diff --git a/build.sh b/build.sh index 087f4ae72..34e4b1153 100755 --- a/build.sh +++ b/build.sh @@ -3,6 +3,11 @@ # Put installed packages into ./bin export GOBIN=$PWD/`dirname $0`/bin -go install -v $PWD/`dirname $0`/cmd/... +export BRANCH=`(git symbolic-ref --short HEAD | cut -d'/' -f 3 )|| ""` +export BUILD=`git rev-parse --short HEAD || ""` -GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs +export FLAGS="-X github.com/matrix-org/dendrite/internal.branch=$BRANCH -X github.com/matrix-org/dendrite/internal.build=$BUILD" + +go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/... + +GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 59535c7b9..725c9c074 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -120,7 +120,7 @@ func (m *DendriteMonolith) Start() { keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) eduInputAPI := eduserver.NewInternalAPI( diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 32c5234b1..de51f16da 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -12,8 +12,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config -RUN ./generate-config > dendrite.yaml -RUN sed -i "s/disable_tls_validation: false/disable_tls_validation: true/g" dendrite.yaml +RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key ENV SERVER_NAME=localhost diff --git a/build/scripts/complement.sh b/build/scripts/complement.sh index 17ddea57e..29feff304 100755 --- a/build/scripts/complement.sh +++ b/build/scripts/complement.sh @@ -10,10 +10,10 @@ cd `dirname $0`/../.. docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile . # Download Complement -wget https://github.com/matrix-org/complement/archive/master.tar.gz +wget -N https://github.com/matrix-org/complement/archive/master.tar.gz tar -xzf master.tar.gz # Run the tests! cd complement-master -COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v ./tests +COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v -count=1 ./tests diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 57fc3f33a..af43064fe 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -342,8 +342,7 @@ func createRoom( } // send events to the room server - _, err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index d2bc9337d..772775aa0 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -41,15 +41,13 @@ type flows struct { } type flow struct { - Type string `json:"type"` - Stages []string `json:"stages"` + Type string `json:"type"` } func passwordLogin() flows { f := flows{} s := flow{ - Type: "m.login.password", - Stages: []string{"m.login.password"}, + Type: "m.login.password", } f.Flows = append(f.Flows, s) return f diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 5d635c018..202662ab6 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -75,13 +75,12 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us return jsonerror.InternalServerError() } - _, err = roomserverAPI.SendEvents( + if err = roomserverAPI.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -270,7 +269,7 @@ func buildMembershipEvent( return nil, err } - return eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index faf92451e..bc51b0b51 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -171,7 +171,7 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -289,7 +289,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -375,7 +375,7 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/rate_limiting.go b/clientapi/routing/rate_limiting.go new file mode 100644 index 000000000..16e3c0565 --- /dev/null +++ b/clientapi/routing/rate_limiting.go @@ -0,0 +1,99 @@ +package routing + +import ( + "net/http" + "sync" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/util" +) + +type rateLimits struct { + limits map[string]chan struct{} + limitsMutex sync.RWMutex + enabled bool + requestThreshold int64 + cooloffDuration time.Duration +} + +func newRateLimits(cfg *config.RateLimiting) *rateLimits { + l := &rateLimits{ + limits: make(map[string]chan struct{}), + enabled: cfg.Enabled, + requestThreshold: cfg.Threshold, + cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, + } + if l.enabled { + go l.clean() + } + return l +} + +func (l *rateLimits) clean() { + for { + // On a 30 second interval, we'll take an exclusive write + // lock of the entire map and see if any of the channels are + // empty. If they are then we will close and delete them, + // freeing up memory. + time.Sleep(time.Second * 30) + l.limitsMutex.Lock() + for k, c := range l.limits { + if len(c) == 0 { + close(c) + delete(l.limits, k) + } + } + l.limitsMutex.Unlock() + } +} + +func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse { + // If rate limiting is disabled then do nothing. + if !l.enabled { + return nil + } + + // Lock the map long enough to check for rate limiting. We hold it + // for longer here than we really need to but it makes sure that we + // also don't conflict with the cleaner goroutine which might clean + // up a channel after we have retrieved it otherwise. + l.limitsMutex.RLock() + defer l.limitsMutex.RUnlock() + + // First of all, work out if X-Forwarded-For was sent to us. If not + // then we'll just use the IP address of the caller. + caller := req.RemoteAddr + if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { + caller = forwardedFor + } + + // Look up the caller's channel, if they have one. If they don't then + // let's create one. + rateLimit, ok := l.limits[caller] + if !ok { + l.limits[caller] = make(chan struct{}, l.requestThreshold) + rateLimit = l.limits[caller] + } + + // Check if the user has got free resource slots for this request. + // If they don't then we'll return an error. + select { + case rateLimit <- struct{}{}: + default: + // We hit the rate limit. Tell the client to back off. + return &util.JSONResponse{ + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), + } + } + + // After the time interval, drain a resource from the rate limiting + // channel. This will free up space in the channel for new requests. + go func() { + <-time.After(l.cooloffDuration) + <-rateLimit + }() + return nil +} diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index bb5265135..178bfafc9 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -115,15 +115,14 @@ func SendRedaction( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Room does not exist"), } } - _, err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9e7970392..cd717e2b1 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -60,6 +60,7 @@ func Setup( keyAPI keyserverAPI.KeyInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { + rateLimits := newRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) publicAPIMux.Handle("/versions", @@ -92,6 +93,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -119,6 +123,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -130,6 +137,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -150,6 +160,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -264,14 +277,23 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) @@ -343,6 +365,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -396,6 +421,9 @@ func Setup( r0mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Whoami(req, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -404,6 +432,9 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -458,6 +489,9 @@ func Setup( r0mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -480,6 +514,9 @@ func Setup( r0mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -517,6 +554,9 @@ func Setup( // Riot logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } // TODO: Set presence (probably the responsibility of a presence server not clientapi) return util.JSONResponse{ Code: http.StatusOK, @@ -527,6 +567,9 @@ func Setup( r0mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return RequestTurnServer(req, device, cfg) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -593,6 +636,9 @@ func Setup( r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } postContent := struct { SearchString string `json:"search_term"` Limit int `json:"limit"` @@ -634,6 +680,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } // TODO: return the read_markers. return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} }), @@ -732,6 +781,9 @@ func Setup( r0mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return GetCapabilities(req, rsAPI) }), ).Methods(http.MethodGet) diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 9cf517cff..9744a5640 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -90,27 +90,26 @@ func SendEvent( // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded - eventID, err := api.SendEvents( + if err := api.SendEvents( req.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, txnAndSessionID, - ) - if err != nil { + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } util.GetLogger(req.Context()).WithFields(logrus.Fields{ - "event_id": eventID, + "event_id": e.EventID(), "room_id": roomID, "room_version": verRes.RoomVersion, }).Info("Sent event to roomserver") res := util.JSONResponse{ Code: http.StatusOK, - JSON: sendEventResponse{eventID}, + JSON: sendEventResponse{e.EventID()}, } // Add response to transactionsCache if txnID != nil { @@ -158,7 +157,7 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index f1d54a47b..b9575a284 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -354,12 +354,12 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err != nil { return err } - _, err = api.SendEvents( + return api.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{ (*event).Headered(queryRes.RoomVersion), @@ -367,5 +367,4 @@ func emit3PIDInviteEvent( cfg.Matrix.ServerName, nil, ) - return err } diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index e2d23e895..d4f0cee04 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -155,7 +155,7 @@ func main() { stateAPI := currentstateserver.NewInternalAPI(&base.Base.Cfg.CurrentStateServer, base.Base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI( - &base.Base, keyRing, federation, + &base.Base, keyRing, ) eduInputAPI := eduserver.NewInternalAPI( &base.Base, cache.New(), userAPI, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 26999ebed..fcf3d4c56 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -104,7 +104,7 @@ func main() { keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) rsAPI := rsComponent diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 815117463..717b21a9f 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -81,7 +81,7 @@ func main() { keyRing := serverKeyAPI.KeyRing() rsImpl := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) // call functions directly on the impl unless running in HTTP mode rsAPI := rsImpl diff --git a/cmd/dendrite-room-server/main.go b/cmd/dendrite-room-server/main.go index 0d587e6ee..08ad34bfd 100644 --- a/cmd/dendrite-room-server/main.go +++ b/cmd/dendrite-room-server/main.go @@ -23,13 +23,12 @@ func main() { cfg := setup.ParseFlags(false) base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true) defer base.Close() // nolint: errcheck - federation := base.CreateFederationClient() serverKeyAPI := base.ServerKeyAPIClient() keyRing := serverKeyAPI.KeyRing() fsAPI := base.FederationSenderHTTPClient() - rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + rsAPI := roomserver.NewInternalAPI(base, keyRing) rsAPI.SetFederationSenderAPI(fsAPI) roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index c95eb3fce..aeca70946 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -205,7 +205,7 @@ func main() { } stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) - rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + rsAPI := roomserver.NewInternalAPI(base, keyRing) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index cff376d8c..78ed3af6c 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "github.com/matrix-org/dendrite/internal/config" @@ -8,6 +9,9 @@ import ( ) func main() { + defaultsForCI := flag.Bool("ci", false, "sane defaults for CI testing") + flag.Parse() + cfg := &config.Dendrite{} cfg.Defaults() cfg.Global.TrustedIDServers = []string{ @@ -56,6 +60,11 @@ func main() { }, } + if *defaultsForCI { + cfg.ClientAPI.RateLimiting.Enabled = false + cfg.FederationSender.DisableTLSValidation = true + } + j, err := yaml.Marshal(cfg) if err != nil { panic(err) diff --git a/currentstateserver/acls/acls.go b/currentstateserver/acls/acls.go index 12619f5fc..775b6c73a 100644 --- a/currentstateserver/acls/acls.go +++ b/currentstateserver/acls/acls.go @@ -23,17 +23,25 @@ import ( "strings" "sync" - "github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + type ServerACLs struct { acls map[string]*serverACL // room ID -> ACL aclsMutex sync.RWMutex // protects the above } -func NewServerACLs(db storage.Database) *ServerACLs { +func NewServerACLs(db ServerACLDatabase) *ServerACLs { ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 23f142a83..570669c1a 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -133,6 +133,14 @@ client_api: turn_username: "" turn_password: "" + # Settings for rate-limited endpoints. Rate limiting will kick in after the + # threshold number of "slots" have been taken by requests from a specific + # host. Each "slot" will be released after the cooloff time in milliseconds. + rate_limiting: + enabled: true + threshold: 5 + cooloff_ms: 500 + # Configuration for the Current State Server. current_state_server: internal_api: diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index ffdadd522..36afe30ab 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -95,7 +95,7 @@ func MakeJoin( queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, @@ -266,15 +266,14 @@ func SendJoin( // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(stateAndAuthChainResponse.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index d2fbfc712..8bb0a8a94 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -61,7 +61,7 @@ func MakeLeave( } var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, @@ -247,15 +247,14 @@ func SendLeave( // Send the events to the room server. // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index cad779219..570062adc 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -382,7 +382,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro } // pass the event to the roomserver - _, err := api.SendEvents( + return api.SendEvents( t.context, t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), @@ -390,7 +390,6 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro api.DoNotSendToOtherServers, nil, ) - return err } func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index fa745e286..6dc8621b2 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias( return fmt.Errorf("not implemented") } +func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + return nil +} + +func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + return nil +} + type testStateAPI struct { } diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index e8d9a9397..ec6cc1488 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if _, err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -172,7 +172,7 @@ func ExchangeThirdPartyInvite( } // Send the event to the roomserver - if _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), diff --git a/federationapi/routing/version.go b/federationapi/routing/version.go index 14ecd21e1..906fc2b9b 100644 --- a/federationapi/routing/version.go +++ b/federationapi/routing/version.go @@ -17,6 +17,7 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/util" ) @@ -31,5 +32,13 @@ type server struct { // Version returns the server version func Version() util.JSONResponse { - return util.JSONResponse{Code: http.StatusOK, JSON: &version{server{"dev", "Dendrite"}}} + return util.JSONResponse{ + Code: http.StatusOK, + JSON: &version{ + server{ + Name: "Dendrite", + Version: internal.VersionString(), + }, + }, + } } diff --git a/federationsender/api/api.go b/federationsender/api/api.go index cea0010d6..655d1d103 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -14,9 +14,12 @@ import ( // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type FederationClient interface { + gomatrixserverlib.BackfillClient + gomatrixserverlib.FederatedStateClient GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 6b5f4c342..61663be31 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -136,3 +136,51 @@ func (a *FederationSenderInternalAPI) QueryKeys( } return ires.(gomatrixserverlib.RespQueryKeys), nil } + +func (a *FederationSenderInternalAPI) Backfill( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, +) (res gomatrixserverlib.Transaction, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} + +func (a *FederationSenderInternalAPI) LookupState( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespState, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + }) + if err != nil { + return gomatrixserverlib.RespState{}, err + } + return ires.(gomatrixserverlib.RespState), nil +} + +func (a *FederationSenderInternalAPI) LookupStateIDs( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, +) (res gomatrixserverlib.RespStateIDs, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + }) + if err != nil { + return gomatrixserverlib.RespStateIDs{}, err + } + return ires.(gomatrixserverlib.RespStateIDs), nil +} + +func (a *FederationSenderInternalAPI) GetEvent( + ctx context.Context, s gomatrixserverlib.ServerName, eventID string, +) (res gomatrixserverlib.Transaction, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.GetEvent(ctx, s, eventID) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 79e220c38..5bfe6089d 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -26,6 +26,10 @@ const ( FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" + FederationSenderBackfillPath = "/federationsender/client/backfill" + FederationSenderLookupStatePath = "/federationsender/client/lookupState" + FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" + FederationSenderGetEventPath = "/federationsender/client/getEvent" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -228,3 +232,129 @@ func (h *httpFederationSenderInternalAPI) QueryKeys( } return *response.Res, nil } + +type backfill struct { + S gomatrixserverlib.ServerName + RoomID string + Limit int + EventIDs []string + Res *gomatrixserverlib.Transaction + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) Backfill( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, +) (gomatrixserverlib.Transaction, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill") + defer span.Finish() + + request := backfill{ + S: s, + RoomID: roomID, + Limit: limit, + EventIDs: eventIDs, + } + var response backfill + apiURL := h.federationSenderURL + FederationSenderBackfillPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + if response.Err != nil { + return gomatrixserverlib.Transaction{}, response.Err + } + return *response.Res, nil +} + +type lookupState struct { + S gomatrixserverlib.ServerName + RoomID string + EventID string + RoomVersion gomatrixserverlib.RoomVersion + Res *gomatrixserverlib.RespState + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) LookupState( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (gomatrixserverlib.RespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState") + defer span.Finish() + + request := lookupState{ + S: s, + RoomID: roomID, + EventID: eventID, + RoomVersion: roomVersion, + } + var response lookupState + apiURL := h.federationSenderURL + FederationSenderLookupStatePath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespState{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespState{}, response.Err + } + return *response.Res, nil +} + +type lookupStateIDs struct { + S gomatrixserverlib.ServerName + RoomID string + EventID string + Res *gomatrixserverlib.RespStateIDs + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) LookupStateIDs( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, +) (gomatrixserverlib.RespStateIDs, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs") + defer span.Finish() + + request := lookupStateIDs{ + S: s, + RoomID: roomID, + EventID: eventID, + } + var response lookupStateIDs + apiURL := h.federationSenderURL + FederationSenderLookupStateIDsPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespStateIDs{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespStateIDs{}, response.Err + } + return *response.Res, nil +} + +type getEvent struct { + S gomatrixserverlib.ServerName + EventID string + Res *gomatrixserverlib.Transaction + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) GetEvent( + ctx context.Context, s gomatrixserverlib.ServerName, eventID string, +) (gomatrixserverlib.Transaction, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent") + defer span.Finish() + + request := getEvent{ + S: s, + EventID: eventID, + } + var response getEvent + apiURL := h.federationSenderURL + FederationSenderGetEventPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + if response.Err != nil { + return gomatrixserverlib.Transaction{}, response.Err + } + return *response.Res, nil +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index b18255760..dfbff1c00 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -175,4 +175,92 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationSenderBackfillPath, + httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse { + var request backfill + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderLookupStatePath, + httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse { + var request lookupState + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderLookupStateIDsPath, + httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse { + var request lookupStateIDs + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderGetEventPath, + httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { + var request getEvent + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) } diff --git a/go.mod b/go.mod index c69068059..3a9fef9f5 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 diff --git a/go.sum b/go.sum index 332ae05fa..33b4f591a 100644 --- a/go.sum +++ b/go.sum @@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2 h1:9wKwfd5KDcXuqZ7/kAaYe0QM4DGM+2awjjvXQtrDa6k= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 h1:k5vsLfpylXHOXgN51N0QNbak9i+4bT33Puk/ZJgcdDw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/internal/config/config_clientapi.go b/internal/config/config_clientapi.go index f7878276a..521154911 100644 --- a/internal/config/config_clientapi.go +++ b/internal/config/config_clientapi.go @@ -34,6 +34,9 @@ type ClientAPI struct { // TURN options TURN TURN `yaml:"turn"` + + // Rate-limiting options + RateLimiting RateLimiting `yaml:"rate_limiting"` } func (c *ClientAPI) Defaults() { @@ -47,6 +50,7 @@ func (c *ClientAPI) Defaults() { c.RecaptchaBypassSecret = "" c.RecaptchaSiteVerifyAPI = "" c.RegistrationDisabled = false + c.RateLimiting.Defaults() } func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -61,6 +65,7 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI)) } c.TURN.Verify(configErrs) + c.RateLimiting.Verify(configErrs) } type TURN struct { @@ -90,3 +95,29 @@ func (c *TURN) Verify(configErrs *ConfigErrors) { } } } + +type RateLimiting struct { + // Is rate limiting enabled or disabled? + Enabled bool `yaml:"enabled"` + + // How many "slots" a user can occupy sending requests to a rate-limited + // endpoint before we apply rate-limiting + Threshold int64 `yaml:"threshold"` + + // The cooloff period in milliseconds after a request before the "slot" + // is freed again + CooloffMS int64 `yaml:"cooloff_ms"` +} + +func (r *RateLimiting) Verify(configErrs *ConfigErrors) { + if r.Enabled { + checkPositive(configErrs, "client_api.rate_limiting.threshold", r.Threshold) + checkPositive(configErrs, "client_api.rate_limiting.cooloff_ms", r.CooloffMS) + } +} + +func (r *RateLimiting) Defaults() { + r.Enabled = true + r.Threshold = 5 + r.CooloffMS = 500 +} diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 35c7f33d8..0b878961e 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -30,13 +30,13 @@ import ( // doesn't exist var ErrRoomNoExists = errors.New("Room does not exist") -// BuildEvent builds a Matrix event using the event builder and roomserver query +// QueryAndBuildEvent builds a Matrix event using the event builder and roomserver query // API client provided. If also fills roomserver query API response (if provided) // in case the function calling FillBuilder needs to use it. // Returns ErrRoomNoExists if the state of the room could not be retrieved because // the room doesn't exist // Returns an error if something else went wrong -func BuildEvent( +func QueryAndBuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, @@ -45,11 +45,25 @@ func BuildEvent( queryRes = &api.QueryLatestEventsAndStateResponse{} } - ver, err := AddPrevEventsToEvent(ctx, builder, rsAPI, queryRes) + eventsNeeded, err := queryRequiredEventsForBuilder(ctx, builder, rsAPI, queryRes) if err != nil { // This can pass through a ErrRoomNoExists to the caller return nil, err } + return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) +} + +// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse +// provided. +func BuildEvent( + ctx context.Context, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, +) (*gomatrixserverlib.HeaderedEvent, error) { + err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) + if err != nil { + return nil, err + } event, err := builder.Build( evTime, cfg.ServerName, cfg.KeyID, @@ -59,23 +73,23 @@ func BuildEvent( return nil, err } - h := event.Headered(ver) + h := event.Headered(queryRes.RoomVersion) return &h, nil } -// AddPrevEventsToEvent fills out the prev_events and auth_events fields in builder -func AddPrevEventsToEvent( +// queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder. +func queryRequiredEventsForBuilder( ctx context.Context, builder *gomatrixserverlib.EventBuilder, rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, -) (gomatrixserverlib.RoomVersion, error) { +) (*gomatrixserverlib.StateNeeded, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { - return "", fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) } if len(eventsNeeded.Tuples()) == 0 { - return "", errors.New("expecting state tuples for event builder, got none") + return nil, errors.New("expecting state tuples for event builder, got none") } // Ask the roomserver for information about this room @@ -83,17 +97,22 @@ func AddPrevEventsToEvent( RoomID: builder.RoomID, StateToFetch: eventsNeeded.Tuples(), } - if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil { - return "", fmt.Errorf("rsAPI.QueryLatestEventsAndState: %w", err) - } + return &eventsNeeded, rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes) +} +// addPrevEventsToEvent fills out the prev_events and auth_events fields in builder +func addPrevEventsToEvent( + builder *gomatrixserverlib.EventBuilder, + eventsNeeded *gomatrixserverlib.StateNeeded, + queryRes *api.QueryLatestEventsAndStateResponse, +) error { if !queryRes.RoomExists { - return "", ErrRoomNoExists + return ErrRoomNoExists } eventFormat, err := queryRes.RoomVersion.EventFormat() if err != nil { - return "", fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err) + return fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err) } builder.Depth = queryRes.Depth @@ -103,13 +122,13 @@ func AddPrevEventsToEvent( for i := range queryRes.StateEvents { err = authEvents.AddEvent(&queryRes.StateEvents[i].Event) if err != nil { - return "", fmt.Errorf("authEvents.AddEvent: %w", err) + return fmt.Errorf("authEvents.AddEvent: %w", err) } } refs, err := eventsNeeded.AuthEventReferences(&authEvents) if err != nil { - return "", fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) + return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) @@ -129,7 +148,7 @@ func AddPrevEventsToEvent( builder.PrevEvents = v2PrevRefs } - return queryRes.RoomVersion, nil + return nil } // truncateAuthAndPrevEvents limits the number of events we add into diff --git a/internal/setup/base.go b/internal/setup/base.go index 7bf06e748..ec2bbc4cf 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -100,6 +100,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo internal.SetupHookLogging(cfg.Logging, componentName) internal.SetupPprof() + logrus.Infof("Dendrite version %s", internal.VersionString()) + closer, err := cfg.SetupTracing("Dendrite" + componentName) if err != nil { logrus.WithError(err).Panicf("failed to start opentracing") diff --git a/internal/version.go b/internal/version.go new file mode 100644 index 000000000..851a09384 --- /dev/null +++ b/internal/version.go @@ -0,0 +1,26 @@ +package internal + +import "fmt" + +// -ldflags "-X github.com/matrix-org/dendrite/internal.branch=master" +var branch string + +// -ldflags "-X github.com/matrix-org/dendrite/internal.build=alpha" +var build string + +const ( + VersionMajor = 0 + VersionMinor = 0 + VersionPatch = 0 +) + +func VersionString() string { + version := fmt.Sprintf("%d.%d.%d", VersionMajor, VersionMinor, VersionPatch) + if branch != "" { + version += fmt.Sprintf("-%s", branch) + } + if build != "" { + version += fmt.Sprintf("+%s", build) + } + return version +} diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go new file mode 100644 index 000000000..775b6c73a --- /dev/null +++ b/roomserver/acls/acls.go @@ -0,0 +1,164 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "context" + "encoding/json" + "fmt" + "net" + "regexp" + "strings" + "sync" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type ServerACLs struct { + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above +} + +func NewServerACLs(db ServerACLDatabase) *ServerACLs { + ctx := context.TODO() + acls := &ServerACLs{ + acls: make(map[string]*serverACL), + } + // Look up all of the rooms that the current state server knows about. + rooms, err := db.GetKnownRooms(ctx) + if err != nil { + logrus.WithError(err).Fatalf("Failed to get known rooms") + } + // For each room, let's see if we have a server ACL state event. If we + // do then we'll process it into memory so that we have the regexes to + // hand. + for _, room := range rooms { + state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "") + if err != nil { + logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) + continue + } + if state != nil { + acls.OnServerACLUpdate(&state.Event) + } + } + return acls +} + +type ServerACL struct { + Allowed []string `json:"allow"` + Denied []string `json:"deny"` + AllowIPLiterals bool `json:"allow_ip_literals"` +} + +type serverACL struct { + ServerACL + allowedRegexes []*regexp.Regexp + deniedRegexes []*regexp.Regexp +} + +func compileACLRegex(orig string) (*regexp.Regexp, error) { + escaped := regexp.QuoteMeta(orig) + escaped = strings.Replace(escaped, "\\?", ".", -1) + escaped = strings.Replace(escaped, "\\*", ".*", -1) + return regexp.Compile(escaped) +} + +func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { + acls := &serverACL{} + if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { + logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") + return + } + // The spec calls only for * (zero or more chars) and ? (exactly one char) + // to be supported as wildcard components, so we will escape all of the regex + // special characters and then replace * and ? with their regex counterparts. + // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl + for _, orig := range acls.Allowed { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile allowed regex") + } else { + acls.allowedRegexes = append(acls.allowedRegexes, expr) + } + } + for _, orig := range acls.Denied { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile denied regex") + } else { + acls.deniedRegexes = append(acls.deniedRegexes, expr) + } + } + logrus.WithFields(logrus.Fields{ + "allow_ip_literals": acls.AllowIPLiterals, + "num_allowed": len(acls.allowedRegexes), + "num_denied": len(acls.deniedRegexes), + }).Debugf("Updating server ACLs for %q", state.RoomID()) + s.aclsMutex.Lock() + defer s.aclsMutex.Unlock() + s.acls[state.RoomID()] = acls +} + +func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { + s.aclsMutex.RLock() + // First of all check if we have an ACL for this room. If we don't then + // no servers are banned from the room. + acls, ok := s.acls[roomID] + if !ok { + s.aclsMutex.RUnlock() + return false + } + s.aclsMutex.RUnlock() + // Split the host and port apart. This is because the spec calls on us to + // validate the hostname only in cases where the port is also present. + if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { + serverName = gomatrixserverlib.ServerName(serverNameOnly) + } + // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding + // a /0 prefix length just to trick ParseCIDR into working. If we find that + // the server is an IP literal and we don't allow those then stop straight + // away. + if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil { + if !acls.AllowIPLiterals { + return true + } + } + // Check if the hostname matches one of the denied regexes. If it does then + // the server is banned from the room. + for _, expr := range acls.deniedRegexes { + if expr.MatchString(string(serverName)) { + return true + } + } + // Check if the hostname matches one of the allowed regexes. If it does then + // the server is NOT banned from the room. + for _, expr := range acls.allowedRegexes { + if expr.MatchString(string(serverName)) { + return false + } + } + // If we've got to this point then we haven't matched any regexes or an IP + // hostname if disallowed. The spec calls for default-deny here. + return true +} diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go new file mode 100644 index 000000000..9fb6a5581 --- /dev/null +++ b/roomserver/acls/acls_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "regexp" + "testing" +) + +func TestOpenACLsWithBlacklist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("*") + if err != nil { + t.Fatalf(err.Error()) + } + denyRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: true, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{denyRegex}, + } + + if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com:4567", roomID) { + t.Fatal("Expected bar.com:4567 to be allowed but wasn't") + } +} + +func TestDefaultACLsWithWhitelist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: false, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{}, + } + + if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("baz.com", roomID) { + t.Fatal("Expected baz.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) { + t.Fatal("Expected qux.com:4567 to be allowed but wasn't") + } +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 3b2d4bd77..eecefe322 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -112,6 +112,20 @@ type RoomserverInternalAPI interface { response *QueryStateAndAuthChainResponse, ) error + // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from + // the response. + QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error + // QueryRoomsForUser retrieves a list of room IDs matching the given query. + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error + // QueryBulkStateContent does a bulk query for state event content in the given rooms. + QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error + // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. + QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. + QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + // Query a given amount (or less) of events prior to a given set of events. PerformBackfill( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 0e1b645e4..643309307 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -245,6 +245,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias( return err } +func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error { + err := t.Impl.QueryCurrentState(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryRoomsForUser retrieves a list of room IDs matching the given query. +func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error { + err := t.Impl.QueryRoomsForUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryBulkStateContent does a bulk query for state event content in the given rooms. +func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error { + err := t.Impl.QueryBulkStateContent(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res)) + return err +} + +// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. +func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error { + err := t.Impl.QuerySharedUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryKnownUsers returns a list of users that we know about from our joined rooms. +func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error { + err := t.Impl.QueryKnownUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. +func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { + err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 05c981df4..73c4994a7 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -83,5 +83,4 @@ type InputRoomEventsRequest struct { // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { - EventID string `json:"event_id"` } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 4e1d09c30..d0d0474d8 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,6 +17,11 @@ package api import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct { // The list of published rooms. RoomIDs []string } + +type QuerySharedUsersRequest struct { + UserID string + ExcludeRoomIDs []string + IncludeRoomIDs []string +} + +type QuerySharedUsersResponse struct { + UserIDsToCount map[string]int +} + +type QueryRoomsForUserRequest struct { + UserID string + // The desired membership of the user. If this is the empty string then no rooms are returned. + WantMembership string +} + +type QueryRoomsForUserResponse struct { + RoomIDs []string +} + +type QueryBulkStateContentRequest struct { + // Returns state events in these rooms + RoomIDs []string + // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*' + AllowWildcards bool + // The state events to return. Only a small subset of tuples are allowed in this request as only certain events + // have their content fields extracted. Specifically, the tuple Type must be one of: + // m.room.avatar + // m.room.create + // m.room.canonical_alias + // m.room.guest_access + // m.room.history_visibility + // m.room.join_rules + // m.room.member + // m.room.name + // m.room.topic + // Any other tuple type will result in the query failing. + StateTuples []gomatrixserverlib.StateKeyTuple +} +type QueryBulkStateContentResponse struct { + // map of room ID -> tuple -> content_value + Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string +} + +type QueryCurrentStateRequest struct { + RoomID string + StateTuples []gomatrixserverlib.StateKeyTuple +} + +type QueryCurrentStateResponse struct { + StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent +} + +type QueryKnownUsersRequest struct { + UserID string `json:"user_id"` + SearchString string `json:"search_string"` + Limit int `json:"limit"` +} + +type QueryKnownUsersResponse struct { + Users []authtypes.FullyQualifiedProfile `json:"profiles"` +} + +type QueryServerBannedFromRoomRequest struct { + ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomID string `json:"room_id"` +} + +type QueryServerBannedFromRoomResponse struct { + Banned bool `json:"banned"` +} + +// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. +func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { + se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) + for k, v := range r.StateEvents { + // use 0x1F (unit separator) as the delimiter between type/state key, + se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v + } + return json.Marshal(se) +} + +func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { + res := make(map[string]*gomatrixserverlib.HeaderedEvent) + err := json.Unmarshal(data, &res) + if err != nil { + return err + } + r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) + for k, v := range res { + fields := strings.Split(k, "\x1F") + r.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: fields[0], + StateKey: fields[1], + }] = v + } + return nil +} diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 207c12c8f..82a4a5719 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, -) (string, error) { +) error { ires := make([]InputRoomEvent, len(events)) for i, event := range events { ires[i] = InputRoomEvent{ @@ -77,19 +77,16 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - _, err = SendInputRoomEvents(ctx, rsAPI, ires) - return err + return SendInputRoomEvents(ctx, rsAPI, ires) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, -) (eventID string, err error) { +) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse - err = rsAPI.InputRoomEvents(ctx, &request, &response) - eventID = response.EventID - return + return rsAPI.InputRoomEvents(ctx, &request, &response) } // SendInvite event to the roomserver. @@ -136,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) } return &res.Events[0] } + +// GetStateEvent returns the current state event in the room or nil. +func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { + var res QueryCurrentStateResponse + err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ + RoomID: roomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState") + return nil + } + ev, ok := res.StateEvents[tuple] + if ok { + return ev + } + return nil +} + +// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. +func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { + req := &QueryServerBannedFromRoomRequest{ + ServerName: serverName, + RoomID: roomID, + } + res := &QueryServerBannedFromRoomResponse{} + if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom") + return true + } + return res.Banned +} + +// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the +// published room directory. +// due to lots of switches +// nolint:gocyclo +func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { + avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} + nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} + guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + + var stateRes QueryBulkStateContentResponse + err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ + RoomIDs: roomIDs, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, + {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + }, + }, &stateRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") + return nil, err + } + chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + i := 0 + for roomID, data := range stateRes.Rooms { + pub := gomatrixserverlib.PublicRoom{ + RoomID: roomID, + } + joinCount := 0 + var joinRule, guestAccess string + for tuple, contentVal := range data { + if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + joinCount++ + continue + } + switch tuple { + case avatarTuple: + pub.AvatarURL = contentVal + case nameTuple: + pub.Name = contentVal + case topicTuple: + pub.Topic = contentVal + case canonicalTuple: + pub.CanonicalAlias = contentVal + case visibilityTuple: + pub.WorldReadable = contentVal == "world_readable" + // need both of these to determine whether guests can join + case joinRuleTuple: + joinRule = contentVal + case guestTuple: + guestAccess = contentVal + } + } + if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + pub.GuestCanJoin = true + } + pub.JoinedMembersCount = joinCount + chunk[i] = pub + i++ + } + return chunk, nil +} diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 4139582b6..d576a8175 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "time" "github.com/matrix-org/dendrite/roomserver/api" @@ -239,16 +240,19 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( } builder.AuthEvents = refs - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { return err } + if roomInfo == nil { + return fmt.Errorf("room %s does not exist", roomID) + } // Build the event now := time.Now() event, err := builder.Build( now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, - r.Cfg.Matrix.PrivateKey, roomVersion, + r.Cfg.Matrix.PrivateKey, roomInfo.RoomVersion, ) if err != nil { return err @@ -257,7 +261,7 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( // Create the request ire := api.InputRoomEvent{ Kind: api.KindNew, - Event: event.Headered(roomVersion), + Event: event.Headered(roomInfo.RoomVersion), AuthEventIDs: event.AuthEventIDs(), SendAsServer: serverName, } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index f94c72f05..bdea650ea 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -1,26 +1,129 @@ package internal import ( - "sync" + "context" "github.com/Shopify/sarama" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/acls" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/internal/perform" + "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI type RoomserverInternalAPI struct { + *input.Inputer + *query.Queryer + *perform.Inviter + *perform.Joiner + *perform.Leaver + *perform.Publisher + *perform.Backfiller DB storage.Database Cfg *config.RoomServer Producer sarama.SyncProducer Cache caching.RoomServerCaches ServerName gomatrixserverlib.ServerName KeyRing gomatrixserverlib.JSONVerifier - FedClient *gomatrixserverlib.FederationClient - OutputRoomEventTopic string // Kafka topic for new output room events - mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent fsAPI fsAPI.FederationSenderInternalAPI + OutputRoomEventTopic string // Kafka topic for new output room events +} + +func NewRoomserverAPI( + cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer, + outputRoomEventTopic string, caches caching.RoomServerCaches, + keyRing gomatrixserverlib.JSONVerifier, +) *RoomserverInternalAPI { + a := &RoomserverInternalAPI{ + DB: roomserverDB, + Cfg: cfg, + Cache: caches, + ServerName: cfg.Matrix.ServerName, + KeyRing: keyRing, + Queryer: &query.Queryer{ + DB: roomserverDB, + Cache: caches, + ServerACLs: acls.NewServerACLs(roomserverDB), + }, + Inputer: &input.Inputer{ + DB: roomserverDB, + OutputRoomEventTopic: outputRoomEventTopic, + Producer: producer, + ServerName: cfg.Matrix.ServerName, + }, + // perform-er structs get initialised when we have a federation sender to use + } + return a +} + +// SetFederationSenderInputAPI passes in a federation sender input API reference +// so that we can avoid the chicken-and-egg problem of both the roomserver input API +// and the federation sender input API being interdependent. +func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { + r.fsAPI = fsAPI + + r.Inviter = &perform.Inviter{ + DB: r.DB, + Cfg: r.Cfg, + FSAPI: r.fsAPI, + Inputer: r.Inputer, + } + r.Joiner = &perform.Joiner{ + ServerName: r.Cfg.Matrix.ServerName, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + Inputer: r.Inputer, + } + r.Leaver = &perform.Leaver{ + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + Inputer: r.Inputer, + } + r.Publisher = &perform.Publisher{ + DB: r.DB, + } + r.Backfiller = &perform.Backfiller{ + ServerName: r.ServerName, + DB: r.DB, + FSAPI: r.fsAPI, + KeyRing: r.KeyRing, + } +} + +func (r *RoomserverInternalAPI) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, + res *api.PerformInviteResponse, +) error { + outputEvents, err := r.Inviter.PerformInvite(ctx, req, res) + if err != nil { + return err + } + if len(outputEvents) == 0 { + return nil + } + return r.WriteOutputEvents(req.Event.RoomID(), outputEvents) +} + +func (r *RoomserverInternalAPI) PerformLeave( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, +) error { + outputEvents, err := r.Leaver.PerformLeave(ctx, req, res) + if err != nil { + return err + } + if len(outputEvents) == 0 { + return nil + } + return r.WriteOutputEvents(req.RoomID, outputEvents) } diff --git a/roomserver/internal/input_authevents.go b/roomserver/internal/helpers/auth.go similarity index 96% rename from roomserver/internal/input_authevents.go rename to roomserver/internal/helpers/auth.go index e3828f566..060f0a0e9 100644 --- a/roomserver/internal/input_authevents.go +++ b/roomserver/internal/helpers/auth.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package helpers import ( "context" @@ -23,9 +23,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// checkAuthEvents checks that the event passes authentication checks +// CheckAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. -func checkAuthEvents( +func CheckAuthEvents( ctx context.Context, db storage.Database, event gomatrixserverlib.HeaderedEvent, @@ -63,7 +63,7 @@ func checkAuthEvents( type authEvents struct { stateKeyNIDMap map[string]types.EventStateKeyNID state stateEntryMap - events eventMap + events EventMap } // Create implements gomatrixserverlib.AuthEventProvider @@ -99,7 +99,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) * if !ok { return nil } - event, ok := ae.events.lookup(eventNID) + event, ok := ae.events.Lookup(eventNID) if !ok { return nil } @@ -118,7 +118,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * if !ok { return nil } - event, ok := ae.events.lookup(eventNID) + event, ok := ae.events.Lookup(eventNID) if !ok { return nil } @@ -224,10 +224,10 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.Even // Map from numeric event ID to event. // Implemented using binary search on a sorted array. -type eventMap []types.Event +type EventMap []types.Event // lookup an entry in the event map. -func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { +func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed diff --git a/roomserver/internal/input_authevents_test.go b/roomserver/internal/helpers/auth_test.go similarity index 97% rename from roomserver/internal/input_authevents_test.go rename to roomserver/internal/helpers/auth_test.go index 6b981571b..2a1c3ea49 100644 --- a/roomserver/internal/input_authevents_test.go +++ b/roomserver/internal/helpers/auth_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package helpers import ( "testing" @@ -95,7 +95,7 @@ func TestStateEntryMap(t *testing.T) { } func TestEventMap(t *testing.T) { - events := eventMap([]types.Event{ + events := EventMap([]types.Event{ {EventNID: 1}, {EventNID: 2}, {EventNID: 3}, @@ -123,7 +123,7 @@ func TestEventMap(t *testing.T) { } for _, testCase := range testCases { - gotEvent, gotOK := events.lookup(testCase.inputEventNID) + gotEvent, gotOK := events.Lookup(testCase.inputEventNID) if testCase.wantOK != gotOK { t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK) } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go new file mode 100644 index 000000000..b7e6ce86c --- /dev/null +++ b/roomserver/internal/helpers/helpers.go @@ -0,0 +1,379 @@ +package helpers + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// TODO: temporary package which has helper functions used by both internal/perform packages. +// Move these to a more sensible place. + +func UpdateToInviteMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, + roomVersion gomatrixserverlib.RoomVersion, +) ([]api.OutputEvent, error) { + // We may have already sent the invite to the user, either because we are + // reprocessing this event, or because the we received this invite from a + // remote server via the federation invite API. In those cases we don't need + // to send the event. + needsSending, err := mu.SetToInvite(*add) + if err != nil { + return nil, err + } + if needsSending { + // We notify the consumers using a special event even though we will + // notify them about the change in current state as part of the normal + // room event stream. This ensures that the consumers only have to + // consider a single stream of events when determining whether a user + // is invited, rather than having to combine multiple streams themselves. + onie := api.OutputNewInviteEvent{ + Event: add.Headered(roomVersion), + RoomVersion: roomVersion, + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeNewInviteEvent, + NewInviteEvent: &onie, + }) + } + return updates, nil +} + +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { + info, err := db.RoomInfo(ctx, roomID) + if err != nil { + return false, err + } + if info == nil { + return false, fmt.Errorf("unknown room %s", roomID) + } + + eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + if err != nil { + return false, err + } + + events, err := db.Events(ctx, eventNIDs) + if err != nil { + return false, err + } + gmslEvents := make([]gomatrixserverlib.Event, len(events)) + for i := range events { + gmslEvents[i] = events[i].Event + } + return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil +} + +func IsInvitePending( + ctx context.Context, db storage.Database, + roomID, userID string, +) (bool, string, string, error) { + // Look up the room NID for the supplied room ID. + info, err := db.RoomInfo(ctx, roomID) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if info == nil { + return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) + } + + // Look up the state key NID for the supplied user ID. + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) + } + targetUserNID, targetUserFound := targetUserNIDs[userID] + if !targetUserFound { + return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + } + + // Let's see if we have an event active for the user in the room. If + // we do then it will contain a server name that we can direct the + // send_leave to. + senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) + } + if len(senderUserNIDs) == 0 { + return false, "", "", nil + } + userNIDToEventID := make(map[types.EventStateKeyNID]string) + for i, nid := range senderUserNIDs { + userNIDToEventID[nid] = eventIDs[i] + } + + // Look up the user ID from the NID. + senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) + } + if len(senderUsers) == 0 { + return false, "", "", fmt.Errorf("no senderUsers") + } + + senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] + if !senderUserFound { + return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) + } + + return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil +} + +// GetMembershipsAtState filters the state events to +// only keep the "m.room.member" events with a "join" membership. These events are returned. +// Returns an error if there was an issue fetching the events. +func GetMembershipsAtState( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, +) ([]types.Event, error) { + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomMemberNID { + eventNIDs = append(eventNIDs, entry.EventNID) + } + } + + // Get all of the events in this state + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + if !joinedOnly { + return stateEvents, nil + } + + // Filter the events to only keep the "join" membership events + var events []types.Event + for _, event := range stateEvents { + membership, err := event.Membership() + if err != nil { + return nil, err + } + + if membership == gomatrixserverlib.Join { + events = append(events, event) + } + } + + return events, nil +} + +func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) + // Lookup the event NID + eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) + if err != nil { + return nil, err + } + eventIDs := []string{eIDs[eventNID]} + + prevState, err := db.StateAtEventIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + // Fetch the state as it was when this event was fired + return roomState.LoadCombinedStateAfterEvents(ctx, prevState) +} + +func LoadEvents( + ctx context.Context, db storage.Database, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.Event, error) { + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + result := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + result[i] = stateEvents[i].Event + } + return result, nil +} + +func LoadStateEvents( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, +) ([]gomatrixserverlib.Event, error) { + eventNIDs := make([]types.EventNID, len(stateEntries)) + for i := range stateEntries { + eventNIDs[i] = stateEntries[i].EventNID + } + return LoadEvents(ctx, db, eventNIDs) +} + +func CheckServerAllowedToSeeEvent( + ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, +) (bool, error) { + roomState := state.NewStateResolution(db, info) + stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) + if err != nil { + return false, err + } + + // TODO: We probably want to make it so that we don't have to pull + // out all the state if possible. + stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries) + if err != nil { + return false, err + } + + return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil +} + +// TODO: Remove this when we have tests to assert correctness of this function +// nolint:gocyclo +func ScanEventTree( + ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, + serverName gomatrixserverlib.ServerName, +) ([]types.EventNID, error) { + var resultNIDs []types.EventNID + var err error + var allowed bool + var events []types.Event + var next []string + var pre string + + // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) + // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing + // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in + // duplicate events being sent in response to /backfill requests. + initialIgnoreList := make(map[string]bool, len(visited)) + for k, v := range visited { + initialIgnoreList[k] = v + } + + resultNIDs = make([]types.EventNID, 0, limit) + + var checkedServerInRoom bool + var isServerInRoom bool + + // Loop through the event IDs to retrieve the requested events and go + // through the whole tree (up to the provided limit) using the events' + // "prev_event" key. +BFSLoop: + for len(front) > 0 { + // Prevent unnecessary allocations: reset the slice only when not empty. + if len(next) > 0 { + next = make([]string, 0) + } + // Retrieve the events to process from the database. + events, err = db.EventsFromIDs(ctx, front) + if err != nil { + return resultNIDs, err + } + + if !checkedServerInRoom && len(events) > 0 { + // It's nasty that we have to extract the room ID from an event, but many federation requests + // only talk in event IDs, no room IDs at all (!!!) + ev := events[0] + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") + } + checkedServerInRoom = true + } + + for _, ev := range events { + // Break out of the loop if the provided limit is reached. + if len(resultNIDs) == limit { + break BFSLoop + } + + if !initialIgnoreList[ev.EventID()] { + // Update the list of events to retrieve. + resultNIDs = append(resultNIDs, ev.EventNID) + } + // Loop through the event's parents. + for _, pre = range ev.PrevEventIDs() { + // Only add an event to the list of next events to process if it + // hasn't been seen before. + if !visited[pre] { + visited[pre] = true + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + if err != nil { + util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( + "Error checking if allowed to see event", + ) + return resultNIDs, err + } + + // If the event hasn't been seen before and the HS + // requesting to retrieve it is allowed to do so, add it to + // the list of events to retrieve. + if allowed { + next = append(next, pre) + } else { + util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") + } + } + } + } + // Repeat the same process with the parent events we just processed. + front = next + } + + return resultNIDs, err +} + +func QueryLatestEventsAndState( + ctx context.Context, db storage.Database, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + roomInfo, err := db.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if roomInfo == nil || roomInfo.IsStub { + response.RoomExists = false + return nil + } + + roomState := state.NewStateResolution(db, *roomInfo) + response.RoomExists = true + response.RoomVersion = roomInfo.RoomVersion + + var currentStateSnapshotNID types.StateSnapshotNID + response.LatestEvents, currentStateSnapshotNID, response.Depth, err = + db.LatestEventIDs(ctx, roomInfo.RoomNID) + if err != nil { + return err + } + + var stateEntries []types.StateEntry + if len(request.StateToFetch) == 0 { + // Look up all room state. + stateEntries, err = roomState.LoadStateAtSnapshot( + ctx, currentStateSnapshotNID, + ) + } else { + // Look up the current state for the requested tuples. + stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples( + ctx, currentStateSnapshotNID, request.StateToFetch, + ) + } + if err != nil { + return err + } + + stateEvents, err := LoadStateEvents(ctx, db, stateEntries) + if err != nil { + return err + } + + for _, event := range stateEvents { + response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) + } + + return nil +} diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go deleted file mode 100644 index e85e9830d..000000000 --- a/roomserver/internal/input.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package input contains the code processes new room events -package internal - -import ( - "context" - "encoding/json" - "sync" - - "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/roomserver/api" - log "github.com/sirupsen/logrus" - - fsAPI "github.com/matrix-org/dendrite/federationsender/api" -) - -// SetFederationSenderInputAPI passes in a federation sender input API reference -// so that we can avoid the chicken-and-egg problem of both the roomserver input API -// and the federation sender input API being interdependent. -func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { - r.fsAPI = fsAPI -} - -// WriteOutputEvents implements OutputRoomEventWriter -func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { - messages := make([]*sarama.ProducerMessage, len(updates)) - for i := range updates { - value, err := json.Marshal(updates[i]) - if err != nil { - return err - } - logger := log.WithFields(log.Fields{ - "room_id": roomID, - "type": updates[i].Type, - }) - if updates[i].NewRoomEvent != nil { - logger = logger.WithFields(log.Fields{ - "event_type": updates[i].NewRoomEvent.Event.Type(), - "event_id": updates[i].NewRoomEvent.Event.EventID(), - "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs), - "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs), - "send_as_server": updates[i].NewRoomEvent.SendAsServer, - "sender": updates[i].NewRoomEvent.Event.Sender(), - }) - } - logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) - messages[i] = &sarama.ProducerMessage{ - Topic: r.OutputRoomEventTopic, - Key: sarama.StringEncoder(roomID), - Value: sarama.ByteEncoder(value), - } - } - return r.Producer.SendMessages(messages) -} - -// InputRoomEvents implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) (err error) { - for i, e := range request.InputRoomEvents { - roomID := "global" - if r.DB.SupportsConcurrentRoomInputs() { - roomID = e.Event.RoomID() - } - mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{}) - mutex.(*sync.Mutex).Lock() - if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { - mutex.(*sync.Mutex).Unlock() - return err - } - mutex.(*sync.Mutex).Unlock() - } - return nil -} diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go new file mode 100644 index 000000000..7a44ff42c --- /dev/null +++ b/roomserver/internal/input/input.go @@ -0,0 +1,157 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package input contains the code processes new room events +package input + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "go.uber.org/atomic" +) + +type Inputer struct { + DB storage.Database + Producer sarama.SyncProducer + ServerName gomatrixserverlib.ServerName + OutputRoomEventTopic string + + workers sync.Map // room ID -> *inputWorker +} + +type inputTask struct { + ctx context.Context + event *api.InputRoomEvent + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + r *Inputer + running atomic.Bool + input chan *inputTask +} + +func (w *inputWorker) start() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for { + select { + case task := <-w.input: + _, task.err = w.r.processRoomEvent(task.ctx, task.event) + task.wg.Done() + case <-time.After(time.Second * 5): + return + } + } +} + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { + messages := make([]*sarama.ProducerMessage, len(updates)) + for i := range updates { + value, err := json.Marshal(updates[i]) + if err != nil { + return err + } + logger := log.WithFields(log.Fields{ + "room_id": roomID, + "type": updates[i].Type, + }) + if updates[i].NewRoomEvent != nil { + logger = logger.WithFields(log.Fields{ + "event_type": updates[i].NewRoomEvent.Event.Type(), + "event_id": updates[i].NewRoomEvent.Event.EventID(), + "adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs), + "removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs), + "send_as_server": updates[i].NewRoomEvent.SendAsServer, + "sender": updates[i].NewRoomEvent.Event.Sender(), + }) + } + logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic) + messages[i] = &sarama.ProducerMessage{ + Topic: r.OutputRoomEventTopic, + Key: sarama.StringEncoder(roomID), + Value: sarama.ByteEncoder(value), + } + } + return r.Producer.SendMessages(messages) +} + +// InputRoomEvents implements api.RoomserverInternalAPI +func (r *Inputer) InputRoomEvents( + ctx context.Context, + request *api.InputRoomEventsRequest, + response *api.InputRoomEventsResponse, +) error { + // Create a wait group. Each task that we dispatch will call Done on + // this wait group so that we know when all of our events have been + // processed. + wg := &sync.WaitGroup{} + wg.Add(len(request.InputRoomEvents)) + tasks := make([]*inputTask, len(request.InputRoomEvents)) + + for i, e := range request.InputRoomEvents { + // Work out if we are running per-room workers or if we're just doing + // it on a global basis (e.g. SQLite). + roomID := "global" + if r.DB.SupportsConcurrentRoomInputs() { + roomID = e.Event.RoomID() + } + + // Look up the worker, or create it if it doesn't exist. This channel + // is buffered to reduce the chance that we'll be blocked by another + // room - the channel will be quite small as it's just pointer types. + w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ + r: r, + input: make(chan *inputTask, 10), + }) + worker := w.(*inputWorker) + + // Create a task. This contains the input event and a reference to + // the wait group, so that the worker can notify us when this specific + // task has been finished. + tasks[i] = &inputTask{ + ctx: ctx, + event: &request.InputRoomEvents[i], + wg: wg, + } + + // Send the task to the worker. + go worker.start() + worker.input <- tasks[i] + } + + // Wait for all of the workers to return results about our tasks. + wg.Wait() + + // If any of the tasks returned an error, we should probably report + // that back to the caller. + for _, task := range tasks { + if task.err != nil { + return task.err + } + } + return nil +} diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input/input_events.go similarity index 84% rename from roomserver/internal/input_events.go rename to roomserver/internal/input/input_events.go index a63082990..6ee679da6 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "context" @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -35,9 +36,9 @@ import ( // state deltas when sending to kafka streams // TODO: Break up function - we should probably do transaction ID checks before calling this. // nolint:gocyclo -func (r *RoomserverInternalAPI) processRoomEvent( +func (r *Inputer) processRoomEvent( ctx context.Context, - input api.InputRoomEvent, + input *api.InputRoomEvent, ) (eventID string, err error) { // Parse and validate the event JSON headered := input.Event @@ -45,7 +46,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if err != nil { logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") return @@ -64,7 +65,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( } // Store the event. - roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } @@ -89,10 +90,18 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } + roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + if err != nil { + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if roomInfo == nil { + return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + } + if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) if err != nil { return "", fmt.Errorf("r.calculateAndSetState: %w", err) } @@ -100,7 +109,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( if err = r.updateLatestEvents( ctx, // context - roomNID, // room NID to update + roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event input.SendAsServer, // send as server @@ -132,22 +141,22 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } -func (r *RoomserverInternalAPI) calculateAndSetState( +func (r *Inputer) calculateAndSetState( ctx context.Context, - input api.InputRoomEvent, - roomNID types.RoomNID, + input *api.InputRoomEvent, + roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -161,14 +170,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState( return err } - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { return err } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { return err } } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input/input_latest_events.go similarity index 95% rename from roomserver/internal/input_latest_events.go rename to roomserver/internal/input/input_latest_events.go index f11a78d72..67a7d8a40 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "bytes" @@ -47,15 +47,15 @@ import ( // 7 <----- latest // // Can only be called once at a time -func (r *RoomserverInternalAPI) updateLatestEvents( +func (r *Inputer) updateLatestEvents( ctx context.Context, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } @@ -66,7 +66,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ctx: ctx, api: r, updater: updater, - roomNID: roomNID, + roomInfo: roomInfo, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, @@ -87,9 +87,9 @@ func (r *RoomserverInternalAPI) updateLatestEvents( // when there are so many variables to pass around. type latestEventsUpdater struct { ctx context.Context - api *RoomserverInternalAPI + api *Inputer updater *shared.LatestEventsUpdater - roomNID types.RoomNID + roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID @@ -196,7 +196,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } - if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { + if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } @@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB) + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error { // of the state after the events. The snapshot state will be resolved // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( - u.ctx, u.roomNID, latestStateAtEvents, + u.ctx, latestStateAtEvents, ) if err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) @@ -303,13 +303,8 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) - if err != nil { - return nil, err - } - ore := api.OutputNewRoomEvent{ - Event: u.event.Headered(roomVersion), + Event: u.event.Headered(u.roomInfo.RoomVersion), LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, @@ -337,7 +332,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) // include extra state events if they were added as nearly every downstream component will care about it // and we'd rather not have them all hit QueryEventsByID at the same time! if len(ore.AddsStateEventIDs) > 0 { - ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs) + ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs) if err != nil { return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) } diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input/input_membership.go similarity index 83% rename from roomserver/internal/input_membership.go rename to roomserver/internal/input/input_membership.go index bcecfca0e..8befcd647 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "context" "fmt" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -28,7 +29,7 @@ import ( // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. -func (r *RoomserverInternalAPI) updateMemberships( +func (r *Inputer) updateMemberships( ctx context.Context, updater *shared.LatestEventsUpdater, removed, added []types.StateEntry, @@ -59,13 +60,13 @@ func (r *RoomserverInternalAPI) updateMemberships( var re *gomatrixserverlib.Event targetUserNID := change.EventStateKeyNID if change.removedEventNID != 0 { - ev, _ := eventMap(events).lookup(change.removedEventNID) + ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) if ev != nil { re = &ev.Event } } if change.addedEventNID != 0 { - ev, _ := eventMap(events).lookup(change.addedEventNID) + ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) if ev != nil { ae = &ev.Event } @@ -77,7 +78,7 @@ func (r *RoomserverInternalAPI) updateMemberships( return updates, nil } -func (r *RoomserverInternalAPI) updateMembership( +func (r *Inputer) updateMembership( updater *shared.LatestEventsUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, @@ -120,7 +121,7 @@ func (r *RoomserverInternalAPI) updateMembership( switch newMembership { case gomatrixserverlib.Invite: - return updateToInviteMembership(mu, add, updates, updater.RoomVersion()) + return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion()) case gomatrixserverlib.Join: return updateToJoinMembership(mu, add, updates) case gomatrixserverlib.Leave, gomatrixserverlib.Ban: @@ -132,45 +133,15 @@ func (r *RoomserverInternalAPI) updateMembership( } } -func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool { +func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) - isTargetLocalUser = domain == r.Cfg.Matrix.ServerName + isTargetLocalUser = domain == r.ServerName } return isTargetLocalUser } -func updateToInviteMembership( - mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, - roomVersion gomatrixserverlib.RoomVersion, -) ([]api.OutputEvent, error) { - // We may have already sent the invite to the user, either because we are - // reprocessing this event, or because the we received this invite from a - // remote server via the federation invite API. In those cases we don't need - // to send the event. - needsSending, err := mu.SetToInvite(*add) - if err != nil { - return nil, err - } - if needsSending { - // We notify the consumers using a special event even though we will - // notify them about the change in current state as part of the normal - // room event stream. This ensures that the consumers only have to - // consider a single stream of events when determining whether a user - // is invited, rather than having to combine multiple streams themselves. - onie := api.OutputNewInviteEvent{ - Event: add.Headered(roomVersion), - RoomVersion: roomVersion, - } - updates = append(updates, api.OutputEvent{ - Type: api.OutputTypeNewInviteEvent, - NewInviteEvent: &onie, - }) - } - return updates, nil -} - func updateToJoinMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go new file mode 100644 index 000000000..668c80787 --- /dev/null +++ b/roomserver/internal/perform/perform_backfill.go @@ -0,0 +1,562 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "fmt" + + federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type Backfiller struct { + ServerName gomatrixserverlib.ServerName + DB storage.Database + FSAPI federationSenderAPI.FederationSenderInternalAPI + KeyRing gomatrixserverlib.JSONVerifier +} + +// PerformBackfill implements api.RoomServerQueryAPI +func (r *Backfiller) PerformBackfill( + ctx context.Context, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, +) error { + // if we are requesting the backfill then we need to do a federation hit + // TODO: we could be more sensible and fetch as many events we already have then request the rest + // which is what the syncapi does already. + if request.ServerName == r.ServerName { + return r.backfillViaFederation(ctx, request, response) + } + // someone else is requesting the backfill, try to service their request. + var err error + var front []string + + // The limit defines the maximum number of events to retrieve, so it also + // defines the highest number of elements in the map below. + visited := make(map[string]bool, request.Limit) + + // this will include these events which is what we want + front = request.PrevEventIDs() + + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) + } + + // Scan the event tree for events to send back. + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + if err != nil { + return err + } + + // Retrieve events from the list that was filled previously. + var loadedEvents []gomatrixserverlib.Event + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) + if err != nil { + return err + } + + for _, event := range loadedEvents { + response.Events = append(response.Events, event.Headered(info.RoomVersion)) + } + + return err +} + +func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { + info, err := r.DB.RoomInfo(ctx, req.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) + } + requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities) + // Request 100 items regardless of what the query asks for. + // We don't want to go much higher than this. + // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass + // (so we don't need to hit /state_ids which the test has no listener for) + // Specifically the test "Outbound federation can backfill events" + events, err := gomatrixserverlib.RequestBackfill( + ctx, requester, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) + if err != nil { + return err + } + logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + + // persist these new events - auth checks have already been done + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + if err != nil { + return err + } + + for _, ev := range backfilledEventMap { + // now add state for these events + stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] + if !ok { + // this should be impossible as all events returned must have pass Step 5 of the PDU checks + // which requires a list of state IDs. + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") + continue + } + var entries []types.StateEntry + if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { + // attempt to fetch the missing events + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + // try again + entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") + return err + } + } + + var beforeStateSnapshotNID types.StateSnapshotNID + if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") + return err + } + if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") + } + } + + // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. + + res.Events = events + return nil +} + +// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just +// best effort. +func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + backfillRequester *backfillRequester, stateIDs []string) { + + servers := backfillRequester.servers + + // work out which are missing + nidMap, err := r.DB.EventNIDs(ctx, stateIDs) + if err != nil { + util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") + return + } + missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event + for _, id := range stateIDs { + if _, ok := nidMap[id]; !ok { + missingMap[id] = nil + } + } + util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) + + // fetch the events from federation. Loop the servers first so if we find one that works we stick with them + for _, srv := range servers { + for id, ev := range missingMap { + if ev != nil { + continue // already found + } + logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) + res, err := r.FSAPI.GetEvent(ctx, srv, id) + if err != nil { + logger.WithError(err).Warn("failed to get event from server") + continue + } + loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + if err != nil { + logger.WithError(err).Warn("failed to load and verify event") + continue + } + logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) + for _, res := range result { + if res.Error != nil { + logger.WithError(err).Warn("event failed PDU checks") + continue + } + missingMap[id] = res.Event + } + } + } + + var newEvents []gomatrixserverlib.HeaderedEvent + for _, ev := range missingMap { + if ev != nil { + newEvents = append(newEvents, *ev) + } + } + util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) + persistEvents(ctx, r.DB, newEvents) +} + +// backfillRequester implements gomatrixserverlib.BackfillRequester +type backfillRequester struct { + db storage.Database + fsAPI federationSenderAPI.FederationSenderInternalAPI + thisServer gomatrixserverlib.ServerName + bwExtrems map[string][]string + + // per-request state + servers []gomatrixserverlib.ServerName + eventIDToBeforeStateIDs map[string][]string + eventIDMap map[string]gomatrixserverlib.Event +} + +func newBackfillRequester(db storage.Database, fsAPI federationSenderAPI.FederationSenderInternalAPI, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester { + return &backfillRequester{ + db: db, + fsAPI: fsAPI, + thisServer: thisServer, + eventIDToBeforeStateIDs: make(map[string][]string), + eventIDMap: make(map[string]gomatrixserverlib.Event), + bwExtrems: bwExtrems, + } +} + +func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) { + b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() + if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { + return ids, nil + } + if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { + util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} + return nil, nil + } + // if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event. + // Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or + // we don't know the result of state res to merge forks (2 or more prev_events) + if len(targetEvent.PrevEventIDs()) == 1 { + prevEventID := targetEvent.PrevEventIDs()[0] + prevEvent, ok := b.eventIDMap[prevEventID] + if !ok { + goto FederationHit + } + prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID] + if !ok { + goto FederationHit + } + newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs) + if newStateIDs != nil { + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs, nil + } + // else we failed to calculate the new state, so fallthrough + } + +FederationHit: + var lastErr error + logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event") + for _, srv := range b.servers { // hit any valid server + c := gomatrixserverlib.FederatedStateProvider{ + FedClient: b.fsAPI, + RememberAuthEvents: false, + Server: srv, + } + res, err := c.StateIDsBeforeEvent(ctx, targetEvent) + if err != nil { + lastErr = err + continue + } + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res + return res, nil + } + return nil, lastErr +} + +func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string { + newStateIDs := prevEventStateIDs[:] + if prevEvent.StateKey() == nil { + // state is the same as the previous event + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs + } + + missingState := false // true if we are missing the info for a state event ID + foundEvent := false // true if we found a (type, state_key) match + // find which state ID to replace, if any + for i, id := range newStateIDs { + ev, ok := b.eventIDMap[id] + if !ok { + missingState = true + continue + } + // The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself + if ev.Type() == prevEvent.Type() && ev.StateKeyEquals(*prevEvent.StateKey()) { + newStateIDs[i] = prevEvent.EventID() + foundEvent = true + break + } + } + if !foundEvent && !missingState { + // we can be certain that this is new state + newStateIDs = append(newStateIDs, prevEvent.EventID()) + foundEvent = true + } + + if foundEvent { + b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs + return newStateIDs + } + return nil +} + +func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + + // try to fetch the events from the database first + events, err := b.ProvideEvents(roomVer, eventIDs) + if err != nil { + // non-fatal, fallthrough + logrus.WithError(err).Info("Failed to fetch events") + } else { + logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) + if len(events) == len(eventIDs) { + result := make(map[string]*gomatrixserverlib.Event) + for i := range events { + result[events[i].EventID()] = &events[i] + b.eventIDMap[events[i].EventID()] = events[i] + } + return result, nil + } + } + + c := gomatrixserverlib.FederatedStateProvider{ + FedClient: b.fsAPI, + RememberAuthEvents: false, + Server: b.servers[0], + } + result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) + if err != nil { + return nil, err + } + for eventID, ev := range result { + b.eventIDMap[eventID] = *ev + } + return result, nil +} + +// ServersAtEvent is called when trying to determine which server to request from. +// It returns a list of servers which can be queried for backfill requests. These servers +// will be servers that are in the room already. The entries at the beginning are preferred servers +// and will be tried first. An empty list will fail the request. +// nolint:gocyclo +func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { + // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use + // its successor, so look it up. + successor := "" +FindSuccessor: + for sucID, prevEventIDs := range b.bwExtrems { + for _, pe := range prevEventIDs { + if pe == eventID { + successor = sucID + break FindSuccessor + } + } + } + if successor == "" { + logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state") + return nil + } + eventID = successor + + // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for + // the event is necessary. + NIDs, err := b.db.EventNIDs(ctx, []string{eventID}) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event") + return nil + } + + info, err := b.db.RoomInfo(ctx, roomID) + if err != nil { + logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room") + return nil + } + if info == nil || info.IsStub { + logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing") + return nil + } + + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") + return nil + } + + // possibly return all joined servers depending on history visiblity + memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) + if err != nil { + logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") + return nil + } + logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis)) + + // Retrieve all "m.room.member" state events of "join" membership, which + // contains the list of users in the room before the event, therefore all + // the servers in it at that moment. + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") + return nil + } + memberEvents = append(memberEvents, memberEventsFromVis...) + + // Store the server names in a temporary map to avoid duplicates. + serverSet := make(map[gomatrixserverlib.ServerName]bool) + for _, event := range memberEvents { + serverSet[event.Origin()] = true + } + var servers []gomatrixserverlib.ServerName + for server := range serverSet { + if server == b.thisServer { + continue + } + servers = append(servers, server) + } + b.servers = servers + return servers +} + +// Backfill performs a backfill request to the given server. +// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid +func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, + limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { + + tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + return tx, err +} + +func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { + ctx := context.Background() + nidMap, err := b.db.EventNIDs(ctx, eventIDs) + if err != nil { + logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events") + return nil, err + } + eventNIDs := make([]types.EventNID, len(nidMap)) + i := 0 + for _, nid := range nidMap { + eventNIDs[i] = nid + i++ + } + eventsWithNids, err := b.db.Events(ctx, eventNIDs) + if err != nil { + logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") + return nil, err + } + events := make([]gomatrixserverlib.Event, len(eventsWithNids)) + for i := range eventsWithNids { + events[i] = eventsWithNids[i].Event + } + return events, nil +} + +// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility. +// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just +// pull all events and then filter by that table. +func joinEventsFromHistoryVisibility( + ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID { + eventNIDs = append(eventNIDs, entry.EventNID) + break + } + } + + // Get all of the events in this state + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + events := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + events[i] = stateEvents[i].Event + } + visibility := auth.HistoryVisibilityForRoom(events) + if visibility != "shared" { + logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) + return nil, nil + } + // get joined members + info, err := db.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + if err != nil { + return nil, err + } + return db.Events(ctx, joinEventNIDs) +} + +func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { + var roomNID types.RoomNID + backfilledEventMap := make(map[string]types.Event) + for j, ev := range events { + nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) + if err != nil { // this shouldn't happen as RequestBackfill already found them + logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") + continue + } + authNids := make([]types.EventNID, len(nidMap)) + i := 0 + for _, nid := range nidMap { + authNids[i] = nid + i++ + } + var stateAtEvent types.StateAtEvent + var redactedEventID string + var redactionEvent *gomatrixserverlib.Event + roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") + continue + } + // If storing this event results in it being redacted, then do so. + // It's also possible for this event to be a redaction which results in another event being + // redacted, which we don't care about since we aren't returning it in this backfill. + if redactedEventID == ev.EventID() { + eventToRedact := ev.Unwrap() + redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") + continue + } + ev = redactedEvent.Headered(ev.RoomVersion) + events[j] = ev + } + backfilledEventMap[ev.EventID()] = types.Event{ + EventNID: stateAtEvent.StateEntry.EventNID, + Event: ev.Unwrap(), + } + } + return roomNID, backfilledEventMap +} diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform/perform_invite.go similarity index 77% rename from roomserver/internal/perform_invite.go rename to roomserver/internal/perform/perform_invite.go index 1cfbcc18c..e06ad062d 100644 --- a/roomserver/internal/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -1,11 +1,28 @@ -package internal +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform import ( "context" "fmt" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -13,22 +30,29 @@ import ( log "github.com/sirupsen/logrus" ) +type Inviter struct { + DB storage.Database + Cfg *config.RoomServer + FSAPI federationSenderAPI.FederationSenderInternalAPI + Inputer *input.Inputer +} + // nolint:gocyclo -func (r *RoomserverInternalAPI) PerformInvite( +func (r *Inviter) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, res *api.PerformInviteResponse, -) error { +) ([]api.OutputEvent, error) { event := req.Event if event.StateKey() == nil { - return fmt.Errorf("invite must be a state event") + return nil, fmt.Errorf("invite must be a state event") } roomID := event.RoomID() targetUserID := *event.StateKey() info, err := r.DB.RoomInfo(ctx, roomID) if err != nil { - return fmt.Errorf("Failed to load RoomInfo: %w", err) + return nil, fmt.Errorf("Failed to load RoomInfo: %w", err) } log.WithFields(log.Fields{ @@ -52,11 +76,11 @@ func (r *RoomserverInternalAPI) PerformInvite( } if len(inviteState) == 0 { if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil { - return fmt.Errorf("event.SetUnsignedField: %w", err) + return nil, fmt.Errorf("event.SetUnsignedField: %w", err) } } else { if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil { - return fmt.Errorf("event.SetUnsignedField: %w", err) + return nil, fmt.Errorf("event.SetUnsignedField: %w", err) } } @@ -64,7 +88,7 @@ func (r *RoomserverInternalAPI) PerformInvite( if info != nil { _, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) if err != nil { - return fmt.Errorf("r.DB.GetMembership: %w", err) + return nil, fmt.Errorf("r.DB.GetMembership: %w", err) } } if isAlreadyJoined { @@ -99,7 +123,7 @@ func (r *RoomserverInternalAPI) PerformInvite( Code: api.PerformErrorNotAllowed, Msg: "User is already joined to room", } - return nil + return nil, nil } if isOriginLocal { @@ -107,7 +131,7 @@ func (r *RoomserverInternalAPI) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = checkAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) if err != nil { log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -117,9 +141,9 @@ func (r *RoomserverInternalAPI) PerformInvite( Msg: err.Error(), Code: api.PerformErrorNotAllowed, } - return nil + return nil, nil } - return fmt.Errorf("checkAuthEvents: %w", err) + return nil, fmt.Errorf("checkAuthEvents: %w", err) } // If the invite originated from us and the target isn't local then we @@ -133,13 +157,13 @@ func (r *RoomserverInternalAPI) PerformInvite( InviteRoomState: inviteState, } fsRes := &federationSenderAPI.PerformInviteResponse{} - if err = r.fsAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { + if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { res.Error = &api.PerformError{ Msg: err.Error(), Code: api.PerformErrorNoOperation, } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.fsAPI.PerformInvite failed") - return nil + log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") + return nil, nil } event = fsRes.Event } @@ -159,8 +183,8 @@ func (r *RoomserverInternalAPI) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { - return fmt.Errorf("r.InputRoomEvents: %w", err) + if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } } else { // The invite originated over federation. Process the membership @@ -168,25 +192,23 @@ func (r *RoomserverInternalAPI) PerformInvite( // invite. updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { - return fmt.Errorf("r.DB.MembershipUpdater: %w", err) + return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } unwrapped := event.Unwrap() - outputUpdates, err := updateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion) + outputUpdates, err := helpers.UpdateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion) if err != nil { - return fmt.Errorf("updateToInviteMembership: %w", err) + return nil, fmt.Errorf("updateToInviteMembership: %w", err) } if err = updater.Commit(); err != nil { - return fmt.Errorf("updater.Commit: %w", err) + return nil, fmt.Errorf("updater.Commit: %w", err) } - if err = r.WriteOutputEvents(roomID, outputUpdates); err != nil { - return fmt.Errorf("r.WriteOutputEvents: %w", err) - } + return outputUpdates, nil } - return nil + return nil, nil } func buildInviteStrippedState( @@ -208,7 +230,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(db, *info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform/perform_join.go similarity index 76% rename from roomserver/internal/perform_join.go rename to roomserver/internal/perform/perform_join.go index 3b9b1b3ca..3d1942272 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -1,4 +1,18 @@ -package internal +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform import ( "context" @@ -8,14 +22,27 @@ import ( "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) +type Joiner struct { + ServerName gomatrixserverlib.ServerName + Cfg *config.RoomServer + FSAPI fsAPI.FederationSenderInternalAPI + DB storage.Database + + Inputer *input.Inputer +} + // PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender. -func (r *RoomserverInternalAPI) PerformJoin( +func (r *Joiner) PerformJoin( ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse, @@ -34,7 +61,7 @@ func (r *RoomserverInternalAPI) PerformJoin( res.RoomID = roomID } -func (r *RoomserverInternalAPI) performJoin( +func (r *Joiner) performJoin( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -63,7 +90,7 @@ func (r *RoomserverInternalAPI) performJoin( } } -func (r *RoomserverInternalAPI) performJoinRoomByAlias( +func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -85,7 +112,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias( ServerName: domain, // the server to ask } dirRes := fsAPI.PerformDirectoryLookupResponse{} - err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) + err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) if err != nil { logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) @@ -112,7 +139,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias( // TODO: Break this function up a bit // nolint:gocyclo -func (r *RoomserverInternalAPI) performJoinRoomByID( +func (r *Joiner) performJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -161,8 +188,8 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( // where we might think we know about a room in the following // section but don't know the latest state as all of our users // have left. - serverInRoom, _ := r.isServerCurrentlyInRoom(ctx, r.ServerName, req.RoomIDOrAlias) - isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) + serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) + isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) if err == nil && isInvitePending && !serverInRoom { // Check if there's an invite pending. _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) @@ -188,15 +215,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent( - ctx, // the request context - &eb, // the template join event - r.Cfg.Matrix, // the server configuration - time.Now(), // the event timestamp to use - r, // the roomserver API to use - &buildRes, // the query response - ) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) switch err { case nil: @@ -228,7 +247,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { var notAllowed *gomatrixserverlib.NotAllowed if errors.As(err, ¬Allowed) { return "", &api.PerformError{ @@ -271,7 +290,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( return req.RoomIDOrAlias, nil } -func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( +func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, ) error { @@ -283,7 +302,7 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( Content: req.Content, // the membership event content } fedRes := fsAPI.PerformJoinResponse{} - r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes) + r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { return &api.PerformError{ Code: api.PerformErrRemote, @@ -293,3 +312,31 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( } return nil } + +func buildEvent( + ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, +) (*gomatrixserverlib.HeaderedEvent, *api.QueryLatestEventsAndStateResponse, error) { + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + if err != nil { + return nil, nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + } + + if len(eventsNeeded.Tuples()) == 0 { + return nil, nil, errors.New("expecting state tuples for event builder, got none") + } + + var queryRes api.QueryLatestEventsAndStateResponse + err = helpers.QueryLatestEventsAndState(ctx, db, &api.QueryLatestEventsAndStateRequest{ + RoomID: builder.RoomID, + StateToFetch: eventsNeeded.Tuples(), + }, &queryRes) + if err != nil { + return nil, nil, fmt.Errorf("QueryLatestEventsAndState: %w", err) + } + + ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) + if err != nil { + return nil, nil, err + } + return ev, &queryRes, nil +} diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go new file mode 100644 index 000000000..aaa3b5b16 --- /dev/null +++ b/roomserver/internal/perform/perform_leave.go @@ -0,0 +1,183 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + "fmt" + "strings" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" +) + +type Leaver struct { + Cfg *config.RoomServer + DB storage.Database + FSAPI fsAPI.FederationSenderInternalAPI + + Inputer *input.Inputer +} + +// WriteOutputEvents implements OutputRoomEventWriter +func (r *Leaver) PerformLeave( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, +) ([]api.OutputEvent, error) { + _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return nil, fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID) + } + if domain != r.Cfg.Matrix.ServerName { + return nil, fmt.Errorf("User %q does not belong to this homeserver", req.UserID) + } + if strings.HasPrefix(req.RoomID, "!") { + return r.performLeaveRoomByID(ctx, req, res) + } + return nil, fmt.Errorf("Room ID %q is invalid", req.RoomID) +} + +func (r *Leaver) performLeaveRoomByID( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, // nolint:unparam +) ([]api.OutputEvent, error) { + // If there's an invite outstanding for the room then respond to + // that. + isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) + if err == nil && isInvitePending { + return r.performRejectInvite(ctx, req, res, senderUser, eventID) + } + + // There's no invite pending, so first of all we want to find out + // if the room exists and if the user is actually in it. + latestReq := api.QueryLatestEventsAndStateRequest{ + RoomID: req.RoomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: req.UserID, + }, + }, + } + latestRes := api.QueryLatestEventsAndStateResponse{} + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { + return nil, err + } + if !latestRes.RoomExists { + return nil, fmt.Errorf("Room %q does not exist", req.RoomID) + } + + // Now let's see if the user is in the room. + if len(latestRes.StateEvents) == 0 { + return nil, fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID) + } + membership, err := latestRes.StateEvents[0].Membership() + if err != nil { + return nil, fmt.Errorf("Error getting membership: %w", err) + } + if membership != gomatrixserverlib.Join { + // TODO: should be able to handle "invite" in this case too, if + // it's a case of kicking or banning or such + return nil, fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership) + } + + // Prepare the template for the leave event. + userID := req.UserID + eb := gomatrixserverlib.EventBuilder{ + Type: gomatrixserverlib.MRoomMember, + Sender: userID, + StateKey: &userID, + RoomID: req.RoomID, + Redacts: "", + } + if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { + return nil, fmt.Errorf("eb.SetContent: %w", err) + } + if err = eb.SetUnsigned(struct{}{}); err != nil { + return nil, fmt.Errorf("eb.SetUnsigned: %w", err) + } + + // We know that the user is in the room at this point so let's build + // a leave event. + // TODO: Check what happens if the room exists on the server + // but everyone has since left. I suspect it does the wrong thing. + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + if err != nil { + return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) + } + + // Give our leave event to the roomserver input stream. The + // roomserver will process the membership change and notify + // downstream automatically. + inputReq := api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event.Headered(buildRes.RoomVersion), + AuthEventIDs: event.AuthEventIDs(), + SendAsServer: string(r.Cfg.Matrix.ServerName), + }, + }, + } + inputRes := api.InputRoomEventsResponse{} + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return nil, fmt.Errorf("r.InputRoomEvents: %w", err) + } + + return nil, nil +} + +func (r *Leaver) performRejectInvite( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, // nolint:unparam + senderUser, eventID string, +) ([]api.OutputEvent, error) { + _, domain, err := gomatrixserverlib.SplitID('@', senderUser) + if err != nil { + return nil, fmt.Errorf("User ID %q invalid: %w", senderUser, err) + } + + // Ask the federation sender to perform a federated leave for us. + leaveReq := fsAPI.PerformLeaveRequest{ + RoomID: req.RoomID, + UserID: req.UserID, + ServerNames: []gomatrixserverlib.ServerName{domain}, + } + leaveRes := fsAPI.PerformLeaveResponse{} + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + return nil, err + } + + // Withdraw the invite, so that the sync API etc are + // notified that we rejected it. + return []api.OutputEvent{ + { + Type: api.OutputTypeRetireInviteEvent, + RetireInviteEvent: &api.OutputRetireInviteEvent{ + EventID: eventID, + Membership: "leave", + TargetUserID: req.UserID, + }, + }, + }, nil +} diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go new file mode 100644 index 000000000..6ff42ac1a --- /dev/null +++ b/roomserver/internal/perform/perform_publish.go @@ -0,0 +1,39 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perform + +import ( + "context" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" +) + +type Publisher struct { + DB storage.Database +} + +func (r *Publisher) PerformPublish( + ctx context.Context, + req *api.PerformPublishRequest, + res *api.PerformPublishResponse, +) { + err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") + if err != nil { + res.Error = &api.PerformError{ + Msg: err.Error(), + } + } +} diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform_backfill.go deleted file mode 100644 index 65c88860c..000000000 --- a/roomserver/internal/perform_backfill.go +++ /dev/null @@ -1,305 +0,0 @@ -package internal - -import ( - "context" - - "github.com/matrix-org/dendrite/roomserver/auth" - "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" -) - -// backfillRequester implements gomatrixserverlib.BackfillRequester -type backfillRequester struct { - db storage.Database - fedClient *gomatrixserverlib.FederationClient - thisServer gomatrixserverlib.ServerName - bwExtrems map[string][]string - - // per-request state - servers []gomatrixserverlib.ServerName - eventIDToBeforeStateIDs map[string][]string - eventIDMap map[string]gomatrixserverlib.Event -} - -func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester { - return &backfillRequester{ - db: db, - fedClient: fedClient, - thisServer: thisServer, - eventIDToBeforeStateIDs: make(map[string][]string), - eventIDMap: make(map[string]gomatrixserverlib.Event), - bwExtrems: bwExtrems, - } -} - -func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) { - b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap() - if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { - return ids, nil - } - if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") { - util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room") - b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{} - return nil, nil - } - // if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event. - // Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or - // we don't know the result of state res to merge forks (2 or more prev_events) - if len(targetEvent.PrevEventIDs()) == 1 { - prevEventID := targetEvent.PrevEventIDs()[0] - prevEvent, ok := b.eventIDMap[prevEventID] - if !ok { - goto FederationHit - } - prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID] - if !ok { - goto FederationHit - } - newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs) - if newStateIDs != nil { - b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs - return newStateIDs, nil - } - // else we failed to calculate the new state, so fallthrough - } - -FederationHit: - var lastErr error - logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event") - for _, srv := range b.servers { // hit any valid server - c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, - RememberAuthEvents: false, - Server: srv, - } - res, err := c.StateIDsBeforeEvent(ctx, targetEvent) - if err != nil { - lastErr = err - continue - } - b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res - return res, nil - } - return nil, lastErr -} - -func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string { - newStateIDs := prevEventStateIDs[:] - if prevEvent.StateKey() == nil { - // state is the same as the previous event - b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs - return newStateIDs - } - - missingState := false // true if we are missing the info for a state event ID - foundEvent := false // true if we found a (type, state_key) match - // find which state ID to replace, if any - for i, id := range newStateIDs { - ev, ok := b.eventIDMap[id] - if !ok { - missingState = true - continue - } - // The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself - if ev.Type() == prevEvent.Type() && ev.StateKeyEquals(*prevEvent.StateKey()) { - newStateIDs[i] = prevEvent.EventID() - foundEvent = true - break - } - } - if !foundEvent && !missingState { - // we can be certain that this is new state - newStateIDs = append(newStateIDs, prevEvent.EventID()) - foundEvent = true - } - - if foundEvent { - b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs - return newStateIDs - } - return nil -} - -func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { - - // try to fetch the events from the database first - events, err := b.ProvideEvents(roomVer, eventIDs) - if err != nil { - // non-fatal, fallthrough - logrus.WithError(err).Info("Failed to fetch events") - } else { - logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs)) - if len(events) == len(eventIDs) { - result := make(map[string]*gomatrixserverlib.Event) - for i := range events { - result[events[i].EventID()] = &events[i] - b.eventIDMap[events[i].EventID()] = events[i] - } - return result, nil - } - } - - c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, - RememberAuthEvents: false, - Server: b.servers[0], - } - result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) - if err != nil { - return nil, err - } - for eventID, ev := range result { - b.eventIDMap[eventID] = *ev - } - return result, nil -} - -// ServersAtEvent is called when trying to determine which server to request from. -// It returns a list of servers which can be queried for backfill requests. These servers -// will be servers that are in the room already. The entries at the beginning are preferred servers -// and will be tried first. An empty list will fail the request. -func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { - // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use - // its successor, so look it up. - successor := "" -FindSuccessor: - for sucID, prevEventIDs := range b.bwExtrems { - for _, pe := range prevEventIDs { - if pe == eventID { - successor = sucID - break FindSuccessor - } - } - } - if successor == "" { - logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state") - return nil - } - eventID = successor - - // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for - // the event is necessary. - NIDs, err := b.db.EventNIDs(ctx, []string{eventID}) - if err != nil { - logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event") - return nil - } - - stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) - if err != nil { - logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") - return nil - } - - // possibly return all joined servers depending on history visiblity - memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) - if err != nil { - logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") - return nil - } - logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis)) - - // Retrieve all "m.room.member" state events of "join" membership, which - // contains the list of users in the room before the event, therefore all - // the servers in it at that moment. - memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true) - if err != nil { - logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") - return nil - } - memberEvents = append(memberEvents, memberEventsFromVis...) - - // Store the server names in a temporary map to avoid duplicates. - serverSet := make(map[gomatrixserverlib.ServerName]bool) - for _, event := range memberEvents { - serverSet[event.Origin()] = true - } - var servers []gomatrixserverlib.ServerName - for server := range serverSet { - if server == b.thisServer { - continue - } - servers = append(servers, server) - } - b.servers = servers - return servers -} - -// Backfill performs a backfill request to the given server. -// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, - fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) { - - tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs) - return &tx, err -} - -func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { - ctx := context.Background() - nidMap, err := b.db.EventNIDs(ctx, eventIDs) - if err != nil { - logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events") - return nil, err - } - eventNIDs := make([]types.EventNID, len(nidMap)) - i := 0 - for _, nid := range nidMap { - eventNIDs[i] = nid - i++ - } - eventsWithNids, err := b.db.Events(ctx, eventNIDs) - if err != nil { - logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") - return nil, err - } - events := make([]gomatrixserverlib.Event, len(eventsWithNids)) - for i := range eventsWithNids { - events[i] = eventsWithNids[i].Event - } - return events, nil -} - -// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility. -// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just -// pull all events and then filter by that table. -func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { - - var eventNIDs []types.EventNID - for _, entry := range stateEntries { - // Filter the events to retrieve to only keep the membership events - if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID { - eventNIDs = append(eventNIDs, entry.EventNID) - break - } - } - - // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - events := make([]gomatrixserverlib.Event, len(stateEvents)) - for i := range stateEvents { - events[i] = stateEvents[i].Event - } - visibility := auth.HistoryVisibilityForRoom(events) - if visibility != "shared" { - logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) - return nil, nil - } - // get joined members - info, err := db.RoomInfo(ctx, roomID) - if err != nil { - return nil, err - } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) - if err != nil { - return nil, err - } - return db.Events(ctx, joinEventNIDs) -} diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform_leave.go deleted file mode 100644 index b8603147c..000000000 --- a/roomserver/internal/perform_leave.go +++ /dev/null @@ -1,223 +0,0 @@ -package internal - -import ( - "context" - "fmt" - "strings" - "time" - - fsAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -// WriteOutputEvents implements OutputRoomEventWriter -func (r *RoomserverInternalAPI) PerformLeave( - ctx context.Context, - req *api.PerformLeaveRequest, - res *api.PerformLeaveResponse, -) error { - _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) - if err != nil { - return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID) - } - if domain != r.Cfg.Matrix.ServerName { - return fmt.Errorf("User %q does not belong to this homeserver", req.UserID) - } - if strings.HasPrefix(req.RoomID, "!") { - return r.performLeaveRoomByID(ctx, req, res) - } - return fmt.Errorf("Room ID %q is invalid", req.RoomID) -} - -func (r *RoomserverInternalAPI) performLeaveRoomByID( - ctx context.Context, - req *api.PerformLeaveRequest, - res *api.PerformLeaveResponse, // nolint:unparam -) error { - // If there's an invite outstanding for the room then respond to - // that. - isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID) - if err == nil && isInvitePending { - return r.performRejectInvite(ctx, req, res, senderUser, eventID) - } - - // There's no invite pending, so first of all we want to find out - // if the room exists and if the user is actually in it. - latestReq := api.QueryLatestEventsAndStateRequest{ - RoomID: req.RoomID, - StateToFetch: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: req.UserID, - }, - }, - } - latestRes := api.QueryLatestEventsAndStateResponse{} - if err = r.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil { - return err - } - if !latestRes.RoomExists { - return fmt.Errorf("Room %q does not exist", req.RoomID) - } - - // Now let's see if the user is in the room. - if len(latestRes.StateEvents) == 0 { - return fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID) - } - membership, err := latestRes.StateEvents[0].Membership() - if err != nil { - return fmt.Errorf("Error getting membership: %w", err) - } - if membership != gomatrixserverlib.Join { - // TODO: should be able to handle "invite" in this case too, if - // it's a case of kicking or banning or such - return fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership) - } - - // Prepare the template for the leave event. - userID := req.UserID - eb := gomatrixserverlib.EventBuilder{ - Type: gomatrixserverlib.MRoomMember, - Sender: userID, - StateKey: &userID, - RoomID: req.RoomID, - Redacts: "", - } - if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { - return fmt.Errorf("eb.SetContent: %w", err) - } - if err = eb.SetUnsigned(struct{}{}); err != nil { - return fmt.Errorf("eb.SetUnsigned: %w", err) - } - - // We know that the user is in the room at this point so let's build - // a leave event. - // TODO: Check what happens if the room exists on the server - // but everyone has since left. I suspect it does the wrong thing. - buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent( - ctx, // the request context - &eb, // the template leave event - r.Cfg.Matrix, // the server configuration - time.Now(), // the event timestamp to use - r, // the roomserver API to use - &buildRes, // the query response - ) - if err != nil { - return fmt.Errorf("eventutil.BuildEvent: %w", err) - } - - // Give our leave event to the roomserver input stream. The - // roomserver will process the membership change and notify - // downstream automatically. - inputReq := api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{ - { - Kind: api.KindNew, - Event: event.Headered(buildRes.RoomVersion), - AuthEventIDs: event.AuthEventIDs(), - SendAsServer: string(r.Cfg.Matrix.ServerName), - }, - }, - } - inputRes := api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return fmt.Errorf("r.InputRoomEvents: %w", err) - } - - return nil -} - -func (r *RoomserverInternalAPI) performRejectInvite( - ctx context.Context, - req *api.PerformLeaveRequest, - res *api.PerformLeaveResponse, // nolint:unparam - senderUser, eventID string, -) error { - _, domain, err := gomatrixserverlib.SplitID('@', senderUser) - if err != nil { - return fmt.Errorf("User ID %q invalid: %w", senderUser, err) - } - - // Ask the federation sender to perform a federated leave for us. - leaveReq := fsAPI.PerformLeaveRequest{ - RoomID: req.RoomID, - UserID: req.UserID, - ServerNames: []gomatrixserverlib.ServerName{domain}, - } - leaveRes := fsAPI.PerformLeaveResponse{} - if err := r.fsAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { - return err - } - - // Withdraw the invite, so that the sync API etc are - // notified that we rejected it. - return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{ - { - Type: api.OutputTypeRetireInviteEvent, - RetireInviteEvent: &api.OutputRetireInviteEvent{ - EventID: eventID, - Membership: "leave", - TargetUserID: req.UserID, - }, - }, - }) -} - -func (r *RoomserverInternalAPI) isInvitePending( - ctx context.Context, - roomID, userID string, -) (bool, string, string, error) { - // Look up the room NID for the supplied room ID. - info, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) - } - if info == nil { - return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) - } - - // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID}) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) - } - targetUserNID, targetUserFound := targetUserNIDs[userID] - if !targetUserFound { - return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) - } - - // Let's see if we have an event active for the user in the room. If - // we do then it will contain a server name that we can direct the - // send_leave to. - senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) - } - if len(senderUserNIDs) == 0 { - return false, "", "", nil - } - userNIDToEventID := make(map[types.EventStateKeyNID]string) - for i, nid := range senderUserNIDs { - userNIDToEventID[nid] = eventIDs[i] - } - - // Look up the user ID from the NID. - senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) - } - if len(senderUsers) == 0 { - return false, "", "", fmt.Errorf("no senderUsers") - } - - senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] - if !senderUserFound { - return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) - } - - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil -} diff --git a/roomserver/internal/perform_publish.go b/roomserver/internal/perform_publish.go deleted file mode 100644 index d7863620a..000000000 --- a/roomserver/internal/perform_publish.go +++ /dev/null @@ -1,20 +0,0 @@ -package internal - -import ( - "context" - - "github.com/matrix-org/dendrite/roomserver/api" -) - -func (r *RoomserverInternalAPI) PerformPublish( - ctx context.Context, - req *api.PerformPublishRequest, - res *api.PerformPublishResponse, -) { - err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") - if err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - } - } -} diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go deleted file mode 100644 index 897164330..000000000 --- a/roomserver/internal/query.go +++ /dev/null @@ -1,960 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "fmt" - - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/auth" - "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/dendrite/roomserver/version" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" -) - -// QueryLatestEventsAndState implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info.IsStub { - return nil - } - response.RoomExists = true - response.RoomVersion = roomVersion - - var currentStateSnapshotNID types.StateSnapshotNID - response.LatestEvents, currentStateSnapshotNID, response.Depth, err = - r.DB.LatestEventIDs(ctx, info.RoomNID) - if err != nil { - return err - } - - var stateEntries []types.StateEntry - if len(request.StateToFetch) == 0 { - // Look up all room state. - stateEntries, err = roomState.LoadStateAtSnapshot( - ctx, currentStateSnapshotNID, - ) - } else { - // Look up the current state for the requested tuples. - stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples( - ctx, currentStateSnapshotNID, request.StateToFetch, - ) - } - if err != nil { - return err - } - - stateEvents, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return err - } - - for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) - } - - return nil -} - -// QueryStateAfterEvents implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info.IsStub { - return nil - } - response.RoomExists = true - response.RoomVersion = roomVersion - - prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) - if err != nil { - switch err.(type) { - case types.MissingEventError: - return nil - default: - return err - } - } - response.PrevEventsExist = true - - // Look up the currrent state for the requested tuples. - stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, info.RoomNID, prevStates, request.StateToFetch, - ) - if err != nil { - return err - } - - stateEvents, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return err - } - - for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) - } - - return nil -} - -// QueryEventsByID implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) - if err != nil { - return err - } - - var eventNIDs []types.EventNID - for _, nid := range eventNIDMap { - eventNIDs = append(eventNIDs, nid) - } - - events, err := r.loadEvents(ctx, eventNIDs) - if err != nil { - return err - } - - for _, event := range events { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) - } - - return nil -} - -func (r *RoomserverInternalAPI) loadStateEvents( - ctx context.Context, stateEntries []types.StateEntry, -) ([]gomatrixserverlib.Event, error) { - eventNIDs := make([]types.EventNID, len(stateEntries)) - for i := range stateEntries { - eventNIDs[i] = stateEntries[i].EventNID - } - return r.loadEvents(ctx, eventNIDs) -} - -func (r *RoomserverInternalAPI) loadEvents( - ctx context.Context, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.Event, error) { - stateEvents, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - result := make([]gomatrixserverlib.Event, len(stateEvents)) - for i := range stateEvents { - result[i] = stateEvents[i].Event - } - return result, nil -} - -// QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryMembershipForUser( - ctx context.Context, - request *api.QueryMembershipForUserRequest, - response *api.QueryMembershipForUserResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) - if err != nil { - return err - } - - if membershipEventNID == 0 { - response.HasBeenInRoom = false - return nil - } - - response.IsInRoom = stillInRoom - response.HasBeenInRoom = true - - evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) - if err != nil { - return err - } - if len(evs) != 1 { - return fmt.Errorf("failed to load membership event for event NID %d", membershipEventNID) - } - - response.EventID = evs[0].EventID() - response.Membership, err = evs[0].Membership() - return err -} - -// QueryMembershipsForRoom implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryMembershipsForRoom( - ctx context.Context, - request *api.QueryMembershipsForRoomRequest, - response *api.QueryMembershipsForRoomResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - - membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) - if err != nil { - return err - } - - if membershipEventNID == 0 { - response.HasBeenInRoom = false - response.JoinEvents = nil - return nil - } - - response.HasBeenInRoom = true - response.JoinEvents = []gomatrixserverlib.ClientEvent{} - - var events []types.Event - var stateEntries []types.StateEntry - if stillInRoom { - var eventNIDs []types.EventNID - eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false) - if err != nil { - return err - } - - events, err = r.DB.Events(ctx, eventNIDs) - } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) - if err != nil { - logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") - return err - } - events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) - } - - if err != nil { - return err - } - - for _, event := range events { - clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) - response.JoinEvents = append(response.JoinEvents, clientEvent) - } - - return nil -} - -func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db) - // Lookup the event NID - eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) - if err != nil { - return nil, err - } - eventIDs := []string{eIDs[eventNID]} - - prevState, err := db.StateAtEventIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - // Fetch the state as it was when this event was fired - return roomState.LoadCombinedStateAfterEvents(ctx, prevState) -} - -// getMembershipsAtState filters the state events to -// only keep the "m.room.member" events with a "join" membership. These events are returned. -// Returns an error if there was an issue fetching the events. -func getMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, -) ([]types.Event, error) { - - var eventNIDs []types.EventNID - for _, entry := range stateEntries { - // Filter the events to retrieve to only keep the membership events - if entry.EventTypeNID == types.MRoomMemberNID { - eventNIDs = append(eventNIDs, entry.EventNID) - } - } - - // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - if !joinedOnly { - return stateEvents, nil - } - - // Filter the events to only keep the "join" membership events - var events []types.Event - for _, event := range stateEvents { - membership, err := event.Membership() - if err != nil { - return nil, err - } - - if membership == gomatrixserverlib.Join { - events = append(events, event) - } - } - - return events, nil -} - -// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( - ctx context.Context, - request *api.QueryServerAllowedToSeeEventRequest, - response *api.QueryServerAllowedToSeeEventResponse, -) (err error) { - events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) - if err != nil { - return - } - if len(events) == 0 { - response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see - return - } - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID()) - if err != nil { - return - } - response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, request.EventID, request.ServerName, isServerInRoom, - ) - return -} - -func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, -) (bool, error) { - roomState := state.NewStateResolution(r.DB) - stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) - if err != nil { - return false, err - } - - // TODO: We probably want to make it so that we don't have to pull - // out all the state if possible. - stateAtEvent, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return false, err - } - - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil -} - -// QueryMissingEvents implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryMissingEvents( - ctx context.Context, - request *api.QueryMissingEventsRequest, - response *api.QueryMissingEventsResponse, -) error { - var front []string - eventsToFilter := make(map[string]bool, len(request.LatestEvents)) - visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size. - for _, id := range request.EarliestEvents { - visited[id] = true - } - - for _, id := range request.LatestEvents { - if !visited[id] { - front = append(front, id) - eventsToFilter[id] = true - } - } - - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) - if err != nil { - return err - } - - loadedEvents, err := r.loadEvents(ctx, resultNIDs) - if err != nil { - return err - } - - response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) - for _, event := range loadedEvents { - if !eventsToFilter[event.EventID()] { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) - } - } - - return err -} - -// PerformBackfill implements api.RoomServerQueryAPI -func (r *RoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - // if we are requesting the backfill then we need to do a federation hit - // TODO: we could be more sensible and fetch as many events we already have then request the rest - // which is what the syncapi does already. - if request.ServerName == r.ServerName { - return r.backfillViaFederation(ctx, request, response) - } - // someone else is requesting the backfill, try to service their request. - var err error - var front []string - - // The limit defines the maximum number of events to retrieve, so it also - // defines the highest number of elements in the map below. - visited := make(map[string]bool, request.Limit) - - // this will include these events which is what we want - front = request.PrevEventIDs() - - // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) - if err != nil { - return err - } - - // Retrieve events from the list that was filled previously. - var loadedEvents []gomatrixserverlib.Event - loadedEvents, err = r.loadEvents(ctx, resultNIDs) - if err != nil { - return err - } - - for _, event := range loadedEvents { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) - } - - return err -} - -func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { - roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID) - if err != nil { - return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) - } - requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) - // Request 100 items regardless of what the query asks for. - // We don't want to go much higher than this. - // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass - // (so we don't need to hit /state_ids which the test has no listener for) - // Specifically the test "Outbound federation can backfill events" - events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100) - if err != nil { - return err - } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) - - // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) - if err != nil { - return err - } - - for _, ev := range backfilledEventMap { - // now add state for these events - stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] - if !ok { - // this should be impossible as all events returned must have pass Step 5 of the PDU checks - // which requires a list of state IDs. - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") - continue - } - var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { - // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs) - // try again - entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") - return err - } - } - - var beforeStateSnapshotNID types.StateSnapshotNID - if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") - return err - } - if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") - } - } - - // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. - - res.Events = events - return nil -} - -func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { - info, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return false, err - } - if info == nil { - return false, fmt.Errorf("unknown room %s", roomID) - } - - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) - if err != nil { - return false, err - } - - events, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return false, err - } - gmslEvents := make([]gomatrixserverlib.Event, len(events)) - for i := range events { - gmslEvents[i] = events[i].Event - } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil -} - -// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just -// best effort. -func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { - - servers := backfillRequester.servers - - // work out which are missing - nidMap, err := r.DB.EventNIDs(ctx, stateIDs) - if err != nil { - util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") - return - } - missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event - for _, id := range stateIDs { - if _, ok := nidMap[id]; !ok { - missingMap[id] = nil - } - } - util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) - - // fetch the events from federation. Loop the servers first so if we find one that works we stick with them - for _, srv := range servers { - for id, ev := range missingMap { - if ev != nil { - continue // already found - } - logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FedClient.GetEvent(ctx, srv, id) - if err != nil { - logger.WithError(err).Warn("failed to get event from server") - continue - } - loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) - if err != nil { - logger.WithError(err).Warn("failed to load and verify event") - continue - } - logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) - for _, res := range result { - if res.Error != nil { - logger.WithError(err).Warn("event failed PDU checks") - continue - } - missingMap[id] = res.Event - } - } - } - - var newEvents []gomatrixserverlib.HeaderedEvent - for _, ev := range missingMap { - if ev != nil { - newEvents = append(newEvents, *ev) - } - } - util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) -} - -// TODO: Remove this when we have tests to assert correctness of this function -// nolint:gocyclo -func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, front []string, visited map[string]bool, limit int, - serverName gomatrixserverlib.ServerName, -) ([]types.EventNID, error) { - var resultNIDs []types.EventNID - var err error - var allowed bool - var events []types.Event - var next []string - var pre string - - // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) - // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing - // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in - // duplicate events being sent in response to /backfill requests. - initialIgnoreList := make(map[string]bool, len(visited)) - for k, v := range visited { - initialIgnoreList[k] = v - } - - resultNIDs = make([]types.EventNID, 0, limit) - - var checkedServerInRoom bool - var isServerInRoom bool - - // Loop through the event IDs to retrieve the requested events and go - // through the whole tree (up to the provided limit) using the events' - // "prev_event" key. -BFSLoop: - for len(front) > 0 { - // Prevent unnecessary allocations: reset the slice only when not empty. - if len(next) > 0 { - next = make([]string, 0) - } - // Retrieve the events to process from the database. - events, err = r.DB.EventsFromIDs(ctx, front) - if err != nil { - return resultNIDs, err - } - - if !checkedServerInRoom && len(events) > 0 { - // It's nasty that we have to extract the room ID from an event, but many federation requests - // only talk in event IDs, no room IDs at all (!!!) - ev := events[0] - isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID()) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") - } - checkedServerInRoom = true - } - - for _, ev := range events { - // Break out of the loop if the provided limit is reached. - if len(resultNIDs) == limit { - break BFSLoop - } - - if !initialIgnoreList[ev.EventID()] { - // Update the list of events to retrieve. - resultNIDs = append(resultNIDs, ev.EventNID) - } - // Loop through the event's parents. - for _, pre = range ev.PrevEventIDs() { - // Only add an event to the list of next events to process if it - // hasn't been seen before. - if !visited[pre] { - visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom) - if err != nil { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( - "Error checking if allowed to see event", - ) - return resultNIDs, err - } - - // If the event hasn't been seen before and the HS - // requesting to retrieve it is allowed to do so, add it to - // the list of events to retrieve. - if allowed { - next = append(next, pre) - } else { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") - } - } - } - } - // Repeat the same process with the parent events we just processed. - front = next - } - - return resultNIDs, err -} - -// QueryStateAndAuthChain implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryStateAndAuthChain( - ctx context.Context, - request *api.QueryStateAndAuthChainRequest, - response *api.QueryStateAndAuthChainResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info.IsStub { - return nil - } - response.RoomExists = true - response.RoomVersion = info.RoomVersion - - stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) - if err != nil { - return err - } - response.PrevEventsExist = true - - // add the auth event IDs for the current state events too - var authEventIDs []string - authEventIDs = append(authEventIDs, request.AuthEventIDs...) - for _, se := range stateEvents { - authEventIDs = append(authEventIDs, se.AuthEventIDs()...) - } - authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe - - authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) - if err != nil { - return err - } - - if request.ResolveState { - if stateEvents, err = state.ResolveConflictsAdhoc( - info.RoomVersion, stateEvents, authEvents, - ); err != nil { - return err - } - } - - for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) - } - - for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) - } - - return err -} - -func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(r.DB) - prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) - if err != nil { - switch err.(type) { - case types.MissingEventError: - return nil, nil - default: - return nil, err - } - } - - // Look up the currrent state for the requested tuples. - stateEntries, err := roomState.LoadCombinedStateAfterEvents( - ctx, prevStates, - ) - if err != nil { - return nil, err - } - - return r.loadStateEvents(ctx, stateEntries) -} - -type eventsFromIDs func(context.Context, []string) ([]types.Event, error) - -// getAuthChain fetches the auth chain for the given auth events. An auth chain -// is the list of all events that are referenced in the auth_events section, and -// all their auth_events, recursively. The returned set of events contain the -// given events. Will *not* error if we don't have all auth events. -func getAuthChain( - ctx context.Context, fn eventsFromIDs, authEventIDs []string, -) ([]gomatrixserverlib.Event, error) { - // List of event IDs to fetch. On each pass, these events will be requested - // from the database and the `eventsToFetch` will be updated with any new - // events that we have learned about and need to find. When `eventsToFetch` - // is eventually empty, we should have reached the end of the chain. - eventsToFetch := authEventIDs - authEventsMap := make(map[string]gomatrixserverlib.Event) - - for len(eventsToFetch) > 0 { - // Try to retrieve the events from the database. - events, err := fn(ctx, eventsToFetch) - if err != nil { - return nil, err - } - - // We've now fetched these events so clear out `eventsToFetch`. Soon we may - // add newly discovered events to this for the next pass. - eventsToFetch = eventsToFetch[:0] - - for _, event := range events { - // Store the event in the event map - this prevents us from requesting it - // from the database again. - authEventsMap[event.EventID()] = event.Event - - // Extract all of the auth events from the newly obtained event. If we - // don't already have a record of the event, record it in the list of - // events we want to request for the next pass. - for _, authEvent := range event.AuthEvents() { - if _, ok := authEventsMap[authEvent.EventID]; !ok { - eventsToFetch = append(eventsToFetch, authEvent.EventID) - } - } - } - } - - // We've now retrieved all of the events we can. Flatten them down into an - // array and return them. - var authEvents []gomatrixserverlib.Event - for _, event := range authEventsMap { - authEvents = append(authEvents, event) - } - - return authEvents, nil -} - -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { - var roomNID types.RoomNID - backfilledEventMap := make(map[string]types.Event) - for j, ev := range events { - nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) - if err != nil { // this shouldn't happen as RequestBackfill already found them - logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") - continue - } - authNids := make([]types.EventNID, len(nidMap)) - i := 0 - for _, nid := range nidMap { - authNids[i] = nid - i++ - } - var stateAtEvent types.StateAtEvent - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") - continue - } - // If storing this event results in it being redacted, then do so. - // It's also possible for this event to be a redaction which results in another event being - // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = redactedEvent.Headered(ev.RoomVersion) - events[j] = ev - } - backfilledEventMap[ev.EventID()] = types.Event{ - EventNID: stateAtEvent.StateEntry.EventNID, - Event: ev.Unwrap(), - } - } - return roomNID, backfilledEventMap -} - -// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( - ctx context.Context, - request *api.QueryRoomVersionCapabilitiesRequest, - response *api.QueryRoomVersionCapabilitiesResponse, -) error { - response.DefaultRoomVersion = version.DefaultRoomVersion() - response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string) - for v, desc := range version.SupportedRoomVersions() { - if desc.Stable { - response.AvailableRoomVersions[v] = "stable" - } else { - response.AvailableRoomVersions[v] = "unstable" - } - } - return nil -} - -// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok { - response.RoomVersion = roomVersion - return nil - } - - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - return err - } - response.RoomVersion = roomVersion - r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - return nil -} - -func (r *RoomserverInternalAPI) QueryPublishedRooms( - ctx context.Context, - req *api.QueryPublishedRoomsRequest, - res *api.QueryPublishedRoomsResponse, -) error { - rooms, err := r.DB.GetPublishedRooms(ctx) - if err != nil { - return err - } - res.RoomIDs = rooms - return nil -} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go new file mode 100644 index 000000000..f76c93166 --- /dev/null +++ b/roomserver/internal/query/query.go @@ -0,0 +1,602 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package query + +import ( + "context" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/acls" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/roomserver/version" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type Queryer struct { + DB storage.Database + Cache caching.RoomServerCaches + ServerACLs *acls.ServerACLs +} + +// QueryLatestEventsAndState implements api.RoomserverInternalAPI +func (r *Queryer) QueryLatestEventsAndState( + ctx context.Context, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) +} + +// QueryStateAfterEvents implements api.RoomserverInternalAPI +func (r *Queryer) QueryStateAfterEvents( + ctx context.Context, + request *api.QueryStateAfterEventsRequest, + response *api.QueryStateAfterEventsResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return nil + } + + roomState := state.NewStateResolution(r.DB, *info) + response.RoomExists = true + response.RoomVersion = info.RoomVersion + + prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) + if err != nil { + switch err.(type) { + case types.MissingEventError: + return nil + default: + return err + } + } + response.PrevEventsExist = true + + // Look up the currrent state for the requested tuples. + stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( + ctx, prevStates, request.StateToFetch, + ) + if err != nil { + return err + } + + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + if err != nil { + return err + } + + for _, event := range stateEvents { + response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) + } + + return nil +} + +// QueryEventsByID implements api.RoomserverInternalAPI +func (r *Queryer) QueryEventsByID( + ctx context.Context, + request *api.QueryEventsByIDRequest, + response *api.QueryEventsByIDResponse, +) error { + eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) + if err != nil { + return err + } + + var eventNIDs []types.EventNID + for _, nid := range eventNIDMap { + eventNIDs = append(eventNIDs, nid) + } + + events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) + if err != nil { + return err + } + + for _, event := range events { + roomVersion, verr := r.roomVersion(event.RoomID()) + if verr != nil { + return verr + } + + response.Events = append(response.Events, event.Headered(roomVersion)) + } + + return nil +} + +// QueryMembershipForUser implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipForUser( + ctx context.Context, + request *api.QueryMembershipForUserRequest, + response *api.QueryMembershipForUserResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + + membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID) + if err != nil { + return err + } + + if membershipEventNID == 0 { + response.HasBeenInRoom = false + return nil + } + + response.IsInRoom = stillInRoom + response.HasBeenInRoom = true + + evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) + if err != nil { + return err + } + if len(evs) != 1 { + return fmt.Errorf("failed to load membership event for event NID %d", membershipEventNID) + } + + response.EventID = evs[0].EventID() + response.Membership, err = evs[0].Membership() + return err +} + +// QueryMembershipsForRoom implements api.RoomserverInternalAPI +func (r *Queryer) QueryMembershipsForRoom( + ctx context.Context, + request *api.QueryMembershipsForRoomRequest, + response *api.QueryMembershipsForRoomResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + + membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender) + if err != nil { + return err + } + + if membershipEventNID == 0 { + response.HasBeenInRoom = false + response.JoinEvents = nil + return nil + } + + response.HasBeenInRoom = true + response.JoinEvents = []gomatrixserverlib.ClientEvent{} + + var events []types.Event + var stateEntries []types.StateEntry + if stillInRoom { + var eventNIDs []types.EventNID + eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false) + if err != nil { + return err + } + + events, err = r.DB.Events(ctx, eventNIDs) + } else { + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) + if err != nil { + logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") + return err + } + events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + } + + if err != nil { + return err + } + + for _, event := range events { + clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + response.JoinEvents = append(response.JoinEvents, clientEvent) + } + + return nil +} + +// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI +func (r *Queryer) QueryServerAllowedToSeeEvent( + ctx context.Context, + request *api.QueryServerAllowedToSeeEventRequest, + response *api.QueryServerAllowedToSeeEventResponse, +) (err error) { + events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) + if err != nil { + return + } + if len(events) == 0 { + response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see + return + } + roomID := events[0].RoomID() + isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID) + if err != nil { + return + } + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) + } + response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom, + ) + return +} + +// QueryMissingEvents implements api.RoomserverInternalAPI +// nolint:gocyclo +func (r *Queryer) QueryMissingEvents( + ctx context.Context, + request *api.QueryMissingEventsRequest, + response *api.QueryMissingEventsResponse, +) error { + var front []string + eventsToFilter := make(map[string]bool, len(request.LatestEvents)) + visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size. + for _, id := range request.EarliestEvents { + visited[id] = true + } + + for _, id := range request.LatestEvents { + if !visited[id] { + front = append(front, id) + eventsToFilter[id] = true + } + } + events, err := r.DB.EventsFromIDs(ctx, front) + if err != nil { + return err + } + if len(events) == 0 { + return nil // we are missing the events being asked to search from, give up. + } + info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + } + + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + if err != nil { + return err + } + + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) + if err != nil { + return err + } + + response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) + for _, event := range loadedEvents { + if !eventsToFilter[event.EventID()] { + roomVersion, verr := r.roomVersion(event.RoomID()) + if verr != nil { + return verr + } + + response.Events = append(response.Events, event.Headered(roomVersion)) + } + } + + return err +} + +// QueryStateAndAuthChain implements api.RoomserverInternalAPI +func (r *Queryer) QueryStateAndAuthChain( + ctx context.Context, + request *api.QueryStateAndAuthChainRequest, + response *api.QueryStateAndAuthChainResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return nil + } + response.RoomExists = true + response.RoomVersion = info.RoomVersion + + stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) + if err != nil { + return err + } + response.PrevEventsExist = true + + // add the auth event IDs for the current state events too + var authEventIDs []string + authEventIDs = append(authEventIDs, request.AuthEventIDs...) + for _, se := range stateEvents { + authEventIDs = append(authEventIDs, se.AuthEventIDs()...) + } + authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe + + authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs) + if err != nil { + return err + } + + if request.ResolveState { + if stateEvents, err = state.ResolveConflictsAdhoc( + info.RoomVersion, stateEvents, authEvents, + ); err != nil { + return err + } + } + + for _, event := range stateEvents { + response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) + } + + for _, event := range authEvents { + response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion)) + } + + return err +} + +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { + roomState := state.NewStateResolution(r.DB, roomInfo) + prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) + if err != nil { + switch err.(type) { + case types.MissingEventError: + return nil, nil + default: + return nil, err + } + } + + // Look up the currrent state for the requested tuples. + stateEntries, err := roomState.LoadCombinedStateAfterEvents( + ctx, prevStates, + ) + if err != nil { + return nil, err + } + + return helpers.LoadStateEvents(ctx, r.DB, stateEntries) +} + +type eventsFromIDs func(context.Context, []string) ([]types.Event, error) + +// getAuthChain fetches the auth chain for the given auth events. An auth chain +// is the list of all events that are referenced in the auth_events section, and +// all their auth_events, recursively. The returned set of events contain the +// given events. Will *not* error if we don't have all auth events. +func getAuthChain( + ctx context.Context, fn eventsFromIDs, authEventIDs []string, +) ([]gomatrixserverlib.Event, error) { + // List of event IDs to fetch. On each pass, these events will be requested + // from the database and the `eventsToFetch` will be updated with any new + // events that we have learned about and need to find. When `eventsToFetch` + // is eventually empty, we should have reached the end of the chain. + eventsToFetch := authEventIDs + authEventsMap := make(map[string]gomatrixserverlib.Event) + + for len(eventsToFetch) > 0 { + // Try to retrieve the events from the database. + events, err := fn(ctx, eventsToFetch) + if err != nil { + return nil, err + } + + // We've now fetched these events so clear out `eventsToFetch`. Soon we may + // add newly discovered events to this for the next pass. + eventsToFetch = eventsToFetch[:0] + + for _, event := range events { + // Store the event in the event map - this prevents us from requesting it + // from the database again. + authEventsMap[event.EventID()] = event.Event + + // Extract all of the auth events from the newly obtained event. If we + // don't already have a record of the event, record it in the list of + // events we want to request for the next pass. + for _, authEvent := range event.AuthEvents() { + if _, ok := authEventsMap[authEvent.EventID]; !ok { + eventsToFetch = append(eventsToFetch, authEvent.EventID) + } + } + } + } + + // We've now retrieved all of the events we can. Flatten them down into an + // array and return them. + var authEvents []gomatrixserverlib.Event + for _, event := range authEventsMap { + authEvents = append(authEvents, event) + } + + return authEvents, nil +} + +// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI +func (r *Queryer) QueryRoomVersionCapabilities( + ctx context.Context, + request *api.QueryRoomVersionCapabilitiesRequest, + response *api.QueryRoomVersionCapabilitiesResponse, +) error { + response.DefaultRoomVersion = version.DefaultRoomVersion() + response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string) + for v, desc := range version.SupportedRoomVersions() { + if desc.Stable { + response.AvailableRoomVersions[v] = "stable" + } else { + response.AvailableRoomVersions[v] = "unstable" + } + } + return nil +} + +// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI +func (r *Queryer) QueryRoomVersionForRoom( + ctx context.Context, + request *api.QueryRoomVersionForRoomRequest, + response *api.QueryRoomVersionForRoomResponse, +) error { + if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok { + response.RoomVersion = roomVersion + return nil + } + + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) + } + response.RoomVersion = info.RoomVersion + r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) + return nil +} + +func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { + var res api.QueryRoomVersionForRoomResponse + err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{ + RoomID: roomID, + }, &res) + return res.RoomVersion, err +} + +func (r *Queryer) QueryPublishedRooms( + ctx context.Context, + req *api.QueryPublishedRoomsRequest, + res *api.QueryPublishedRoomsResponse, +) error { + rooms, err := r.DB.GetPublishedRooms(ctx) + if err != nil { + return err + } + res.RoomIDs = rooms + return nil +} + +func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + for _, tuple := range req.StateTuples { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } + } + return nil +} + +func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) + if err != nil { + return err + } + res.RoomIDs = roomIDs + return nil +} + +func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) + if err != nil { + return err + } + for _, user := range users { + res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ + UserID: user, + }) + } + return nil +} + +func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) + if err != nil { + return err + } + res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for _, ev := range events { + if res.Rooms[ev.RoomID] == nil { + res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string) + } + room := res.Rooms[ev.RoomID] + room[gomatrixserverlib.StateKeyTuple{ + EventType: ev.EventType, + StateKey: ev.StateKey, + }] = ev.ContentValue + res.Rooms[ev.RoomID] = room + } + return nil +} + +func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") + if err != nil { + return err + } + roomIDs = append(roomIDs, req.IncludeRoomIDs...) + excludeMap := make(map[string]bool) + for _, roomID := range req.ExcludeRoomIDs { + excludeMap[roomID] = true + } + // filter out excluded rooms + j := 0 + for i := range roomIDs { + // move elements to include to the beginning of the slice + // then trim elements on the right + if !excludeMap[roomIDs[i]] { + roomIDs[j] = roomIDs[i] + j++ + } + } + roomIDs = roomIDs[:j] + + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + if err != nil { + return err + } + res.UserIDsToCount = users + return nil +} + +func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + if r.ServerACLs == nil { + return errors.New("no server ACL tracking") + } + res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) + return nil +} diff --git a/roomserver/internal/query_test.go b/roomserver/internal/query/query_test.go similarity index 98% rename from roomserver/internal/query_test.go rename to roomserver/internal/query/query_test.go index 92e008324..b4cb99b85 100644 --- a/roomserver/internal/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package query import ( "context" diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 4ffc3c8bb..1ff1fc82b 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -44,6 +44,12 @@ const ( RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" + RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState" + RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser" + RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" + RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" + RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" + RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" ) type httpRoomserverInternalAPI struct { @@ -389,3 +395,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( } return err } + +func (h *httpRoomserverInternalAPI) QueryCurrentState( + ctx context.Context, + request *api.QueryCurrentStateRequest, + response *api.QueryCurrentStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryRoomsForUser( + ctx context.Context, + request *api.QueryRoomsForUserRequest, + response *api.QueryRoomsForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryBulkStateContent( + ctx context.Context, + request *api.QueryBulkStateContentRequest, + response *api.QueryBulkStateContentResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QuerySharedUsers( + ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryKnownUsers( + ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( + ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 0ac36a2a4..ebfb296d8 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryCurrentStatePath, + httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { + request := api.QueryCurrentStateRequest{} + response := api.QueryCurrentStateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryRoomsForUserPath, + httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { + request := api.QueryRoomsForUserRequest{} + response := api.QueryRoomsForUserResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryBulkStateContentPath, + httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { + request := api.QueryBulkStateContentRequest{} + response := api.QueryBulkStateContentResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { + request := api.QuerySharedUsersRequest{} + response := api.QuerySharedUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryKnownUsersRequest{} + response := api.QueryKnownUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, + httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { + request := api.QueryServerBannedFromRoomRequest{} + response := api.QueryServerBannedFromRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 21af5f32d..2eabf4504 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -38,7 +38,6 @@ func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { func NewInternalAPI( base *setup.BaseDendrite, keyRing gomatrixserverlib.JSONVerifier, - fedClient *gomatrixserverlib.FederationClient, ) api.RoomserverInternalAPI { cfg := &base.Cfg.RoomServer @@ -47,14 +46,8 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - return &internal.RoomserverInternalAPI{ - DB: roomserverDB, - Cfg: cfg, - Producer: base.KafkaProducer, - OutputRoomEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), - Cache: base.Caches, - ServerName: cfg.Matrix.ServerName, - FedClient: fedClient, - KeyRing: keyRing, - } + return internal.NewRoomserverAPI( + cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), + base.Caches, keyRing, + ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index bcd9afb38..786d4f31f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -112,10 +112,9 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js Cfg: cfg, } - rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}, nil) + rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) hevents := mustLoadEvents(t, ver, events) - _, err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil) - if err != nil { + if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b9ad4a504..37e6807a3 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -31,12 +31,14 @@ import ( ) type StateResolution struct { - db storage.Database + db storage.Database + roomInfo types.RoomInfo } -func NewStateResolution(db storage.Database) StateResolution { +func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { return StateResolution{ - db: db, + db: db, + roomInfo: roomInfo, } } @@ -339,7 +341,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples( // This is typically the state before an event. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func (v StateResolution) LoadStateAfterEventsForStringTuples( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { @@ -347,24 +349,18 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples( if err != nil { return nil, err } - return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) + return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) } func (v StateResolution) loadStateAfterEventsForNumericTuples( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return nil, err - } - if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] - var result []types.StateEntry - result, err = v.loadStateAtSnapshotForNumericTuples( + result, err := v.loadStateAtSnapshotForNumericTuples( ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, ) if err != nil { @@ -403,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples( // TODO: Add metrics for this as it could take a long time for big rooms // with large conflicts. - fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) if err != nil { return nil, err } @@ -527,7 +523,6 @@ func init() { func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, event gomatrixserverlib.Event, - roomNID types.RoomNID, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -542,14 +537,13 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent( } // The state before this event will be the state after the events that came before it. - return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) + return v.CalculateAndStoreStateAfterEvents(ctx, prevStates) } // CalculateAndStoreStateAfterEvents finds the room state after the given events. // Stores the resulting state in the database and returns a numeric ID for that snapshot. func (v StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, - roomNID types.RoomNID, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} @@ -558,7 +552,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) } @@ -590,7 +584,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" stateNID, err := v.db.AddState( - ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) @@ -601,7 +595,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // So fall through to calculateAndStoreStateAfterManyEvents } - stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics) if err != nil { return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) } @@ -624,13 +618,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return metrics.stop(0, err) - } - state, algorithm, conflictLength, err := - v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm if err != nil { return metrics.stop(0, err) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 5f6416145..c4119f7ed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -17,6 +17,7 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" @@ -66,8 +67,6 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) - // Look up a room version from the room NID. - GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -91,7 +90,7 @@ type Database interface { // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error) + GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) // Look up event ID by transaction's info. // This is used to determine if the room event is processed/processing already. // Returns an empty string if no such event exists. @@ -136,10 +135,26 @@ type Database interface { // not found. // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) - // Look up the room version for a given room. - GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID string, publish bool) error // Returns a list of room IDs for rooms which are published. GetPublishedRooms(ctx context.Context) ([]string, error) + + // TODO: factor out - from currentstateserver + + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). + GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. + // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. + GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) + // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. + JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetKnownUsers searches all users that userID knows about. + GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 13cef638f..5164f654f 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -18,7 +18,9 @@ package postgres import ( "context" "database/sql" + "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -99,6 +105,19 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + " WHERE room_nid = $1 AND target_nid = $2" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid = ANY(" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -108,6 +127,9 @@ type membershipStatements struct { selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -126,6 +148,9 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -222,3 +247,61 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + roomIDarray := make([]int64, len(roomNIDs)) + for i := range roomNIDs { + roomIDarray[i] = int64(roomNIDs[i]) + } + rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 691c04ba6..ef1b7891a 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,6 +21,7 @@ import ( "errors" "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -68,24 +69,32 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" -const selectRoomVersionForRoomIDSQL = "" + - "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" - const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt - selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt + bulkSelectRoomIDsStmt *sql.Stmt + bulkSelectRoomNIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -100,12 +109,30 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, - {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, + {&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, @@ -192,18 +219,6 @@ func (s *roomStatements) UpdateLatestEventNIDs( return err } -func (s *roomStatements) SelectRoomVersionForRoomID( - ctx context.Context, txn *sql.Tx, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - var roomVersion gomatrixserverlib.RoomVersion - stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) - if err == sql.ErrNoRows { - return roomVersion, errors.New("room not found") - } - return roomVersion, err -} - func (s *roomStatements) SelectRoomVersionForRoomNID( ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { @@ -214,3 +229,45 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + var array pq.Int64Array + for _, nid := range roomNIDs { + array = append(array, int64(nid)) + } + rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + var array pq.StringArray + for _, roomID := range roomIDs { + array = append(array, roomID) + } + rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index e9a0f6982..29eab0c98 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -12,15 +12,15 @@ import ( type LatestEventsUpdater struct { transaction d *Database - roomNID types.RoomNID + roomInfo types.RoomInfo latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID } -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) { +func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -39,14 +39,13 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomN } } return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } // RoomVersion implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return + return u.roomInfo.RoomVersion } // LatestEvents implements types.RoomRecentEventsUpdater @@ -118,5 +117,5 @@ func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4af61be8f..a3b33a4fe 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,13 +5,16 @@ import ( "database/sql" "encoding/json" "fmt" + "sort" + csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -229,30 +232,6 @@ func (d *Database) StateEntries( return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - return roomVersion, nil - } - return d.RoomsTable.SelectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - return roomVersion, nil - } - } - return d.RoomsTable.SelectRoomVersionForRoomNID( - ctx, roomNID, - ) -} - func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) @@ -387,7 +366,7 @@ func (d *Database) MembershipUpdater( } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*LatestEventsUpdater, error) { txn, err := d.DB.Begin() if err != nil { @@ -395,7 +374,7 @@ func (d *Database) GetLatestEventsForUpdate( } var updater *LatestEventsUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) + updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) return nil }) return updater, err @@ -735,3 +714,190 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { } return &evs[0] } + +// GetStateEvent returns the current state event of a given type for a given room with a given state key +// If no event could be found, returns nil +// If there was an issue during the retrieval, returns an error +func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + // return the event requested + for _, e := range entries { + if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID) + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + h := ev.Headered(roomInfo.RoomVersion) + return &h, nil + } + } + + return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID) +} + +// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). +func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { + var membershipState tables.MembershipState + switch membership { + case "join": + membershipState = tables.MembershipStateJoin + case "invite": + membershipState = tables.MembershipStateInvite + case "leave": + membershipState = tables.MembershipStateLeaveOrBan + case "ban": + membershipState = tables.MembershipStateLeaveOrBan + default: + return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) + } + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + if err != nil { + return nil, err + } + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + if err != nil { + return nil, err + } + if len(roomIDs) != len(roomNIDs) { + return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs)) + } + return roomIDs, nil +} + +// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. +// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. +func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + if err != nil { + return nil, err + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + if err != nil { + return nil, err + } + stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount)) + i := 0 + for nid := range userNIDToCount { + stateKeyNIDs[i] = nid + i++ + } + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + if err != nil { + return nil, err + } + if len(nidToUserID) != len(userNIDToCount) { + return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID)) + } + result := make(map[string]int, len(userNIDToCount)) + for nid, count := range userNIDToCount { + result[nidToUserID[nid]] = count + } + return result, nil +} + +// GetKnownUsers searches all users that userID knows about. +func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, err + } + return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) +} + +// GetKnownRooms returns a list of all rooms we know about. +func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { + return d.RoomsTable.SelectRoomIDs(ctx) +} + +// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops +// it should live in this package! + +func (d *Database) loadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index b3ee69c00..0d5ce516d 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -38,6 +40,10 @@ const membershipSchema = ` ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -75,6 +81,19 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + " WHERE room_nid = $4 AND target_nid = $5" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid IN (" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -84,7 +103,9 @@ type membershipStatements struct { selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -105,6 +126,8 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -203,3 +226,62 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index fc1bcf22f..b4564aff9 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "errors" "fmt" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -58,15 +60,21 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" -const selectRoomVersionForRoomIDSQL = "" + - "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" - const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -74,9 +82,9 @@ type roomStatements struct { selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt - selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -93,12 +101,29 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, - {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string @@ -198,18 +223,6 @@ func (s *roomStatements) UpdateLatestEventNIDs( return err } -func (s *roomStatements) SelectRoomVersionForRoomID( - ctx context.Context, txn *sql.Tx, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - var roomVersion gomatrixserverlib.RoomVersion - stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) - if err == sql.ErrNoRows { - return roomVersion, errors.New("room not found") - } - return roomVersion, err -} - func (s *roomStatements) SelectRoomVersionForRoomNID( ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { @@ -220,3 +233,47 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 33782171e..4a74bf736 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -150,7 +150,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*shared.LatestEventsUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional @@ -158,7 +158,7 @@ func (d *Database) GetLatestEventsForUpdate( // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID) + return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index ca9159d07..a142f2b1a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -63,9 +63,11 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) } type Transactions interface { @@ -121,6 +123,11 @@ type Membership interface { SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the + // counts of how many rooms they are joined. + SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) } type Published interface {