diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d6c54ead1..bbbdd82a4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,7 +2,7 @@ -* [ ] I have added tests for PR _or_ I have justified why this PR doesn't need tests. +* [ ] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [ ] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Your Name ` diff --git a/CHANGES.md b/CHANGES.md index ba14dd07a..55df36f96 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,19 @@ # Changelog +## Dendrite 0.10.6 (2022-11-01) + +### Features + +* History visibility checks have been optimised, which should speed up response times on a variety of endpoints (including `/sync`, `/messages`, `/context` and others) and reduce database load +* The built-in NATS Server has been updated to version 2.9.4 +* Some other minor dependencies have been updated + +### Fixes + +* A panic has been fixed in the sync API PDU stream which could cause requests to fail +* The `/members` response now contains the `room_id` field, which may fix some E2EE problems with clients using the JS SDK (contributed by [ashkitten](https://github.com/ashkitten)) +* The auth difference calculation in state resolution v2 has been tweaked for clarity (and moved into gomatrixserverlib with the rest of the state resolution code) + ## Dendrite 0.10.5 (2022-10-31) ### Features diff --git a/appservice/api/query.go b/appservice/api/query.go index 4d1cf9474..eb567b2ee 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -19,11 +19,13 @@ package api import ( "context" + "encoding/json" "errors" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AppServiceInternalAPI is used to query user and room alias data from application @@ -41,6 +43,10 @@ type AppServiceInternalAPI interface { req *UserIDExistsRequest, resp *UserIDExistsResponse, ) error + + Locations(ctx context.Context, req *LocationRequest, resp *LocationResponse) error + User(ctx context.Context, request *UserRequest, response *UserResponse) error + Protocols(ctx context.Context, req *ProtocolRequest, resp *ProtocolResponse) error } // RoomAliasExistsRequest is a request to an application service @@ -77,6 +83,73 @@ type UserIDExistsResponse struct { UserIDExists bool `json:"exists"` } +const ( + ASProtocolPath = "/_matrix/app/unstable/thirdparty/protocol/" + ASUserPath = "/_matrix/app/unstable/thirdparty/user" + ASLocationPath = "/_matrix/app/unstable/thirdparty/location" +) + +type ProtocolRequest struct { + Protocol string `json:"protocol,omitempty"` +} + +type ProtocolResponse struct { + Protocols map[string]ASProtocolResponse `json:"protocols"` + Exists bool `json:"exists"` +} + +type ASProtocolResponse struct { + FieldTypes map[string]FieldType `json:"field_types,omitempty"` // NOTSPEC: field_types is required by the spec + Icon string `json:"icon"` + Instances []ProtocolInstance `json:"instances"` + LocationFields []string `json:"location_fields"` + UserFields []string `json:"user_fields"` +} + +type FieldType struct { + Placeholder string `json:"placeholder"` + Regexp string `json:"regexp"` +} + +type ProtocolInstance struct { + Description string `json:"desc"` + Icon string `json:"icon,omitempty"` + NetworkID string `json:"network_id,omitempty"` // NOTSPEC: network_id is required by the spec + Fields json.RawMessage `json:"fields,omitempty"` // NOTSPEC: fields is required by the spec +} + +type UserRequest struct { + Protocol string `json:"protocol"` + Params string `json:"params"` +} + +type UserResponse struct { + Users []ASUserResponse `json:"users,omitempty"` + Exists bool `json:"exists,omitempty"` +} + +type ASUserResponse struct { + Protocol string `json:"protocol"` + UserID string `json:"userid"` + Fields json.RawMessage `json:"fields"` +} + +type LocationRequest struct { + Protocol string `json:"protocol"` + Params string `json:"params"` +} + +type LocationResponse struct { + Locations []ASLocationResponse `json:"locations,omitempty"` + Exists bool `json:"exists,omitempty"` +} + +type ASLocationResponse struct { + Alias string `json:"alias"` + Protocol string `json:"protocol"` + Fields json.RawMessage `json:"fields"` +} + // RetrieveUserProfile is a wrapper that queries both the local database and // application services for a given user's profile // TODO: Remove this, it's called from federationapi and clientapi but is a pure function diff --git a/appservice/appservice.go b/appservice/appservice.go index 9000adb1d..0c778b6ca 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "net/http" + "sync" "time" "github.com/gorilla/mux" @@ -58,8 +59,10 @@ func NewInternalAPI( // Create appserivce query API with an HTTP client that will be used for all // outbound and inbound requests (inbound only for the internal API) appserviceQueryAPI := &query.AppServiceQueryAPI{ - HTTPClient: client, - Cfg: &base.Cfg.AppServiceAPI, + HTTPClient: client, + Cfg: &base.Cfg.AppServiceAPI, + ProtocolCache: map[string]appserviceAPI.ASProtocolResponse{}, + CacheMu: sync.Mutex{}, } if len(base.Cfg.Derived.ApplicationServices) == 0 { diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go index 3ae2c9278..f7f164877 100644 --- a/appservice/inthttp/client.go +++ b/appservice/inthttp/client.go @@ -13,6 +13,9 @@ import ( const ( AppServiceRoomAliasExistsPath = "/appservice/RoomAliasExists" AppServiceUserIDExistsPath = "/appservice/UserIDExists" + AppServiceLocationsPath = "/appservice/locations" + AppServiceUserPath = "/appservice/users" + AppServiceProtocolsPath = "/appservice/protocols" ) // httpAppServiceQueryAPI contains the URL to an appservice query API and a @@ -58,3 +61,24 @@ func (h *httpAppServiceQueryAPI) UserIDExists( h.httpClient, ctx, request, response, ) } + +func (h *httpAppServiceQueryAPI) Locations(ctx context.Context, request *api.LocationRequest, response *api.LocationResponse) error { + return httputil.CallInternalRPCAPI( + "ASLocation", h.appserviceURL+AppServiceLocationsPath, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpAppServiceQueryAPI) User(ctx context.Context, request *api.UserRequest, response *api.UserResponse) error { + return httputil.CallInternalRPCAPI( + "ASUser", h.appserviceURL+AppServiceUserPath, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpAppServiceQueryAPI) Protocols(ctx context.Context, request *api.ProtocolRequest, response *api.ProtocolResponse) error { + return httputil.CallInternalRPCAPI( + "ASProtocols", h.appserviceURL+AppServiceProtocolsPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index 01d9f9895..ccf5c83d8 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -2,6 +2,7 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/httputil" ) @@ -17,4 +18,19 @@ func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { AppServiceUserIDExistsPath, httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), ) + + internalAPIMux.Handle( + AppServiceProtocolsPath, + httputil.MakeInternalRPCAPI("AppserviceProtocols", a.Protocols), + ) + + internalAPIMux.Handle( + AppServiceLocationsPath, + httputil.MakeInternalRPCAPI("AppserviceLocations", a.Locations), + ) + + internalAPIMux.Handle( + AppServiceUserPath, + httputil.MakeInternalRPCAPI("AppserviceUser", a.User), + ) } diff --git a/appservice/query/query.go b/appservice/query/query.go index 53b34cb18..2348eab4b 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -18,13 +18,18 @@ package query import ( "context" + "encoding/json" + "io" "net/http" "net/url" + "strings" + "sync" + + "github.com/opentracing/opentracing-go" + log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/setup/config" - opentracing "github.com/opentracing/opentracing-go" - log "github.com/sirupsen/logrus" ) const roomAliasExistsPath = "/rooms/" @@ -32,8 +37,10 @@ const userIDExistsPath = "/users/" // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { - HTTPClient *http.Client - Cfg *config.AppServiceAPI + HTTPClient *http.Client + Cfg *config.AppServiceAPI + ProtocolCache map[string]api.ASProtocolResponse + CacheMu sync.Mutex } // RoomAliasExists performs a request to '/room/{roomAlias}' on all known @@ -165,3 +172,178 @@ func (a *AppServiceQueryAPI) UserIDExists( response.UserIDExists = false return nil } + +type thirdpartyResponses interface { + api.ASProtocolResponse | []api.ASUserResponse | []api.ASLocationResponse +} + +func requestDo[T thirdpartyResponses](client *http.Client, url string, response *T) (err error) { + origURL := url + // try v1 and unstable appservice endpoints + for _, version := range []string{"v1", "unstable"} { + var resp *http.Response + var body []byte + asURL := strings.Replace(origURL, "unstable", version, 1) + resp, err = client.Get(asURL) + if err != nil { + continue + } + defer resp.Body.Close() // nolint: errcheck + body, err = io.ReadAll(resp.Body) + if err != nil { + continue + } + return json.Unmarshal(body, &response) + } + return err +} + +func (a *AppServiceQueryAPI) Locations( + ctx context.Context, + req *api.LocationRequest, + resp *api.LocationResponse, +) error { + params, err := url.ParseQuery(req.Params) + if err != nil { + return err + } + + for _, as := range a.Cfg.Derived.ApplicationServices { + var asLocations []api.ASLocationResponse + params.Set("access_token", as.HSToken) + + url := as.URL + api.ASLocationPath + if req.Protocol != "" { + url += "/" + req.Protocol + } + + if err := requestDo[[]api.ASLocationResponse](a.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { + log.WithError(err).Error("unable to get 'locations' from application service") + continue + } + + resp.Locations = append(resp.Locations, asLocations...) + } + + if len(resp.Locations) == 0 { + resp.Exists = false + return nil + } + resp.Exists = true + return nil +} + +func (a *AppServiceQueryAPI) User( + ctx context.Context, + req *api.UserRequest, + resp *api.UserResponse, +) error { + params, err := url.ParseQuery(req.Params) + if err != nil { + return err + } + + for _, as := range a.Cfg.Derived.ApplicationServices { + var asUsers []api.ASUserResponse + params.Set("access_token", as.HSToken) + + url := as.URL + api.ASUserPath + if req.Protocol != "" { + url += "/" + req.Protocol + } + + if err := requestDo[[]api.ASUserResponse](a.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { + log.WithError(err).Error("unable to get 'user' from application service") + continue + } + + resp.Users = append(resp.Users, asUsers...) + } + + if len(resp.Users) == 0 { + resp.Exists = false + return nil + } + resp.Exists = true + return nil +} + +func (a *AppServiceQueryAPI) Protocols( + ctx context.Context, + req *api.ProtocolRequest, + resp *api.ProtocolResponse, +) error { + + // get a single protocol response + if req.Protocol != "" { + + a.CacheMu.Lock() + defer a.CacheMu.Unlock() + if proto, ok := a.ProtocolCache[req.Protocol]; ok { + resp.Exists = true + resp.Protocols = map[string]api.ASProtocolResponse{ + req.Protocol: proto, + } + return nil + } + + response := api.ASProtocolResponse{} + for _, as := range a.Cfg.Derived.ApplicationServices { + var proto api.ASProtocolResponse + if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+req.Protocol, &proto); err != nil { + log.WithError(err).Error("unable to get 'protocol' from application service") + continue + } + + if len(response.Instances) != 0 { + response.Instances = append(response.Instances, proto.Instances...) + } else { + response = proto + } + } + + if len(response.Instances) == 0 { + resp.Exists = false + return nil + } + + resp.Exists = true + resp.Protocols = map[string]api.ASProtocolResponse{ + req.Protocol: response, + } + a.ProtocolCache[req.Protocol] = response + return nil + } + + response := make(map[string]api.ASProtocolResponse, len(a.Cfg.Derived.ApplicationServices)) + + for _, as := range a.Cfg.Derived.ApplicationServices { + for _, p := range as.Protocols { + var proto api.ASProtocolResponse + if err := requestDo[api.ASProtocolResponse](a.HTTPClient, as.URL+api.ASProtocolPath+p, &proto); err != nil { + log.WithError(err).Error("unable to get 'protocol' from application service") + continue + } + existing, ok := response[p] + if !ok { + response[p] = proto + continue + } + existing.Instances = append(existing.Instances, proto.Instances...) + response[p] = existing + } + } + + if len(response) == 0 { + resp.Exists = false + return nil + } + + a.CacheMu.Lock() + defer a.CacheMu.Unlock() + a.ProtocolCache = response + + resp.Exists = true + resp.Protocols = response + return nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 17e9d5cfd..f35aa7e12 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -869,12 +869,50 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocols", - httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { - // TODO: Return the third party protcols - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, + httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return Protocols(req, asAPI, device, "") + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/thirdparty/protocol/{protocolID}", + httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) } + return Protocols(req, asAPI, device, vars["protocolID"]) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/thirdparty/user/{protocolID}", + httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/thirdparty/user", + httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return User(req, asAPI, device, "", req.URL.Query()) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/thirdparty/location/{protocolID}", + httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/thirdparty/location", + httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return Location(req, asAPI, device, "", req.URL.Query()) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go new file mode 100644 index 000000000..e757cd411 --- /dev/null +++ b/clientapi/routing/thirdparty.go @@ -0,0 +1,106 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "net/url" + + "github.com/matrix-org/util" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/userapi/api" +) + +// Protocols implements +// +// GET /_matrix/client/v3/thirdparty/protocols/{protocol} +// GET /_matrix/client/v3/thirdparty/protocols +func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device *api.Device, protocol string) util.JSONResponse { + resp := &appserviceAPI.ProtocolResponse{} + + if err := asAPI.Protocols(req.Context(), &appserviceAPI.ProtocolRequest{Protocol: protocol}, resp); err != nil { + return jsonerror.InternalServerError() + } + if !resp.Exists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The protocol is unknown."), + } + } + if protocol != "" { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: resp.Protocols[protocol], + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: resp.Protocols, + } +} + +// User implements +// +// GET /_matrix/client/v3/thirdparty/user +// GET /_matrix/client/v3/thirdparty/user/{protocol} +func User(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device *api.Device, protocol string, params url.Values) util.JSONResponse { + resp := &appserviceAPI.UserResponse{} + + params.Del("access_token") + if err := asAPI.User(req.Context(), &appserviceAPI.UserRequest{ + Protocol: protocol, + Params: params.Encode(), + }, resp); err != nil { + return jsonerror.InternalServerError() + } + if !resp.Exists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The Matrix User ID was not found"), + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: resp.Users, + } +} + +// Location implements +// +// GET /_matrix/client/v3/thirdparty/location +// GET /_matrix/client/v3/thirdparty/location/{protocol} +func Location(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, device *api.Device, protocol string, params url.Values) util.JSONResponse { + resp := &appserviceAPI.LocationResponse{} + + params.Del("access_token") + if err := asAPI.Locations(req.Context(), &appserviceAPI.LocationRequest{ + Protocol: protocol, + Params: params.Encode(), + }, resp); err != nil { + return jsonerror.InternalServerError() + } + if !resp.Exists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("No portal rooms were found."), + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: resp.Locations, + } +} diff --git a/go.mod b/go.mod index 6dac2343f..61f327cae 100644 --- a/go.mod +++ b/go.mod @@ -22,11 +22,11 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 - github.com/nats-io/nats-server/v2 v2.9.3 - github.com/nats-io/nats.go v1.18.0 + github.com/nats-io/nats-server/v2 v2.9.4 + github.com/nats-io/nats.go v1.19.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 @@ -35,7 +35,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.0 + github.com/stretchr/testify v1.8.1 github.com/tidwall/gjson v1.14.3 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible @@ -43,8 +43,8 @@ require ( github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 golang.org/x/crypto v0.1.0 - golang.org/x/image v0.0.0-20220902085622-e7cb96979f69 - golang.org/x/mobile v0.0.0-20221012134814-c746ac228303 + golang.org/x/image v0.1.0 + golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e golang.org/x/net v0.1.0 golang.org/x/term v0.1.0 gopkg.in/h2non/bimg.v1 v1.1.9 @@ -121,11 +121,11 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect go.etcd.io/bbolt v1.3.6 // indirect - golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b // indirect + golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect golang.org/x/mod v0.6.0 // indirect golang.org/x/sys v0.1.0 // indirect golang.org/x/text v0.4.0 // indirect - golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect + golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/macaroon.v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 766fb47f4..de40f5f33 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74 h1:I4LUlFqxZ72m3s9wIvUIV2FpprsxW28dO/0lAgepCZY= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221031151122-0885c35ebe74/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= @@ -396,7 +396,8 @@ github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/em github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattomatic/dijkstra v0.0.0-20130617153013-6f6d134eb237/go.mod h1:UOnLAUmVG5paym8pD3C4B9BQylUDC2vXFJJpT7JrlEA= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= @@ -425,10 +426,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.3 h1:HrfzA7G9LNetKkm1z+jU/e9kuAe+E6uaBuuq9EB5sQQ= -github.com/nats-io/nats-server/v2 v2.9.3/go.mod h1:4sq8wvrpbvSzL1n3ZfEYnH4qeUuIl5W990j3kw13rRk= -github.com/nats-io/nats.go v1.18.0 h1:o480Ao6kuSSFyJO75rGTXCEPj7LGkY84C1Ye+Uhm4c0= -github.com/nats-io/nats.go v1.18.0/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= +github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= +github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= +github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -557,8 +558,9 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -566,8 +568,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= @@ -650,13 +653,13 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b h1:SCE/18RnFsLrjydh/R/s5EVvHoZprqEQUuoxK8q2Pc4= -golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 h1:QfTh0HpN6hlw6D3vu8DAwC8pBIwikq0AI1evdm+FksE= +golang.org/x/exp v0.0.0-20221031165847-c99f073a8326/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20220902085622-e7cb96979f69 h1:Lj6HJGCSn5AjxRAH2+r35Mir4icalbqku+CLUtjnvXY= -golang.org/x/image v0.0.0-20220902085622-e7cb96979f69/go.mod h1:doUCurBvlfPMKfmIpRIywoHmhN3VyhnoFDbvIEWF4hY= +golang.org/x/image v0.1.0 h1:r8Oj8ZA2Xy12/b5KZYj3tuv7NG/fBz3TwQVvpJ9l8Rk= +golang.org/x/image v0.1.0/go.mod h1:iyPr49SD/G/TBxYVB/9RRtGUT5eNbo2u4NamWeQcD5c= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -671,8 +674,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mobile v0.0.0-20220722155234-aaac322e2105/go.mod h1:pe2sM7Uk+2Su1y7u/6Z8KJ24D7lepUjFZbhFOrmDfuQ= -golang.org/x/mobile v0.0.0-20221012134814-c746ac228303 h1:K4fp1rDuJBz0FCPAWzIJwnzwNEM7S6yobdZzMrZ/Zws= -golang.org/x/mobile v0.0.0-20221012134814-c746ac228303/go.mod h1:M32cGdzp91A8Ex9qQtyZinr19EYxzkFqDjW2oyHzTDQ= +golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e h1:zSgtO19fpg781xknwqiQPmOHaASr6E7ZVlTseLd9Fx4= +golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -752,7 +755,6 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -821,7 +823,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220730100132-1609e554cd39/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -841,8 +844,8 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y= -golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= +golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/version.go b/internal/version.go index 7254ab102..f762adf90 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 5 + VersionPatch = 6 VersionTag = "" // example: "rc1" ) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a6de8ac84..7efad7af6 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "sort" "strings" "github.com/matrix-org/gomatrixserverlib" @@ -159,7 +160,7 @@ func GetMembershipsAtState( ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { - var eventNIDs []types.EventNID + var eventNIDs types.EventNIDs for _, entry := range stateEntries { // Filter the events to retrieve to only keep the membership events if entry.EventTypeNID == types.MRoomMemberNID { @@ -167,6 +168,14 @@ func GetMembershipsAtState( } } + // There are no events to get, don't bother asking the database + if len(eventNIDs) == 0 { + return []types.Event{}, nil + } + + sort.Sort(eventNIDs) + util.Unique(eventNIDs) + // Get all of the events in this state stateEvents, err := db.Events(ctx, eventNIDs) if err != nil { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 0db046a86..8850e5c46 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -239,16 +239,42 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("unable to get state before event: %w", err) } + // If we only have one or less state entries, we can short circuit the below + // loop and avoid hitting the database + allStateEventNIDs := make(map[types.EventNID]types.StateEntry) + for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + for _, s := range stateEntry { + allStateEventNIDs[s.EventNID] = s + } + } + + var canShortCircuit bool + if len(allStateEventNIDs) <= 1 { + canShortCircuit = true + } + + var memberships []types.Event for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] - if !ok { + if !ok || len(stateEntry) == 0 { response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} continue } - memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + + // If we can short circuit, e.g. we only have 0 or 1 membership events, we only get the memberships + // once. If we have more than one membership event, we need to get the state for each state entry. + if canShortCircuit { + if len(memberships) == 0 { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } + } else { + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) for i := range memberships { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index cb96d83ec..1cfde5e4b 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -18,17 +18,17 @@ package state import ( "context" - "database/sql" "fmt" "sort" "sync" "time" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" + + "github.com/matrix-org/dendrite/roomserver/types" ) type StateResolutionStorage interface { @@ -37,6 +37,7 @@ type StateResolutionStorage interface { StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) @@ -130,21 +131,10 @@ func (v *StateResolution) LoadMembershipAtEvent( span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") defer span.Finish() - // De-dupe snapshotNIDs - snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs - for i := range eventIDs { - eventID := eventIDs[i] - snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) - if err != nil && err != sql.ErrNoRows { - return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) - } - if snapshotNID == 0 { - // If we don't know a state snapshot for this event then we can't calculate - // memberships at the time of the event, so skip over it. This means that - // it isn't guaranteed that the response map will contain every single event. - continue - } - snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID) + // Get a mapping from snapshotNID -> eventIDs + snapshotNIDMap, err := v.db.BulkSelectSnapshotsFromEventIDs(ctx, eventIDs) + if err != nil { + return nil, err } snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap)) @@ -157,24 +147,45 @@ func (v *StateResolution) LoadMembershipAtEvent( return nil, err } + var wantStateBlocks []types.StateBlockNID + for _, x := range stateBlockNIDLists { + wantStateBlocks = append(wantStateBlocks, x.StateBlockNIDs...) + } + + stateEntryLists, err := v.db.StateEntriesForTuples(ctx, uniqueStateBlockNIDs(wantStateBlocks), []types.StateKeyTuple{ + { + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }, + }) + if err != nil { + return nil, err + } + + stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) + stateEntriesMap := stateEntryListMap(stateEntryLists) + result := make(map[string][]types.StateEntry) for _, stateBlockNIDList := range stateBlockNIDLists { - // Query the membership event for the user at the given stateblocks - stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{ - { - EventTypeNID: types.MRoomMemberNID, - EventStateKeyNID: stateKeyNID, - }, - }) - if err != nil { - return nil, err + stateBlockNIDs, ok := stateBlockNIDsMap.lookup(stateBlockNIDList.StateSnapshotNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + return nil, fmt.Errorf("corrupt DB: Missing state snapshot numeric ID %d", stateBlockNIDList.StateSnapshotNID) } - evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + for _, stateBlockNID := range stateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + return nil, fmt.Errorf("corrupt DB: Missing state block numeric ID %d", stateBlockNID) + } - for _, evID := range evIDs { - for _, x := range stateEntryLists { - result[evID] = append(result[evID], x.StateEntries...) + evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + + for _, evID := range evIDs { + result[evID] = append(result[evID], entries...) } } } @@ -944,7 +955,6 @@ func (v *StateResolution) resolveConflictsV2( authSets := make(map[string][]*gomatrixserverlib.Event, len(conflicted)) authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3) - authDifference := make([]*gomatrixserverlib.Event, 0, estimate) knownAuthEvents := make(map[string]types.Event, estimate*3) // For each conflicted event, let's try and get the needed auth events. @@ -992,41 +1002,6 @@ func (v *StateResolution) resolveConflictsV2( // longer need this after this point. gotAuthEvents = nil // nolint:ineffassign - // This function helps us to work out whether an event exists in one of the - // auth sets. - isInAuthList := func(k string, event *gomatrixserverlib.Event) bool { - for _, e := range authSets[k] { - if e.EventID() == event.EventID() { - return true - } - } - return false - } - - // This function works out if an event exists in all of the auth sets. - isInAllAuthLists := func(event *gomatrixserverlib.Event) bool { - for k := range authSets { - if !isInAuthList(k, event) { - return false - } - } - return true - } - - // Look through all of the auth events that we've been given and work out if - // there are any events which don't appear in all of the auth sets. If they - // don't then we add them to the auth difference. - func() { - span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists") - defer span.Finish() - - for _, event := range authEvents { - if !isInAllAuthLists(event) { - authDifference = append(authDifference, event) - } - } - }() - // Resolve the conflicts. resolvedEvents := func() []*gomatrixserverlib.Event { span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2") @@ -1036,7 +1011,6 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, - authDifference, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 094537948..c39a8cbba 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -72,6 +72,7 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1e7ca7669..9b5ed6eda 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,11 +22,12 @@ import ( "sort" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const eventsSchema = ` @@ -80,6 +81,9 @@ const insertEventSQL = "" + const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +const bulkSelectSnapshotsForEventIDsSQL = "" + + "SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id = ANY($1)" + // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. // This means we can use binary search to lookup entries by type and state key. @@ -150,6 +154,7 @@ const selectEventRejectedSQL = "" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt + bulkSelectSnapshotsForEventIDsStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt bulkSelectStateEventByNIDStmt *sql.Stmt @@ -179,6 +184,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, @@ -230,6 +236,29 @@ func (s *eventStatements) SelectEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } +func (s *eventStatements) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectSnapshotsForEventIDsStmt) + + rows, err := stmt.QueryContext(ctx, pq.Array(eventIDs)) + if err != nil { + return nil, err + } + + var eventID string + var stateNID types.StateSnapshotNID + result := make(map[types.StateSnapshotNID][]string) + for rows.Next() { + if err := rows.Scan(&eventID, &stateNID); err != nil { + return nil, err + } + result[stateNID] = append(result[stateNID], eventID) + } + + return result, rows.Err() +} + // bulkSelectStateEventByID lookups a list of state events by event ID. // If not excluding rejected events, and any of the requested events are missing from // the database it returns a types.MissingEventError. If excluding rejected events, diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 42c0c8f2d..cc880a6c8 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -5,8 +5,9 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/roomserver/types" ) type RoomUpdater struct { @@ -186,6 +187,10 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } +func (u *RoomUpdater) BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error) { + return u.d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, u.txn, eventIDs) +} + func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index ed86280bf..4455ec3bf 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -469,6 +469,23 @@ func (d *Database) events( eventNIDs = append(eventNIDs, nid) } } + // If we don't need to get any events from the database, short circuit now + if len(eventNIDs) == 0 { + results := make([]types.Event, 0, len(inputEventNIDs)) + for _, nid := range inputEventNIDs { + event, ok := events[nid] + if !ok || event == nil { + return nil, fmt.Errorf("event %d missing", nid) + } + results = append(results, types.Event{ + EventNID: nid, + Event: event, + }) + } + if !redactionsArePermanent { + d.applyRedactions(results) + } + } eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err @@ -534,6 +551,12 @@ func (d *Database) events( return results, nil } +func (d *Database) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + return d.EventsTable.BulkSelectSnapshotsFromEventIDs(ctx, nil, eventIDs) +} + func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, diff --git a/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go b/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go index cd923b1c1..410fb7cf6 100644 --- a/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go +++ b/roomserver/storage/sqlite3/deltas/20221027084407_published_appservice.go @@ -24,8 +24,8 @@ func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_published RENAME TO roomserver_published_tmp; CREATE TABLE IF NOT EXISTS roomserver_published ( room_id TEXT NOT NULL, - appservice_id TEXT NOT NULL, - network_id TEXT NOT NULL, + appservice_id TEXT NOT NULL DEFAULT '', + network_id TEXT NOT NULL DEFAULT '', published BOOLEAN NOT NULL DEFAULT false, CONSTRAINT unique_published_idx PRIMARY KEY (room_id, appservice_id, network_id) ); diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 950d03b03..f39b9902d 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -23,11 +23,12 @@ import ( "sort" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const eventsSchema = ` @@ -57,6 +58,9 @@ const insertEventSQL = ` const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" +const bulkSelectSnapshotsForEventIDsSQL = "" + + "SELECT event_id, state_snapshot_nid FROM roomserver_events WHERE event_id IN ($1)" + // Bulk lookup of events by string ID. // Sort by the numeric IDs for event type and state key. // This means we can use binary search to lookup entries by type and state key. @@ -124,6 +128,7 @@ type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt + bulkSelectSnapshotsForEventIDsStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDExcludingRejectedStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt @@ -153,6 +158,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { return s, sqlutil.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectSnapshotsForEventIDsStmt, bulkSelectSnapshotsForEventIDsSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDExcludingRejectedStmt, bulkSelectStateEventByIDExcludingRejectedSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, @@ -203,6 +209,40 @@ func (s *eventStatements) SelectEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } +func (s *eventStatements) BulkSelectSnapshotsFromEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[types.StateSnapshotNID][]string, error) { + qry := strings.Replace(bulkSelectSnapshotsForEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "BulkSelectSnapshotsFromEventIDs: stmt.close() failed") + + params := make([]interface{}, len(eventIDs)) + for i := range eventIDs { + params[i] = eventIDs[i] + } + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "BulkSelectSnapshotsFromEventIDs: rows.close() failed") + + var eventID string + var stateNID types.StateSnapshotNID + result := make(map[types.StateSnapshotNID][]string) + for rows.Next() { + if err := rows.Scan(&eventID, &stateNID); err != nil { + return nil, err + } + result[stateNID] = append(result[stateNID], eventID) + } + + return result, rows.Err() +} + // bulkSelectStateEventByID lookups a list of state events by event ID. // If not excluding rejected events, and any of the requested events are missing from // the database it returns a types.MissingEventError. If excluding rejected events, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 8d6ca324c..50d27c756 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -44,6 +44,7 @@ type Events interface { referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) + BulkSelectSnapshotsFromEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[types.StateSnapshotNID][]string, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 0ed164c7e..095a868c7 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -93,7 +93,6 @@ func Context( } stateFilter := gomatrixserverlib.StateFilter{ - Limit: 100, NotSenders: filter.NotSenders, NotTypes: filter.NotTypes, Senders: filter.Senders, diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 05c7deef0..3fcc3235c 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -135,6 +135,6 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatSync)}, + JSON: getMembershipResponse{gomatrixserverlib.HeaderedToClientEvents(result, gomatrixserverlib.FormatAll)}, } } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index aef355def..081ec6cb1 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -294,7 +294,7 @@ type SearchRequest struct { BeforeLimit int `json:"before_limit,omitempty"` IncludeProfile bool `json:"include_profile,omitempty"` } `json:"event_context"` - Filter gomatrixserverlib.StateFilter `json:"filter"` + Filter gomatrixserverlib.RoomEventFilter `json:"filter"` Groupings struct { GroupBy []struct { Key string `json:"key"` diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 2ccf0be1a..48ed20021 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -91,8 +91,7 @@ const selectCurrentStateSQL = "" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + " AND ( $6::bool IS NULL OR contains_url = $6 )" + - " AND (event_id = ANY($7)) IS NOT TRUE" + - " LIMIT $8" + " AND (event_id = ANY($7)) IS NOT TRUE" const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" @@ -290,7 +289,6 @@ func (s *currentRoomStateStatements) SelectCurrentState( pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, pq.StringArray(excludeEventIDs), - stateFilter.Limit, ) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 0ecbdf4d2..3b69b26f6 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -144,8 +144,7 @@ const selectStateInRangeFilteredSQL = "" + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " AND ( $8::bool IS NULL OR contains_url = $8 )" + - " ORDER BY id ASC" + - " LIMIT $9" + " ORDER BY id ASC" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). const selectStateInRangeSQL = "" + @@ -153,8 +152,7 @@ const selectStateInRangeSQL = "" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " AND room_id = ANY($3)" + - " ORDER BY id ASC" + - " LIMIT $4" + " ORDER BY id ASC" const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -264,13 +262,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange( pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)), stateFilter.ContainsURL, - stateFilter.Limit, ) } else { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err = stmt.QueryContext( ctx, r.Low(), r.High(), pq.StringArray(roomIDs), - r.High()-r.Low(), ) } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ff45e786e..7a381f68b 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -277,7 +277,8 @@ func (s *currentRoomStateStatements) SelectCurrentState( }, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone, + excludeEventIDs, stateFilter.ContainsURL, 0, + FilterOrderNone, ) if err != nil { return nil, fmt.Errorf("s.prepareWithFilters: %w", err) diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go index 05edb7b8c..17a37a2df 100644 --- a/syncapi/storage/sqlite3/filtering.go +++ b/syncapi/storage/sqlite3/filtering.go @@ -84,8 +84,10 @@ func prepareWithFilters( case FilterOrderDesc: query += " ORDER BY id DESC" } - query += fmt.Sprintf(" LIMIT $%d", offset+1) - params = append(params, limit) + if limit > 0 { + query += fmt.Sprintf(" LIMIT $%d", offset+1) + params = append(params, limit) + } var stmt *sql.Stmt var err error diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 77c692ff0..1aa4bfff7 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -200,7 +200,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange( s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, - nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc, + nil, stateFilter.ContainsURL, 0, FilterOrderAsc, ) } else { stmt, params, err = prepareWithFilters( diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 700f25c10..e4de30e1c 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/base64" + "math" "strconv" "time" @@ -74,12 +75,14 @@ func (p *InviteStreamProvider) IncrementalSync( return to } for roomID := range retiredInvites { - if _, ok := req.Response.Rooms.Invite[roomID]; ok { - continue - } - if _, ok := req.Response.Rooms.Join[roomID]; ok { + membership, _, err := snapshot.SelectMembershipForUser(ctx, roomID, req.Device.UserID, math.MaxInt64) + // Skip if the user is an existing member of the room. + // Otherwise, the NewLeaveResponse will eject the user from the room unintentionally + if membership == gomatrixserverlib.Join || + err != nil { continue } + lr := types.NewLeaveResponse() h := sha256.Sum256(append([]byte(roomID), []byte(strconv.FormatInt(int64(to), 10))...)) lr.Timeline.Events = append(lr.Timeline.Events, gomatrixserverlib.ClientEvent{ @@ -93,7 +96,6 @@ func (p *InviteStreamProvider) IncrementalSync( Content: gomatrixserverlib.RawJSON(`{"membership":"leave"}`), }) req.Response.Rooms.Leave[roomID] = lr - } return maxID diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 90f401481..5ea2732f4 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -301,7 +301,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } // Applies the history visibility rules - events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) + events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -321,10 +321,14 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( ) if len(delta.StateEvents) > 0 { - updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) + if last := delta.StateEvents[len(delta.StateEvents)-1]; last != nil { + updateLatestPosition(last.EventID()) + } } if len(events) > 0 { - updateLatestPosition(events[len(events)-1].EventID()) + if last := events[len(events)-1]; last != nil { + updateLatestPosition(last.EventID()) + } } switch delta.Membership { @@ -374,12 +378,12 @@ func applyHistoryVisibilityFilter( snapshot storage.DatabaseTransaction, rsAPI roomserverAPI.SyncRoomserverAPI, roomID, userID string, - limit int, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { // We need to make sure we always include the latest states events, if they are in the timeline. // We grep at least limit * 2 events, to ensure we really get the needed events. - stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) + filter := gomatrixserverlib.DefaultStateFilter() + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) if err != nil { // Not a fatal error, we can continue without the stateEvents, // they are only needed if there are state events in the timeline. @@ -517,7 +521,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( events := recentEvents // Only apply history visibility checks if the response is for joined rooms if !isPeek { - events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) + events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -597,7 +601,6 @@ func (p *PDUStreamProvider) lazyLoadMembers( } // Query missing membership events filter := gomatrixserverlib.DefaultStateFilter() - filter.Limit = stateFilter.Limit filter.Senders = &wantUsers filter.Types = &[]string{gomatrixserverlib.MRoomMember} memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 620dfdcdb..e5e5fdb5b 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -79,7 +79,6 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat // for the rest of the data to trickle down. filter.AccountData.Limit = math.MaxInt32 filter.Room.AccountData.Limit = math.MaxInt32 - filter.Room.State.Limit = math.MaxInt32 } logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index a12876946..97c17e188 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "time" "github.com/matrix-org/gomatrixserverlib" @@ -23,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + userAPITypes "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/dendrite/userapi/util" ) @@ -36,6 +38,11 @@ type OutputRoomEventConsumer struct { topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI + msgCounts map[gomatrixserverlib.ServerName]userAPITypes.MessageStats + roomCounts map[gomatrixserverlib.ServerName]map[string]bool // map from serverName to map from rommID to "isEncrypted" + lastUpdate time.Time + countsLock sync.Mutex + serverName gomatrixserverlib.ServerName } func NewOutputRoomEventConsumer( @@ -57,6 +64,11 @@ func NewOutputRoomEventConsumer( pgClient: pgClient, rsAPI: rsAPI, syncProducer: syncProducer, + msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, + roomCounts: map[gomatrixserverlib.ServerName]map[string]bool{}, + lastUpdate: time.Now(), + countsLock: sync.Mutex{}, + serverName: cfg.Matrix.ServerName, } } @@ -88,6 +100,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } + if s.cfg.Matrix.ReportStats.Enabled { + go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) + } + log.WithFields(log.Fields{ "event_id": event.EventID(), "event_type": event.Type(), @@ -107,6 +123,68 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return true } +func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventType, eventSender, roomID string) { + s.countsLock.Lock() + defer s.countsLock.Unlock() + + // reset the roomCounts on a day change + if s.lastUpdate.Day() != time.Now().Day() { + s.roomCounts[s.serverName] = make(map[string]bool) + s.lastUpdate = time.Now() + } + + _, sender, err := gomatrixserverlib.SplitID('@', eventSender) + if err != nil { + return + } + msgCount := s.msgCounts[s.serverName] + roomCount := s.roomCounts[s.serverName] + if roomCount == nil { + roomCount = make(map[string]bool) + } + switch eventType { + case "m.room.message": + roomCount[roomID] = false + msgCount.Messages++ + if sender == s.serverName { + msgCount.SentMessages++ + } + case "m.room.encrypted": + roomCount[roomID] = true + msgCount.MessagesE2EE++ + if sender == s.serverName { + msgCount.SentMessagesE2EE++ + } + default: + return + } + s.msgCounts[s.serverName] = msgCount + s.roomCounts[s.serverName] = roomCount + + for serverName, stats := range s.msgCounts { + var normalRooms, encryptedRooms int64 = 0, 0 + for _, isEncrypted := range s.roomCounts[s.serverName] { + if isEncrypted { + encryptedRooms++ + } else { + normalRooms++ + } + } + err := s.db.UpsertDailyRoomsMessages(ctx, serverName, stats, normalRooms, encryptedRooms) + if err != nil { + log.WithError(err).Errorf("failed to upsert daily messages") + } + // Clear stats if we successfully stored it + if err == nil { + stats.Messages = 0 + stats.SentMessages = 0 + stats.MessagesE2EE = 0 + stats.SentMessagesE2EE = 0 + s.msgCounts[serverName] = stats + } + } +} + func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index e4587670f..265e3a3aa 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -2,7 +2,10 @@ package consumers import ( "context" + "reflect" + "sync" "testing" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/stretchr/testify/assert" @@ -12,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/storage" + userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -132,3 +136,122 @@ func Test_evaluatePushRules(t *testing.T) { } }) } + +func TestMessageStats(t *testing.T) { + type args struct { + eventType string + eventSender string + roomID string + } + tests := []struct { + name string + args args + ourServer gomatrixserverlib.ServerName + lastUpdate time.Time + initRoomCounts map[gomatrixserverlib.ServerName]map[string]bool + wantStats userAPITypes.MessageStats + }{ + { + name: "m.room.create does not count as a message", + ourServer: "localhost", + args: args{ + eventType: "m.room.create", + eventSender: "@alice:localhost", + }, + }, + { + name: "our server - message", + ourServer: "localhost", + args: args{ + eventType: "m.room.message", + eventSender: "@alice:localhost", + roomID: "normalRoom", + }, + wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1}, + }, + { + name: "our server - E2EE message", + ourServer: "localhost", + args: args{ + eventType: "m.room.encrypted", + eventSender: "@alice:localhost", + roomID: "encryptedRoom", + }, + wantStats: userAPITypes.MessageStats{Messages: 1, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1}, + }, + + { + name: "remote server - message", + ourServer: "localhost", + args: args{ + eventType: "m.room.message", + eventSender: "@alice:remote", + roomID: "normalRoom", + }, + wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 1, SentMessagesE2EE: 1}, + }, + { + name: "remote server - E2EE message", + ourServer: "localhost", + args: args{ + eventType: "m.room.encrypted", + eventSender: "@alice:remote", + roomID: "encryptedRoom", + }, + wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 2, SentMessagesE2EE: 1}, + }, + { + name: "day change creates a new room map", + ourServer: "localhost", + lastUpdate: time.Now().Add(-time.Hour * 24), + initRoomCounts: map[gomatrixserverlib.ServerName]map[string]bool{ + "localhost": {"encryptedRoom": true}, + }, + args: args{ + eventType: "m.room.encrypted", + eventSender: "@alice:remote", + roomID: "someOtherRoom", + }, + wantStats: userAPITypes.MessageStats{Messages: 2, SentMessages: 1, MessagesE2EE: 3, SentMessagesE2EE: 1}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.lastUpdate.IsZero() { + tt.lastUpdate = time.Now() + } + if tt.initRoomCounts == nil { + tt.initRoomCounts = map[gomatrixserverlib.ServerName]map[string]bool{} + } + s := &OutputRoomEventConsumer{ + db: db, + msgCounts: map[gomatrixserverlib.ServerName]userAPITypes.MessageStats{}, + roomCounts: tt.initRoomCounts, + countsLock: sync.Mutex{}, + lastUpdate: tt.lastUpdate, + serverName: tt.ourServer, + } + s.storeMessageStats(context.Background(), tt.args.eventType, tt.args.eventSender, tt.args.roomID) + t.Logf("%+v", s.roomCounts) + gotStats, activeRooms, activeE2EERooms, err := db.DailyRoomsMessages(context.Background(), tt.ourServer) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !reflect.DeepEqual(gotStats, tt.wantStats) { + t.Fatalf("expected %+v, got %+v", tt.wantStats, gotStats) + } + if tt.args.eventType == "m.room.encrypted" && activeE2EERooms != 1 { + t.Fatalf("expected room to be activeE2EE") + } + if tt.args.eventType == "m.room.message" && activeRooms != 1 { + t.Fatalf("expected room to be active") + } + }) + } + }) +} diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index fb12b53af..28ef26559 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -19,6 +19,8 @@ import ( "encoding/json" "errors" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/userapi/api" @@ -144,6 +146,8 @@ type Database interface { type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) + DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) + UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/stats_table.go b/userapi/storage/postgres/stats_table.go index 20eb0bf46..f62467fa4 100644 --- a/userapi/storage/postgres/stats_table.go +++ b/userapi/storage/postgres/stats_table.go @@ -20,13 +20,14 @@ import ( "time" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" ) const userDailyVisitsSchema = ` @@ -43,6 +44,35 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_timestamp_idx ON userapi_daily_v CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON userapi_daily_visits(localpart, timestamp); ` +const messagesDailySchema = ` +CREATE TABLE IF NOT EXISTS userapi_daily_stats ( + timestamp BIGINT NOT NULL, + server_name TEXT NOT NULL, + messages BIGINT NOT NULL, + sent_messages BIGINT NOT NULL, + e2ee_messages BIGINT NOT NULL, + sent_e2ee_messages BIGINT NOT NULL, + active_rooms BIGINT NOT NULL, + active_e2ee_rooms BIGINT NOT NULL, + CONSTRAINT daily_stats_unique UNIQUE (timestamp, server_name) +); +` + +const upsertDailyMessagesSQL = ` + INSERT INTO userapi_daily_stats AS u (timestamp, server_name, messages, sent_messages, e2ee_messages, sent_e2ee_messages, active_rooms, active_e2ee_rooms) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT ON CONSTRAINT daily_stats_unique + DO UPDATE SET + messages=u.messages+excluded.messages, sent_messages=u.sent_messages+excluded.sent_messages, + e2ee_messages=u.e2ee_messages+excluded.e2ee_messages, sent_e2ee_messages=u.sent_e2ee_messages+excluded.sent_e2ee_messages, + active_rooms=GREATEST($7, u.active_rooms), active_e2ee_rooms=GREATEST($8, u.active_e2ee_rooms) +` + +const selectDailyMessagesSQL = ` + SELECT messages, sent_messages, e2ee_messages, sent_e2ee_messages, active_rooms, active_e2ee_rooms + FROM userapi_daily_stats + WHERE server_name = $1 AND timestamp = $2; +` + const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + @@ -170,6 +200,8 @@ type statsStatements struct { countUserByAccountTypeStmt *sql.Stmt countRegisteredUserByTypeStmt *sql.Stmt dbEngineVersionStmt *sql.Stmt + upsertMessagesStmt *sql.Stmt + selectDailyMessagesStmt *sql.Stmt } func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { @@ -182,6 +214,10 @@ func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) if err != nil { return nil, err } + _, err = db.Exec(messagesDailySchema) + if err != nil { + return nil, err + } go s.startTimers() return s, sqlutil.StatementList{ {&s.countUsersLastSeenAfterStmt, countUsersLastSeenAfterSQL}, @@ -191,6 +227,8 @@ func NewPostgresStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) {&s.countUserByAccountTypeStmt, countUserByAccountTypeSQL}, {&s.countRegisteredUserByTypeStmt, countRegisteredUserByTypeStmt}, {&s.dbEngineVersionStmt, queryDBEngineVersion}, + {&s.upsertMessagesStmt, upsertDailyMessagesSQL}, + {&s.selectDailyMessagesStmt, selectDailyMessagesSQL}, }.Prepare(db) } @@ -435,3 +473,34 @@ func (s *statsStatements) UpdateUserDailyVisits( } return err } + +func (s *statsStatements) UpsertDailyStats( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, stats types.MessageStats, + activeRooms, activeE2EERooms int64, +) error { + stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) + timestamp := time.Now().Truncate(time.Hour * 24) + _, err := stmt.ExecContext(ctx, + gomatrixserverlib.AsTimestamp(timestamp), + serverName, + stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, + activeRooms, activeE2EERooms, + ) + return err +} + +func (s *statsStatements) DailyRoomsMessages( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { + stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) + timestamp := time.Now().Truncate(time.Hour * 24) + + err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) + if err != nil && err != sql.ErrNoRows { + return msgStats, 0, 0, err + } + return msgStats, activeRooms, activeE2EERooms, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f8b6ad311..f8b8d02c9 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,13 +29,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" ) // Database represents an account database @@ -808,3 +807,15 @@ func (d *Database) RemovePushers( func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) { return d.Stats.UserStatistics(ctx, nil) } + +func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Stats.UpsertDailyStats(ctx, txn, serverName, stats, activeRooms, activeE2EERooms) + }) +} + +func (d *Database) DailyRoomsMessages( + ctx context.Context, serverName gomatrixserverlib.ServerName, +) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { + return d.Stats.DailyRoomsMessages(ctx, nil, serverName) +} diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index 35e3c653e..a1365c944 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -44,6 +44,35 @@ CREATE INDEX IF NOT EXISTS userapi_daily_visits_timestamp_idx ON userapi_daily_v CREATE INDEX IF NOT EXISTS userapi_daily_visits_localpart_timestamp_idx ON userapi_daily_visits(localpart, timestamp); ` +const messagesDailySchema = ` +CREATE TABLE IF NOT EXISTS userapi_daily_stats ( + timestamp BIGINT NOT NULL, + server_name TEXT NOT NULL, + messages BIGINT NOT NULL, + sent_messages BIGINT NOT NULL, + e2ee_messages BIGINT NOT NULL, + sent_e2ee_messages BIGINT NOT NULL, + active_rooms BIGINT NOT NULL, + active_e2ee_rooms BIGINT NOT NULL, + CONSTRAINT daily_stats_unique UNIQUE (timestamp, server_name) +); +` + +const upsertDailyMessagesSQL = ` + INSERT INTO userapi_daily_stats (timestamp, server_name, messages, sent_messages, e2ee_messages, sent_e2ee_messages, active_rooms, active_e2ee_rooms) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (timestamp, server_name) + DO UPDATE SET + messages=messages+excluded.messages, sent_messages=sent_messages+excluded.sent_messages, + e2ee_messages=e2ee_messages+excluded.e2ee_messages, sent_e2ee_messages=sent_e2ee_messages+excluded.sent_e2ee_messages, + active_rooms=MAX($7, active_rooms), active_e2ee_rooms=MAX($8, active_e2ee_rooms) +` + +const selectDailyMessagesSQL = ` + SELECT messages, sent_messages, e2ee_messages, sent_e2ee_messages, active_rooms, active_e2ee_rooms + FROM userapi_daily_stats + WHERE server_name = $1 AND timestamp = $2; +` + const countUsersLastSeenAfterSQL = "" + "SELECT COUNT(*) FROM (" + " SELECT localpart FROM userapi_devices WHERE last_seen_ts > $1 " + @@ -176,6 +205,8 @@ type statsStatements struct { countUserByAccountTypeStmt *sql.Stmt countRegisteredUserByTypeStmt *sql.Stmt dbEngineVersionStmt *sql.Stmt + upsertMessagesStmt *sql.Stmt + selectDailyMessagesStmt *sql.Stmt } func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.StatsTable, error) { @@ -189,6 +220,10 @@ func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (t if err != nil { return nil, err } + _, err = db.Exec(messagesDailySchema) + if err != nil { + return nil, err + } go s.startTimers() return s, sqlutil.StatementList{ {&s.countUsersLastSeenAfterStmt, countUsersLastSeenAfterSQL}, @@ -198,6 +233,8 @@ func NewSQLiteStatsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (t {&s.countUserByAccountTypeStmt, countUserByAccountTypeSQL}, {&s.countRegisteredUserByTypeStmt, countRegisteredUserByTypeSQL}, {&s.dbEngineVersionStmt, queryDBEngineVersion}, + {&s.upsertMessagesStmt, upsertDailyMessagesSQL}, + {&s.selectDailyMessagesStmt, selectDailyMessagesSQL}, }.Prepare(db) } @@ -451,3 +488,34 @@ func (s *statsStatements) UpdateUserDailyVisits( } return err } + +func (s *statsStatements) UpsertDailyStats( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, stats types.MessageStats, + activeRooms, activeE2EERooms int64, +) error { + stmt := sqlutil.TxStmt(txn, s.upsertMessagesStmt) + timestamp := time.Now().Truncate(time.Hour * 24) + _, err := stmt.ExecContext(ctx, + gomatrixserverlib.AsTimestamp(timestamp), + serverName, + stats.Messages, stats.SentMessages, stats.MessagesE2EE, stats.SentMessagesE2EE, + activeRooms, activeE2EERooms, + ) + return err +} + +func (s *statsStatements) DailyRoomsMessages( + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { + stmt := sqlutil.TxStmt(txn, s.selectDailyMessagesStmt) + timestamp := time.Now().Truncate(time.Hour * 24) + + err = stmt.QueryRowContext(ctx, serverName, gomatrixserverlib.AsTimestamp(timestamp)). + Scan(&msgStats.Messages, &msgStats.SentMessages, &msgStats.MessagesE2EE, &msgStats.SentMessagesE2EE, &activeRooms, &activeE2EERooms) + if err != nil && err != sql.ErrNoRows { + return msgStats, 0, 0, err + } + return msgStats, activeRooms, activeE2EERooms, nil +} diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 1b239e442..5e1dd0971 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,6 +20,8 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" @@ -115,7 +117,9 @@ type NotificationTable interface { type StatsTable interface { UserStatistics(ctx context.Context, txn *sql.Tx) (*types.UserStatistics, *types.DatabaseEngine, error) + DailyRoomsMessages(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (msgStats types.MessageStats, activeRooms, activeE2EERooms int64, err error) UpdateUserDailyVisits(ctx context.Context, txn *sql.Tx, startTime, lastUpdate time.Time) error + UpsertDailyStats(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error } type NotificationFilter uint32 diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index c4aec552c..a547423bc 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -8,6 +8,9 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" @@ -16,8 +19,6 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) func mustMakeDBs(t *testing.T, dbType test.DBType) ( @@ -227,7 +228,7 @@ func Test_UserStatistics(t *testing.T) { mustUserUpdateRegistered(t, ctx, db, "user4", time.Now().AddDate(0, -2, 0)) mustUpdateDeviceLastSeen(t, ctx, db, "user4", time.Now()) startTime := time.Now().AddDate(0, 0, -2) - err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24).Add(time.Hour)) + err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24)) if err != nil { t.Fatalf("unable to update daily visits stats: %v", err) } @@ -278,7 +279,7 @@ func Test_UserStatistics(t *testing.T) { mustUpdateDeviceLastSeen(t, ctx, db, "user1", time.Now().AddDate(0, 0, -i)) mustUpdateDeviceLastSeen(t, ctx, db, "user5", time.Now().AddDate(0, 0, -i)) startTime := time.Now().AddDate(0, 0, -i) - err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24).Add(time.Hour)) + err := statsDB.UpdateUserDailyVisits(ctx, nil, startTime, startTime.Truncate(time.Hour*24)) if err != nil { t.Fatalf("unable to update daily visits stats: %v", err) } diff --git a/userapi/types/statistics.go b/userapi/types/statistics.go index 09564f78f..b74e32add 100644 --- a/userapi/types/statistics.go +++ b/userapi/types/statistics.go @@ -28,3 +28,10 @@ type DatabaseEngine struct { Engine string Version string } + +type MessageStats struct { + Messages int64 + SentMessages int64 + MessagesE2EE int64 + SentMessagesE2EE int64 +} diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index b17f62060..6f36568c9 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -24,11 +24,12 @@ import ( "syscall" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" ) type phoneHomeStats struct { @@ -109,12 +110,19 @@ func (p *phoneHomeStats) collect() { } // message and room stats - // TODO: Find a solution to actually set these values + // TODO: Find a solution to actually set this value p.stats["total_room_count"] = 0 - p.stats["daily_messages"] = 0 - p.stats["daily_sent_messages"] = 0 - p.stats["daily_e2ee_messages"] = 0 - p.stats["daily_sent_e2ee_messages"] = 0 + + messageStats, activeRooms, activeE2EERooms, err := p.db.DailyRoomsMessages(ctx, p.serverName) + if err != nil { + logrus.WithError(err).Warn("unable to query message stats, using default values") + } + p.stats["daily_messages"] = messageStats.Messages + p.stats["daily_sent_messages"] = messageStats.SentMessages + p.stats["daily_e2ee_messages"] = messageStats.MessagesE2EE + p.stats["daily_sent_e2ee_messages"] = messageStats.SentMessagesE2EE + p.stats["daily_active_rooms"] = activeRooms + p.stats["daily_active_e2ee_rooms"] = activeE2EERooms // user stats and DB engine userStats, db, err := p.db.UserStatistics(ctx)