diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml deleted file mode 100644 index 9d755a244..000000000 --- a/.buildkite/pipeline.yaml +++ /dev/null @@ -1,49 +0,0 @@ -steps: - - command: - # https://github.com/golangci/golangci-lint#memory-usage-of-golangci-lint - - "GOGC=20 ./scripts/find-lint.sh" - label: "\U0001F9F9 Lint / :go: 1.12" - agents: - # Use a larger instance as linting takes a looot of memory - queue: "medium" - plugins: - - docker#v3.0.1: - image: "golang:1.12" - - - wait - - - command: - - "go build ./cmd/..." - label: "\U0001F528 Build / :go: 1.11" - plugins: - - docker#v3.0.1: - image: "golang:1.11" - retry: - automatic: - - exit_status: 128 - limit: 3 - - - command: - - "go build ./cmd/..." - label: "\U0001F528 Build / :go: 1.12" - plugins: - - docker#v3.0.1: - image: "golang:1.12" - retry: - automatic: - - exit_status: 128 - limit: 3 - - - command: - - "go test ./..." - label: "\U0001F9EA Unit tests / :go: 1.11" - plugins: - - docker#v3.0.1: - image: "golang:1.11" - - - command: - - "go test ./..." - label: "\U0001F9EA Unit tests / :go: 1.12" - plugins: - - docker#v3.0.1: - image: "golang:1.12" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 22ad0586f..dc962fee7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,39 @@ should pick up any unit test and run it). There are also [scripts](scripts) for [linting](scripts/find-lint.sh) and doing a [build/test/lint run](scripts/build-test-lint.sh). +## Continuous Integration + +When a Pull Request is submitted, continuous integration jobs are run +automatically to ensure the code builds and is relatively well-written. Checks +are run on [Buildkite](https://buildkite.com/matrix-dot-org/dendrite/) and +[CircleCI](https://circleci.com/gh/matrix-org/dendrite/). + +If a job fails, click the "details" button and you should be taken to the job's +logs. + +![Click the details button on the failing build step](docs/images/details-button-location.jpg) + +Scroll down to the failing step and you should see some log output. Scan +the logs until you find what it's complaining about, fix it, submit a new +commit, then rinse and repeat until CI passes. + +### Running CI Tests Locally + +To save waiting for CI to finish after every commit, it is ideal to run the +checks locally before pushing, fixing errors first. This also saves other +people time as only so many PRs can be tested at a given time. + +To execute what Buildkite tests, simply run `./scripts/build-test-lint.sh`. +This script will build the code, lint it, and run `go test ./...` with race +condition checking enabled. If something needs to be changed, fix it and then +run the script again until it no longer complains. Be warned that the linting +can take a significant amount of CPU and RAM. + +CircleCI simply runs [Sytest](https://github.com/matrix-org/sytest) with a test +whitelist. See +[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md#using-a-sytest-docker-image) +for instructions on setting it up to run locally. + ## Picking Things To Do diff --git a/INSTALL.md b/INSTALL.md index 82f7f00af..0fb0c08e5 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -35,10 +35,10 @@ cd dendrite If using Kafka, install and start it (c.f. [scripts/install-local-kafka.sh](scripts/install-local-kafka.sh)): ```bash -MIRROR=http://apache.mirror.anlx.net/kafka/0.10.2.0/kafka_2.11-0.10.2.0.tgz +KAFKA_URL=http://archive.apache.org/dist/kafka/2.1.0/kafka_2.11-2.1.0.tgz # Only download the kafka if it isn't already downloaded. -test -f kafka.tgz || wget $MIRROR -O kafka.tgz +test -f kafka.tgz || wget $KAFKA_URL -O kafka.tgz # Unpack the kafka over the top of any existing installation mkdir -p kafka && tar xzf kafka.tgz -C kafka --strip-components 1 diff --git a/README.md b/README.md index 8eadaf431..4e628c0ff 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg)](https://buildkite.com/matrix-dot-org/dendrite) [![CircleCI](https://circleci.com/gh/matrix-org/dendrite.svg?style=svg)](https://circleci.com/gh/matrix-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) +# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![CircleCI](https://circleci.com/gh/matrix-org/dendrite.svg?style=svg)](https://circleci.com/gh/matrix-org/dendrite) [![Dendrite Dev on Matrix](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) [![Dendrite on Matrix](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) Dendrite will be a matrix homeserver written in go. diff --git a/appservice/api/query.go b/appservice/api/query.go index 9ec214486..9542df565 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -20,13 +20,13 @@ package api import ( "context" "database/sql" - "errors" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/common" commonHTTP "github.com/matrix-org/dendrite/common/http" opentracing "github.com/opentracing/opentracing-go" ) @@ -134,9 +134,9 @@ func (h *httpAppServiceQueryAPI) UserIDExists( return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// RetreiveUserProfile is a wrapper that queries both the local database and +// RetrieveUserProfile is a wrapper that queries both the local database and // application services for a given user's profile -func RetreiveUserProfile( +func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceQueryAPI, @@ -164,7 +164,7 @@ func RetreiveUserProfile( // If no user exists, return if !userResp.UserIDExists { - return nil, errors.New("no known profile for given user ID") + return nil, common.ErrProfileNoExists } // Try to query the user from the local database again diff --git a/clientapi/auth/authtypes/profile.go b/clientapi/auth/authtypes/profile.go index 6cf508f4f..0bc49658b 100644 --- a/clientapi/auth/authtypes/profile.go +++ b/clientapi/auth/authtypes/profile.go @@ -14,7 +14,7 @@ package authtypes -// Profile represents the profile for a Matrix account on this home server. +// Profile represents the profile for a Matrix account. type Profile struct { Localpart string DisplayName string diff --git a/clientapi/auth/storage/accounts/filter_table.go b/clientapi/auth/storage/accounts/filter_table.go index 81bae4545..2b07ef17e 100644 --- a/clientapi/auth/storage/accounts/filter_table.go +++ b/clientapi/auth/storage/accounts/filter_table.go @@ -17,6 +17,7 @@ package accounts import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/gomatrixserverlib" ) @@ -71,25 +72,44 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { func (s *filterStatements) selectFilter( ctx context.Context, localpart string, filterID string, -) (filter []byte, err error) { - err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter) - return +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + var filter gomatrixserverlib.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil } func (s *filterStatements) insertFilter( - ctx context.Context, filter []byte, localpart string, + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - filterJSON, err := gomatrixserverlib.CanonicalJSON(filter) + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) if err != nil { return "", err } - // Check if filter already exists in the database + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, localpart, filterJSON).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go index 27c0a176a..41d75daad 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/clientapi/auth/storage/accounts/storage.go @@ -230,7 +230,7 @@ func (d *Database) newMembership( } // Only "join" membership events can be considered as new memberships - if membership == "join" { + if membership == gomatrixserverlib.Join { if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { return err } @@ -344,11 +344,11 @@ func (d *Database) GetThreePIDsForLocalpart( } // GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter represented as a byte slice. Otherwise returns an error if -// no such filter exists or if there was an error talking to the database. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. func (d *Database) GetFilter( ctx context.Context, localpart string, filterID string, -) ([]byte, error) { +) (*gomatrixserverlib.Filter, error) { return d.filter.selectFilter(ctx, localpart, filterID) } @@ -356,7 +356,7 @@ func (d *Database) GetFilter( // Returns the filterID as a string. Otherwise returns an error if something // goes wrong. func (d *Database) PutFilter( - ctx context.Context, localpart string, filter []byte, + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, ) (string, error) { return d.filter.insertFilter(ctx, filter, localpart) } diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/devices_table.go index 96d6521d8..60aa563a2 100644 --- a/clientapi/auth/storage/devices/devices_table.go +++ b/clientapi/auth/storage/devices/devices_table.go @@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken( return &dev, err } +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*authtypes.Device, error) { diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go index 7032fe7bf..82c8e97a2 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/clientapi/auth/storage/devices/storage.go @@ -84,7 +84,7 @@ func (d *Database) CreateDevice( if deviceID != nil { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - // Revoke existing token for this device + // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { return err } diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 30e00f723..d57a6d370 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -33,13 +33,6 @@ func SaveAccountData( req *http.Request, accountDB *accounts.Database, device *authtypes.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { - if req.Method != http.MethodPut { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } - } - if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go new file mode 100644 index 000000000..cd4530d1b --- /dev/null +++ b/clientapi/routing/auth_fallback.go @@ -0,0 +1,210 @@ +// Copyright 2019 Parminder Singh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "html/template" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/util" +) + +// recaptchaTemplate is an HTML webpage template for recaptcha auth +const recaptchaTemplate = ` + + +Authentication + + + + + + +
+
+

+ Hello! We need to prevent computer programs and other automated + things from creating accounts on this server. +

+

+ Please verify that you're not a robot. +

+ +
+
+ +
+ +
+ + +` + +// successTemplate is an HTML template presented to the user after successful +// recaptcha completion +const successTemplate = ` + + +Success! + + + + +
+

Thank you!

+

You may now close this window and return to the application.

+
+ + +` + +// serveTemplate fills template data and serves it using http.ResponseWriter +func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]string) { + t := template.Must(template.New("response").Parse(templateHTML)) + if err := t.Execute(w, data); err != nil { + panic(err) + } +} + +// AuthFallback implements GET and POST /auth/{authType}/fallback/web?session={sessionID} +func AuthFallback( + w http.ResponseWriter, req *http.Request, authType string, + cfg config.Dendrite, +) *util.JSONResponse { + sessionID := req.URL.Query().Get("session") + + if sessionID == "" { + return writeHTTPMessage(w, req, + "Session ID not provided", + http.StatusBadRequest, + ) + } + + serveRecaptcha := func() { + data := map[string]string{ + "myUrl": req.URL.String(), + "session": sessionID, + "siteKey": cfg.Matrix.RecaptchaPublicKey, + } + serveTemplate(w, recaptchaTemplate, data) + } + + serveSuccess := func() { + data := map[string]string{} + serveTemplate(w, successTemplate, data) + } + + if req.Method == http.MethodGet { + // Handle Recaptcha + if authType == authtypes.LoginTypeRecaptcha { + if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { + return err + } + + serveRecaptcha() + return nil + } + return &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Unknown auth stage type"), + } + } else if req.Method == http.MethodPost { + // Handle Recaptcha + if authType == authtypes.LoginTypeRecaptcha { + if err := checkRecaptchaEnabled(&cfg, w, req); err != nil { + return err + } + + clientIP := req.RemoteAddr + err := req.ParseForm() + if err != nil { + res := httputil.LogThenError(req, err) + return &res + } + + response := req.Form.Get("g-recaptcha-response") + if err := validateRecaptcha(&cfg, response, clientIP); err != nil { + util.GetLogger(req.Context()).Error(err) + return err + } + + // Success. Add recaptcha as a completed login flow + AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + + serveSuccess() + return nil + } + + return &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Unknown auth stage type"), + } + } + return &util.JSONResponse{ + Code: http.StatusMethodNotAllowed, + JSON: jsonerror.NotFound("Bad method"), + } +} + +// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. +func checkRecaptchaEnabled( + cfg *config.Dendrite, + w http.ResponseWriter, + req *http.Request, +) *util.JSONResponse { + if !cfg.Matrix.RecaptchaEnabled { + return writeHTTPMessage(w, req, + "Recaptcha login is disabled on this Homeserver", + http.StatusBadRequest, + ) + } + return nil +} + +// writeHTTPMessage writes the given header and message to the HTTP response writer. +// Returns an error JSONResponse obtained through httputil.LogThenError if the writing failed, otherwise nil. +func writeHTTPMessage( + w http.ResponseWriter, req *http.Request, + message string, header int, +) *util.JSONResponse { + w.WriteHeader(header) + _, err := w.Write([]byte(message)) + if err != nil { + res := httputil.LogThenError(req, err) + return &res + } + return nil +} diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index a7187c495..620246d28 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -15,6 +15,7 @@ package routing import ( + "encoding/json" "fmt" "net/http" "strings" @@ -54,10 +55,6 @@ const ( presetPublicChat = "public_chat" ) -const ( - joinRulePublic = "public" - joinRuleInvite = "invite" -) const ( historyVisibilityShared = "shared" // TODO: These should be implemented once history visibility is implemented @@ -97,6 +94,27 @@ func (r createRoomRequest) Validate() *util.JSONResponse { } } + // Validate creation_content fields defined in the spec by marshalling the + // creation_content map into bytes and then unmarshalling the bytes into + // common.CreateContent. + + creationContentBytes, err := json.Marshal(r.CreationContent) + if err != nil { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("malformed creation_content"), + } + } + + var CreationContent gomatrixserverlib.CreateContent + err = json.Unmarshal(creationContentBytes, &CreationContent) + if err != nil { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("malformed creation_content"), + } + } + return nil } @@ -154,7 +172,17 @@ func createRoom( JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - // TODO: visibility/presets/raw initial state/creation content + + // Clobber keys: creator, room_version + + if r.CreationContent == nil { + r.CreationContent = make(map[string]interface{}, 2) + } + + r.CreationContent["creator"] = userID + r.CreationContent["room_version"] = "1" // TODO: We set this to 1 before we support Room versioning + + // TODO: visibility/presets/raw initial state // TODO: Create room alias association // Make sure this doesn't fall into an application service's namespace though! @@ -163,13 +191,13 @@ func createRoom( "roomID": roomID, }).Info("Creating new room") - profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) if err != nil { return httputil.LogThenError(req, err) } - membershipContent := common.MemberContent{ - Membership: "join", + membershipContent := gomatrixserverlib.MemberContent{ + Membership: gomatrixserverlib.Join, DisplayName: profile.DisplayName, AvatarURL: profile.AvatarURL, } @@ -177,19 +205,19 @@ func createRoom( var joinRules, historyVisibility string switch r.Preset { case presetPrivateChat: - joinRules = joinRuleInvite + joinRules = gomatrixserverlib.Invite historyVisibility = historyVisibilityShared case presetTrustedPrivateChat: - joinRules = joinRuleInvite + joinRules = gomatrixserverlib.Invite historyVisibility = historyVisibilityShared // TODO If trusted_private_chat, all invitees are given the same power level as the room creator. case presetPublicChat: - joinRules = joinRulePublic + joinRules = gomatrixserverlib.Public historyVisibility = historyVisibilityShared default: // Default room rules, r.Preset was previously checked for valid values so // only a request with no preset should end up here. - joinRules = joinRuleInvite + joinRules = gomatrixserverlib.Invite historyVisibility = historyVisibilityShared } @@ -214,11 +242,11 @@ func createRoom( // harder to reason about, hence sticking to a strict static ordering. // TODO: Synapse has txn/token ID on each event. Do we need to do this here? eventsToMake := []fledglingEvent{ - {"m.room.create", "", common.CreateContent{Creator: userID}}, + {"m.room.create", "", r.CreationContent}, {"m.room.member", userID, membershipContent}, {"m.room.power_levels", "", common.InitialPowerLevelsContent(userID)}, // TODO: m.room.canonical_alias - {"m.room.join_rules", "", common.JoinRulesContent{JoinRule: joinRules}}, + {"m.room.join_rules", "", gomatrixserverlib.JoinRuleContent{JoinRule: joinRules}}, {"m.room.history_visibility", "", common.HistoryVisibilityContent{HistoryVisibility: historyVisibility}}, } if r.GuestCanJoin { diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index cf6f24a7d..c858e88aa 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -106,13 +106,6 @@ func UpdateDeviceByID( req *http.Request, deviceDB *devices.Database, device *authtypes.Device, deviceID string, ) util.JSONResponse { - if req.Method != http.MethodPut { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad Method"), - } - } - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return httputil.LogThenError(req, err) diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index b15082981..371c62f87 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -128,12 +128,16 @@ func SetLocalAlias( // 1. The new method for checking for things matching an AS's namespace // 2. Using an overall Regex object for all AS's just like we did for usernames for _, appservice := range cfg.Derived.ApplicationServices { - if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok { - for _, namespace := range aliasNamespaces { - if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.ASExclusive("Alias is reserved by an application service"), + // Don't prevent AS from creating aliases in its own namespace + // Note that Dendrite uses SenderLocalpart as UserID for AS users + if device.UserID != appservice.SenderLocalpart { + if aliasNamespaces, ok := appservice.NamespaceMap["aliases"]; ok { + for _, namespace := range aliasNamespaces { + if namespace.Exclusive && namespace.RegexpObject.MatchString(alias) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.ASExclusive("Alias is reserved by an application service"), + } } } } @@ -171,13 +175,36 @@ func SetLocalAlias( } // RemoveLocalAlias implements DELETE /directory/room/{roomAlias} -// TODO: Check if the user has the power level to remove an alias func RemoveLocalAlias( req *http.Request, device *authtypes.Device, alias string, aliasAPI roomserverAPI.RoomserverAliasAPI, ) util.JSONResponse { + + creatorQueryReq := roomserverAPI.GetCreatorIDForAliasRequest{ + Alias: alias, + } + var creatorQueryRes roomserverAPI.GetCreatorIDForAliasResponse + if err := aliasAPI.GetCreatorIDForAlias(req.Context(), &creatorQueryReq, &creatorQueryRes); err != nil { + return httputil.LogThenError(req, err) + } + + if creatorQueryRes.UserID == "" { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Alias does not exist"), + } + } + + if creatorQueryRes.UserID != device.UserID { + // TODO: Still allow deletion if user is admin + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You do not have permission to delete this alias"), + } + } + queryReq := roomserverAPI.RemoveRoomAliasRequest{ Alias: alias, UserID: device.UserID, diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go index 109c55da1..eec501ff7 100644 --- a/clientapi/routing/filter.go +++ b/clientapi/routing/filter.go @@ -17,13 +17,10 @@ package routing import ( "net/http" - "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,12 +29,6 @@ import ( func GetFilter( req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, ) util.JSONResponse { - if req.Method != http.MethodGet { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } - } if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, @@ -49,7 +40,7 @@ func GetFilter( return httputil.LogThenError(req, err) } - res, err := accountDB.GetFilter(req.Context(), localpart, filterID) + filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) if err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, @@ -59,11 +50,6 @@ func GetFilter( JSON: jsonerror.NotFound("No such filter"), } } - filter := gomatrix.Filter{} - err = json.Unmarshal(res, &filter) - if err != nil { - httputil.LogThenError(req, err) - } return util.JSONResponse{ Code: http.StatusOK, @@ -79,12 +65,6 @@ type filterResponse struct { func PutFilter( req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, ) util.JSONResponse { - if req.Method != http.MethodPost { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } - } if userID != device.UserID { return util.JSONResponse{ Code: http.StatusForbidden, @@ -97,21 +77,21 @@ func PutFilter( return httputil.LogThenError(req, err) } - var filter gomatrix.Filter + var filter gomatrixserverlib.Filter if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil { return *reqErr } - filterArray, err := json.Marshal(filter) - if err != nil { + // Validate generates a user-friendly error + if err = filter.Validate(); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Filter is malformed"), + JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()), } } - filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray) + filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) if err != nil { return httputil.LogThenError(req, err) } diff --git a/clientapi/routing/getevent.go b/clientapi/routing/getevent.go new file mode 100644 index 000000000..7071d16f0 --- /dev/null +++ b/clientapi/routing/getevent.go @@ -0,0 +1,127 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type getEventRequest struct { + req *http.Request + device *authtypes.Device + roomID string + eventID string + cfg config.Dendrite + federation *gomatrixserverlib.FederationClient + keyRing gomatrixserverlib.KeyRing + requestedEvent gomatrixserverlib.Event +} + +// GetEvent implements GET /_matrix/client/r0/rooms/{roomId}/event/{eventId} +// https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-rooms-roomid-event-eventid +func GetEvent( + req *http.Request, + device *authtypes.Device, + roomID string, + eventID string, + cfg config.Dendrite, + queryAPI api.RoomserverQueryAPI, + federation *gomatrixserverlib.FederationClient, + keyRing gomatrixserverlib.KeyRing, +) util.JSONResponse { + eventsReq := api.QueryEventsByIDRequest{ + EventIDs: []string{eventID}, + } + var eventsResp api.QueryEventsByIDResponse + err := queryAPI.QueryEventsByID(req.Context(), &eventsReq, &eventsResp) + if err != nil { + return httputil.LogThenError(req, err) + } + + if len(eventsResp.Events) == 0 { + // Event not found locally + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + } + } + + requestedEvent := eventsResp.Events[0] + + r := getEventRequest{ + req: req, + device: device, + roomID: roomID, + eventID: eventID, + cfg: cfg, + federation: federation, + keyRing: keyRing, + requestedEvent: requestedEvent, + } + + stateReq := api.QueryStateAfterEventsRequest{ + RoomID: r.requestedEvent.RoomID(), + PrevEventIDs: r.requestedEvent.PrevEventIDs(), + StateToFetch: []gomatrixserverlib.StateKeyTuple{{ + EventType: gomatrixserverlib.MRoomMember, + StateKey: device.UserID, + }}, + } + var stateResp api.QueryStateAfterEventsResponse + if err := queryAPI.QueryStateAfterEvents(req.Context(), &stateReq, &stateResp); err != nil { + return httputil.LogThenError(req, err) + } + + if !stateResp.RoomExists { + util.GetLogger(req.Context()).Errorf("Expected to find room for event %s but failed", r.requestedEvent.EventID()) + return jsonerror.InternalServerError() + } + + if !stateResp.PrevEventsExist { + // Missing some events locally; stateResp.StateEvents unavailable. + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + } + } + + for _, stateEvent := range stateResp.StateEvents { + if stateEvent.StateKeyEquals(r.device.UserID) { + membership, err := stateEvent.Membership() + if err != nil { + return httputil.LogThenError(req, err) + } + if membership == gomatrixserverlib.Join { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: gomatrixserverlib.ToClientEvent(r.requestedEvent, gomatrixserverlib.FormatAll), + } + } + } + } + + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The event was not found or you do not have permission to read this event"), + } +} diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 9c02a93ca..432c982b4 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -70,7 +70,7 @@ func JoinRoomByIDOrAlias( return httputil.LogThenError(req, err) } - content["membership"] = "join" + content["membership"] = gomatrixserverlib.Join content["displayname"] = profile.DisplayName content["avatar_url"] = profile.AvatarURL diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index abcf7f569..02d958152 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -18,7 +18,6 @@ import ( "net/http" "context" - "database/sql" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -42,10 +41,12 @@ type flow struct { } type passwordRequest struct { - User string `json:"user"` - Password string `json:"password"` + User string `json:"user"` + Password string `json:"password"` + // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") + // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` - DeviceID string `json:"device_id"` + DeviceID *string `json:"device_id"` } type loginResponse struct { @@ -107,10 +108,10 @@ func Login( token, err := auth.GenerateAccessToken() if err != nil { - httputil.LogThenError(req, err) + return httputil.LogThenError(req, err) } - dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) + dev, err := getDevice(req.Context(), r, deviceDB, acc, token) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -134,20 +135,16 @@ func Login( } } -// check if device exists else create one +// getDevice returns a new or existing device func getDevice( ctx context.Context, r passwordRequest, deviceDB *devices.Database, acc *authtypes.Account, - localpart, token string, + token string, ) (dev *authtypes.Device, err error) { - dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID) - if err == sql.ErrNoRows { - // device doesn't exist, create one - dev, err = deviceDB.CreateDevice( - ctx, acc.Localpart, nil, token, r.InitialDisplayName, - ) - } + dev, err = deviceDB.CreateDevice( + ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName, + ) return } diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index d20138534..3294fbcdc 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -20,7 +20,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -29,13 +28,6 @@ import ( func Logout( req *http.Request, deviceDB *devices.Database, device *authtypes.Device, ) util.JSONResponse { - if req.Method != http.MethodPost { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } - } - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return httputil.LogThenError(req, err) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index b308de79a..c71ac2de2 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -58,27 +58,12 @@ func SendMembership( } } - inviteStored, err := threepid.CheckAndProcessInvite( - req.Context(), device, &body, cfg, queryAPI, accountDB, producer, + inviteStored, jsonErrResp := checkAndProcessThreepid( + req, device, &body, cfg, queryAPI, accountDB, producer, membership, roomID, evTime, ) - if err == threepid.ErrMissingParameter { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(err.Error()), - } - } else if err == threepid.ErrNotTrusted { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.NotTrusted(body.IDServer), - } - } else if err == common.ErrRoomNoExists { - return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound(err.Error()), - } - } else if err != nil { - return httputil.LogThenError(req, err) + if jsonErrResp != nil { + return *jsonErrResp } // If an invite has been stored on an identity server, it means that a @@ -114,9 +99,18 @@ func SendMembership( return httputil.LogThenError(req, err) } + var returnData interface{} = struct{}{} + + // The join membership requires the room id to be sent in the response + if membership == gomatrixserverlib.Join { + returnData = struct { + RoomID string `json:"room_id"` + }{roomID} + } + return util.JSONResponse{ Code: http.StatusOK, - JSON: struct{}{}, + JSON: returnData, } } @@ -147,10 +141,10 @@ func buildMembershipEvent( // "unban" or "kick" isn't a valid membership value, change it to "leave" if membership == "unban" || membership == "kick" { - membership = "leave" + membership = gomatrixserverlib.Leave } - content := common.MemberContent{ + content := gomatrixserverlib.MemberContent{ Membership: membership, DisplayName: profile.DisplayName, AvatarURL: profile.AvatarURL, @@ -182,7 +176,7 @@ func loadProfile( var profile *authtypes.Profile if serverName == cfg.Matrix.ServerName { - profile, err = appserviceAPI.RetreiveUserProfile(ctx, userID, asAPI, accountDB) + profile, err = appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) } else { profile = &authtypes.Profile{} } @@ -198,7 +192,7 @@ func loadProfile( func getMembershipStateKey( body threepid.MembershipRequest, device *authtypes.Device, membership string, ) (stateKey string, reason string, err error) { - if membership == "ban" || membership == "unban" || membership == "kick" || membership == "invite" { + if membership == gomatrixserverlib.Ban || membership == "unban" || membership == "kick" || membership == gomatrixserverlib.Invite { // If we're in this case, the state key is contained in the request body, // possibly along with a reason (for "kick" and "ban") so we need to parse // it @@ -215,3 +209,41 @@ func getMembershipStateKey( return } + +func checkAndProcessThreepid( + req *http.Request, + device *authtypes.Device, + body *threepid.MembershipRequest, + cfg config.Dendrite, + queryAPI roomserverAPI.RoomserverQueryAPI, + accountDB *accounts.Database, + producer *producers.RoomserverProducer, + membership, roomID string, + evTime time.Time, +) (inviteStored bool, errRes *util.JSONResponse) { + + inviteStored, err := threepid.CheckAndProcessInvite( + req.Context(), device, body, cfg, queryAPI, accountDB, producer, + membership, roomID, evTime, + ) + if err == threepid.ErrMissingParameter { + return inviteStored, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + } else if err == threepid.ErrNotTrusted { + return inviteStored, &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotTrusted(body.IDServer), + } + } else if err == common.ErrRoomNoExists { + return inviteStored, &util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound(err.Error()), + } + } else if err != nil { + er := httputil.LogThenError(req, err) + return inviteStored, &er + } + return +} diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index e57d16fbf..a87c6f743 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -30,49 +30,61 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/util" ) // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, + req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + userID string, + asAPI appserviceAPI.AppServiceQueryAPI, + federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { - if req.Method != http.MethodGet { - return util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } - } - profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { + if err == common.ErrProfileNoExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), + } + } + return httputil.LogThenError(req, err) } - res := common.ProfileResponse{ - AvatarURL: profile.AvatarURL, - DisplayName: profile.DisplayName, - } return util.JSONResponse{ Code: http.StatusOK, - JSON: res, + JSON: common.ProfileResponse{ + AvatarURL: profile.AvatarURL, + DisplayName: profile.DisplayName, + }, } } // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, + req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + userID string, asAPI appserviceAPI.AppServiceQueryAPI, + federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { - profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { + if err == common.ErrProfileNoExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), + } + } + return httputil.LogThenError(req, err) } - res := common.AvatarURL{ - AvatarURL: profile.AvatarURL, - } return util.JSONResponse{ Code: http.StatusOK, - JSON: res, + JSON: common.AvatarURL{ + AvatarURL: profile.AvatarURL, + }, } } @@ -158,18 +170,27 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB *accounts.Database, userID string, asAPI appserviceAPI.AppServiceQueryAPI, + req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + userID string, asAPI appserviceAPI.AppServiceQueryAPI, + federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { - profile, err := appserviceAPI.RetreiveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := getProfile(req.Context(), accountDB, cfg, userID, asAPI, federation) if err != nil { + if err == common.ErrProfileNoExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The user does not exist or does not have a profile"), + } + } + return httputil.LogThenError(req, err) } - res := common.DisplayName{ - DisplayName: profile.DisplayName, - } + return util.JSONResponse{ Code: http.StatusOK, - JSON: res, + JSON: common.DisplayName{ + DisplayName: profile.DisplayName, + }, } } @@ -253,6 +274,48 @@ func SetDisplayName( } } +// getProfile gets the full profile of a user by querying the database or a +// remote homeserver. +// Returns an error when something goes wrong or specifically +// common.ErrProfileNoExists when the profile doesn't exist. +func getProfile( + ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite, + userID string, + asAPI appserviceAPI.AppServiceQueryAPI, + federation *gomatrixserverlib.FederationClient, +) (*authtypes.Profile, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + + if domain != cfg.Matrix.ServerName { + profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") + if fedErr != nil { + if x, ok := fedErr.(gomatrix.HTTPError); ok { + if x.Code == http.StatusNotFound { + return nil, common.ErrProfileNoExists + } + } + + return nil, fedErr + } + + return &authtypes.Profile{ + Localpart: localpart, + DisplayName: profile.DisplayName, + AvatarURL: profile.AvatarURL, + }, nil + } + + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) + if err != nil { + return nil, err + } + + return profile, nil +} + func buildMembershipEvents( ctx context.Context, memberships []authtypes.Membership, @@ -269,8 +332,8 @@ func buildMembershipEvents( StateKey: &userID, } - content := common.MemberContent{ - Membership: "join", + content := gomatrixserverlib.MemberContent{ + Membership: gomatrixserverlib.Join, } content.DisplayName = newProfile.DisplayName diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index b1522e82b..d0f36a6fd 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -29,6 +29,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/matrix-org/dendrite/common/config" @@ -70,12 +71,17 @@ func init() { } // sessionsDict keeps track of completed auth stages for each session. +// It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { + sync.Mutex sessions map[string][]authtypes.LoginType } // GetCompletedStages returns the completed stages for a session. -func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { +func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { + d.Lock() + defer d.Unlock() + if completedStages, ok := d.sessions[sessionID]; ok { return completedStages } @@ -83,17 +89,25 @@ func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType return make([]authtypes.LoginType, 0) } -// AddCompletedStage records that a session has completed an auth stage. -func (d *sessionsDict) AddCompletedStage(sessionID string, stage authtypes.LoginType) { - d.sessions[sessionID] = append(d.GetCompletedStages(sessionID), stage) -} - func newSessionsDict() *sessionsDict { return &sessionsDict{ sessions: make(map[string][]authtypes.LoginType), } } +// AddCompletedSessionStage records that a session has completed an auth stage. +func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { + sessions.Lock() + defer sessions.Unlock() + + for _, completedStage := range sessions.sessions[sessionID] { + if completedStage == stage { + return + } + } + sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) +} + var ( // TODO: Remove old sessions. Need to do so on a session-specific timeout. // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. @@ -115,7 +129,10 @@ type registerRequest struct { // user-interactive auth params Auth authDict `json:"auth"` + // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") + // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` + DeviceID *string `json:"device_id"` // Prevent this user from logging in InhibitLogin common.WeakBoolean `json:"inhibit_login"` @@ -243,8 +260,8 @@ func validateRecaptcha( ) *util.JSONResponse { if !cfg.Matrix.RecaptchaEnabled { return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha registration is disabled"), + Code: http.StatusConflict, + JSON: jsonerror.Unknown("Captcha registration is disabled"), } } @@ -279,8 +296,8 @@ func validateRecaptcha( body, err := ioutil.ReadAll(resp.Body) if err != nil { return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in contacting captcha server" + err.Error()), + Code: http.StatusGatewayTimeout, + JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), } } err = json.Unmarshal(body, &r) @@ -521,7 +538,7 @@ func handleRegistrationFlow( } // Add Recaptcha to the list of completed registration stages - sessions.AddCompletedStage(sessionID, authtypes.LoginTypeRecaptcha) + AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) case authtypes.LoginTypeSharedSecret: // Check shared secret against config @@ -534,7 +551,7 @@ func handleRegistrationFlow( } // Add SharedSecret to the list of completed registration stages - sessions.AddCompletedStage(sessionID, authtypes.LoginTypeSharedSecret) + AddCompletedSessionStage(sessionID, authtypes.LoginTypeSharedSecret) case "": // Extract the access token from the request, if there's one to extract @@ -564,7 +581,7 @@ func handleRegistrationFlow( case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages - sessions.AddCompletedStage(sessionID, authtypes.LoginTypeDummy) + AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) default: return util.JSONResponse{ @@ -620,7 +637,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, - r.InhibitLogin, r.InitialDisplayName, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -640,7 +657,7 @@ func checkAndCompleteFlow( // This flow was completed, registration can continue return completeRegistration( req.Context(), accountDB, deviceDB, r.Username, r.Password, "", - r.InhibitLogin, r.InitialDisplayName, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -691,10 +708,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -732,13 +749,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u return nil } +// completeRegistration runs some rudimentary checks against the submitted +// input, then if successful creates an account and a newly associated device +// We pass in each individual part of the request here instead of just passing a +// registerRequest, as this function serves requests encoded as both +// registerRequests and legacyRegisterRequests, which share some attributes but +// not all func completeRegistration( ctx context.Context, accountDB *accounts.Database, deviceDB *devices.Database, username, password, appserviceID string, inhibitLogin common.WeakBoolean, - displayName *string, + displayName, deviceID *string, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -767,6 +790,9 @@ func completeRegistration( } } + // Increment prometheus counter for created users + amtRegUsers.Inc() + // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { @@ -787,8 +813,7 @@ func completeRegistration( } } - // TODO: Use the device ID in the request. - dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName) + dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -796,9 +821,6 @@ func completeRegistration( } } - // Increment prometheus counter for created users - amtRegUsers.Inc() - return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go new file mode 100644 index 000000000..6e7324cd8 --- /dev/null +++ b/clientapi/routing/room_tagging.go @@ -0,0 +1,234 @@ +// Copyright 2019 Sumukha PK +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// newTag creates and returns a new gomatrix.TagContent +func newTag() gomatrix.TagContent { + return gomatrix.TagContent{ + Tags: make(map[string]gomatrix.TagProperties), + } +} + +// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags +func GetTags( + req *http.Request, + accountDB *accounts.Database, + device *authtypes.Device, + userID string, + roomID string, + syncProducer *producers.SyncAPIProducer, +) util.JSONResponse { + + if device.UserID != userID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Cannot retrieve another user's tags"), + } + } + + _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + if err != nil { + return httputil.LogThenError(req, err) + } + + if len(data) == 0 { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: data[0].Content, + } +} + +// PutTag implements PUT /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag} +// Put functionality works by getting existing data from the DB (if any), adding +// the tag to the "map" and saving the new "map" to the DB +func PutTag( + req *http.Request, + accountDB *accounts.Database, + device *authtypes.Device, + userID string, + roomID string, + tag string, + syncProducer *producers.SyncAPIProducer, +) util.JSONResponse { + + if device.UserID != userID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + } + } + + var properties gomatrix.TagProperties + if reqErr := httputil.UnmarshalJSONRequest(req, &properties); reqErr != nil { + return *reqErr + } + + localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + if err != nil { + return httputil.LogThenError(req, err) + } + + var tagContent gomatrix.TagContent + if len(data) > 0 { + if err = json.Unmarshal(data[0].Content, &tagContent); err != nil { + return httputil.LogThenError(req, err) + } + } else { + tagContent = newTag() + } + tagContent.Tags[tag] = properties + if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + return httputil.LogThenError(req, err) + } + + // Send data to syncProducer in order to inform clients of changes + // Run in a goroutine in order to prevent blocking the tag request response + go func() { + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } + }() + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +// DeleteTag implements DELETE /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags/{tag} +// Delete functionality works by obtaining the saved tags, removing the intended tag from +// the "map" and then saving the new "map" in the DB +func DeleteTag( + req *http.Request, + accountDB *accounts.Database, + device *authtypes.Device, + userID string, + roomID string, + tag string, + syncProducer *producers.SyncAPIProducer, +) util.JSONResponse { + + if device.UserID != userID { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Cannot modify another user's tags"), + } + } + + localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + if err != nil { + return httputil.LogThenError(req, err) + } + + // If there are no tags in the database, exit + if len(data) == 0 { + // Spec only defines 200 responses for this endpoint so we don't return anything else. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + + var tagContent gomatrix.TagContent + err = json.Unmarshal(data[0].Content, &tagContent) + if err != nil { + return httputil.LogThenError(req, err) + } + + // Check whether the tag to be deleted exists + if _, ok := tagContent.Tags[tag]; ok { + delete(tagContent.Tags, tag) + } else { + // Spec only defines 200 responses for this endpoint so we don't return anything else. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } + } + if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + return httputil.LogThenError(req, err) + } + + // Send data to syncProducer in order to inform clients of changes + // Run in a goroutine in order to prevent blocking the tag request response + go func() { + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } + }() + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} + +// obtainSavedTags gets all tags scoped to a userID and roomID +// from the database +func obtainSavedTags( + req *http.Request, + userID string, + roomID string, + accountDB *accounts.Database, +) (string, []gomatrixserverlib.ClientEvent, error) { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return "", nil, err + } + + data, err := accountDB.GetAccountDataByType( + req.Context(), localpart, roomID, "m.tag", + ) + + return localpart, data, err +} + +// saveTagData saves the provided tag data into the database +func saveTagData( + req *http.Request, + localpart string, + roomID string, + accountDB *accounts.Database, + Tag gomatrix.TagContent, +) error { + newTagData, err := json.Marshal(Tag) + if err != nil { + return err + } + + return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 8135e49af..d4b323a2d 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -93,7 +93,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", - common.MakeAuthAPI("join", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + common.MakeAuthAPI(gomatrixserverlib.Join, authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -132,6 +132,15 @@ func Setup( nil, cfg, queryAPI, producer, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/event/{eventID}", + common.MakeAuthAPI("rooms_get_event", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, queryAPI, federation, keyRing) + }), + ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(req)) @@ -236,6 +245,13 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + r0mux.Handle("/auth/{authType}/fallback/web", + common.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + vars := mux.Vars(req) + return AuthFallback(w, req, vars["authType"], cfg) + }), + ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + r0mux.Handle("/pushrules/", common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API @@ -283,7 +299,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetProfile(req, accountDB, vars["userID"], asAPI) + return GetProfile(req, accountDB, &cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -293,7 +309,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAvatarURL(req, accountDB, vars["userID"], asAPI) + return GetAvatarURL(req, accountDB, &cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -315,7 +331,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetDisplayName(req, accountDB, vars["userID"], asAPI) + return GetDisplayName(req, accountDB, &cfg, vars["userID"], asAPI, federation) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -483,4 +499,34 @@ func Setup( }} }), ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", + common.MakeAuthAPI("get_tags", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + common.MakeAuthAPI("put_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + common.MakeAuthAPI("delete_tag", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + }), + ).Methods(http.MethodDelete, http.MethodOptions) } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index e916e451e..9696b360e 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -50,7 +50,7 @@ func SendEvent( ) util.JSONResponse { if txnID != nil { // Try to fetch response from transactionsCache - if res, ok := txnCache.FetchTransaction(*txnID); ok { + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { return *res } } @@ -83,7 +83,7 @@ func SendEvent( } // Add response to transactionsCache if txnID != nil { - txnCache.AddTransaction(*txnID, &res) + txnCache.AddTransaction(device.AccessToken, *txnID, &res) } return res diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 2538577fd..bfe5060a8 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -56,10 +56,10 @@ type idServerLookupResponse struct { // idServerLookupResponse represents the response described at https://matrix.org/docs/spec/client_server/r0.2.0.html#invitation-storage type idServerStoreInviteResponse struct { - PublicKey string `json:"public_key"` - Token string `json:"token"` - DisplayName string `json:"display_name"` - PublicKeys []common.PublicKey `json:"public_keys"` + PublicKey string `json:"public_key"` + Token string `json:"token"` + DisplayName string `json:"display_name"` + PublicKeys []gomatrixserverlib.PublicKey `json:"public_keys"` } var ( @@ -91,7 +91,7 @@ func CheckAndProcessInvite( producer *producers.RoomserverProducer, membership string, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { - if membership != "invite" || (body.Address == "" && body.IDServer == "" && body.Medium == "") { + if membership != gomatrixserverlib.Invite || (body.Address == "" && body.IDServer == "" && body.Medium == "") { // If none of the 3PID-specific fields are supplied, it's a standard invite // so return nil for it to be processed as such return @@ -342,7 +342,7 @@ func emit3PIDInviteEvent( } validityURL := fmt.Sprintf("https://%s/_matrix/identity/api/v1/pubkey/isvalid", body.IDServer) - content := common.ThirdPartyInviteContent{ + content := gomatrixserverlib.ThirdPartyInviteContent{ DisplayName: res.DisplayName, KeyValidityURL: validityURL, PublicKey: res.PublicKey, diff --git a/cmd/create-room-events/main.go b/cmd/create-room-events/main.go index 1d05b2a12..8475914f0 100644 --- a/cmd/create-room-events/main.go +++ b/cmd/create-room-events/main.go @@ -86,7 +86,7 @@ func main() { // Build a m.room.member event. b.Type = "m.room.member" b.StateKey = userID - b.SetContent(map[string]string{"membership": "join"}) // nolint: errcheck + b.SetContent(map[string]string{"membership": gomatrixserverlib.Join}) // nolint: errcheck b.AuthEvents = []gomatrixserverlib.EventReference{create} member := buildAndOutput() diff --git a/common/config/config.go b/common/config/config.go index 9fcab8cf9..40232fb03 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -498,6 +498,11 @@ func (config *Dendrite) checkMatrix(configErrs *configErrors) { checkNotEmpty(configErrs, "matrix.server_name", string(config.Matrix.ServerName)) checkNotEmpty(configErrs, "matrix.private_key", string(config.Matrix.PrivateKeyPath)) checkNotZero(configErrs, "matrix.federation_certificates", int64(len(config.Matrix.FederationCertificatePaths))) + if config.Matrix.RecaptchaEnabled { + checkNotEmpty(configErrs, "matrix.recaptcha_public_key", string(config.Matrix.RecaptchaPublicKey)) + checkNotEmpty(configErrs, "matrix.recaptcha_private_key", string(config.Matrix.RecaptchaPrivateKey)) + checkNotEmpty(configErrs, "matrix.recaptcha_siteverify_api", string(config.Matrix.RecaptchaSiteVerifyAPI)) + } } // checkMedia verifies the parameters media.* are valid. diff --git a/common/config/config_test.go b/common/config/config_test.go index acc4dbd12..110c8b84c 100644 --- a/common/config/config_test.go +++ b/common/config/config_test.go @@ -54,12 +54,14 @@ database: server_key: "postgresql:///server_keys" sync_api: "postgresql:///syn_api" room_server: "postgresql:///room_server" + appservice: "postgresql:///appservice" listen: room_server: "localhost:7770" client_api: "localhost:7771" federation_api: "localhost:7772" sync_api: "localhost:7773" media_api: "localhost:7774" + appservice_api: "localhost:7777" typing_server: "localhost:7778" logging: - type: "file" diff --git a/common/eventcontent.go b/common/eventcontent.go index 971c4f0a7..c07c56276 100644 --- a/common/eventcontent.go +++ b/common/eventcontent.go @@ -14,47 +14,7 @@ package common -// CreateContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-create -type CreateContent struct { - Creator string `json:"creator"` - Federate *bool `json:"m.federate,omitempty"` -} - -// MemberContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member -type MemberContent struct { - Membership string `json:"membership"` - DisplayName string `json:"displayname,omitempty"` - AvatarURL string `json:"avatar_url,omitempty"` - Reason string `json:"reason,omitempty"` - ThirdPartyInvite *TPInvite `json:"third_party_invite,omitempty"` -} - -// TPInvite is the "Invite" structure defined at http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member -type TPInvite struct { - DisplayName string `json:"display_name"` - Signed TPInviteSigned `json:"signed"` -} - -// TPInviteSigned is the "signed" structure defined at http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-member -type TPInviteSigned struct { - MXID string `json:"mxid"` - Signatures map[string]map[string]string `json:"signatures"` - Token string `json:"token"` -} - -// ThirdPartyInviteContent is the content event for https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-third-party-invite -type ThirdPartyInviteContent struct { - DisplayName string `json:"display_name"` - KeyValidityURL string `json:"key_validity_url"` - PublicKey string `json:"public_key"` - PublicKeys []PublicKey `json:"public_keys"` -} - -// PublicKey is the PublicKeys structure in https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-third-party-invite -type PublicKey struct { - KeyValidityURL string `json:"key_validity_url"` - PublicKey string `json:"public_key"` -} +import "github.com/matrix-org/gomatrixserverlib" // NameContent is the event content for https://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-name type NameContent struct { @@ -71,51 +31,26 @@ type GuestAccessContent struct { GuestAccess string `json:"guest_access"` } -// JoinRulesContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-join-rules -type JoinRulesContent struct { - JoinRule string `json:"join_rule"` -} - // HistoryVisibilityContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-history-visibility type HistoryVisibilityContent struct { HistoryVisibility string `json:"history_visibility"` } -// PowerLevelContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels -type PowerLevelContent struct { - EventsDefault int `json:"events_default"` - Invite int `json:"invite"` - StateDefault int `json:"state_default"` - Redact int `json:"redact"` - Ban int `json:"ban"` - UsersDefault int `json:"users_default"` - Events map[string]int `json:"events"` - Kick int `json:"kick"` - Users map[string]int `json:"users"` -} - // InitialPowerLevelsContent returns the initial values for m.room.power_levels on room creation // if they have not been specified. // http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-power-levels // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/handlers/room.py#L294 -func InitialPowerLevelsContent(roomCreator string) PowerLevelContent { - return PowerLevelContent{ - EventsDefault: 0, - Invite: 0, - StateDefault: 50, - Redact: 50, - Ban: 50, - UsersDefault: 0, - Events: map[string]int{ - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - }, - Kick: 50, - Users: map[string]int{roomCreator: 100}, +func InitialPowerLevelsContent(roomCreator string) (c gomatrixserverlib.PowerLevelContent) { + c.Defaults() + c.Events = map[string]int64{ + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, } + c.Users = map[string]int64{roomCreator: 100} + return c } // AliasesContent is the event content for http://matrix.org/docs/spec/client_server/r0.2.0.html#m-room-aliases diff --git a/common/httpapi.go b/common/httpapi.go index 99e15830a..bf634ff4a 100644 --- a/common/httpapi.go +++ b/common/httpapi.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/util" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -43,6 +44,24 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse return http.HandlerFunc(withSpan) } +// MakeHTMLAPI adds Span metrics to the HTML Handler function +// This is used to serve HTML alongside JSON error messages +func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { + withSpan := func(w http.ResponseWriter, req *http.Request) { + span := opentracing.StartSpan(metricsName) + defer span.Finish() + req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) + if err := f(w, req); err != nil { + h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { + return *err + })) + h.ServeHTTP(w, req) + } + } + + return prometheus.InstrumentHandler(metricsName, http.HandlerFunc(withSpan)) +} + // MakeInternalAPI turns a util.JSONRequestHandler function into an http.Handler. // This is used for APIs that are internal to dendrite. // If we are passed a tracing context in the request headers then we use that diff --git a/common/log.go b/common/log.go index 89a705822..f9ed84edb 100644 --- a/common/log.go +++ b/common/log.go @@ -15,9 +15,12 @@ package common import ( + "fmt" "os" "path" "path/filepath" + "runtime" + "strings" "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dugong" @@ -54,15 +57,35 @@ func (h *logLevelHook) Levels() []logrus.Level { return levels } +// callerPrettyfier is a function that given a runtime.Frame object, will +// extract the calling function's name and file, and return them in a nicely +// formatted way +func callerPrettyfier(f *runtime.Frame) (string, string) { + // Retrieve just the function name + s := strings.Split(f.Function, ".") + funcname := s[len(s)-1] + + // Append a newline + tab to it to move the actual log content to its own line + funcname += "\n\t" + + // Surround the filepath in brackets and append line number so IDEs can quickly + // navigate + filename := fmt.Sprintf(" [%s:%d]", f.File, f.Line) + + return funcname, filename +} + // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ TimestampFormat: "2006-01-02T15:04:05.000000000Z07:00", FullTimestamp: true, DisableColors: false, DisableTimestamp: false, - DisableSorting: false, + QuoteEmptyFields: true, + CallerPrettyfier: callerPrettyfier, }, }) } @@ -71,8 +94,8 @@ func SetupStdLogging() { // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { + logrus.SetReportCaller(true) for _, hook := range hooks { - // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) if err != nil { @@ -126,6 +149,7 @@ func setupFileHook(hook config.LogrusHook, level logrus.Level, componentName str DisableColors: true, DisableTimestamp: false, DisableSorting: false, + QuoteEmptyFields: true, }, }, &dugong.DailyRotationSchedule{GZip: true}, diff --git a/common/transactions/transactions.go b/common/transactions/transactions.go index febcb9a75..80b403a98 100644 --- a/common/transactions/transactions.go +++ b/common/transactions/transactions.go @@ -22,7 +22,14 @@ import ( // DefaultCleanupPeriod represents the default time duration after which cacheCleanService runs. const DefaultCleanupPeriod time.Duration = 30 * time.Minute -type txnsMap map[string]*util.JSONResponse +type txnsMap map[CacheKey]*util.JSONResponse + +// CacheKey is the type for the key in a transactions cache. +// This is needed because the spec requires transaction IDs to have a per-access token scope. +type CacheKey struct { + AccessToken string + TxnID string +} // Cache represents a temporary store for response entries. // Entries are evicted after a certain period, defined by cleanupPeriod. @@ -50,14 +57,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache { return &t } -// FetchTransaction looks up an entry for txnID in Cache. +// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. // Looks in both the txnMaps. // Returns (JSON response, true) if txnID is found, else the returned bool is false. -func (t *Cache) FetchTransaction(txnID string) (*util.JSONResponse, bool) { +func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { t.RLock() defer t.RUnlock() for _, txns := range t.txnsMaps { - res, ok := txns[txnID] + res, ok := txns[CacheKey{accessToken, txnID}] if ok { return res, true } @@ -65,13 +72,13 @@ func (t *Cache) FetchTransaction(txnID string) (*util.JSONResponse, bool) { return nil, false } -// AddTransaction adds an entry for txnID in Cache for later access. +// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. // Adds to the front txnMap. -func (t *Cache) AddTransaction(txnID string, res *util.JSONResponse) { +func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { t.Lock() defer t.Unlock() - t.txnsMaps[0][txnID] = res + t.txnsMaps[0][CacheKey{accessToken, txnID}] = res } // cacheCleanService is responsible for cleaning up entries after cleanupPeriod. diff --git a/common/transactions/transactions_test.go b/common/transactions/transactions_test.go index 0cdb776cc..f565e4846 100644 --- a/common/transactions/transactions_test.go +++ b/common/transactions/transactions_test.go @@ -24,27 +24,54 @@ type fakeType struct { } var ( - fakeTxnID = "aRandomTxnID" - fakeResponse = &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: "0"}} + fakeAccessToken = "aRandomAccessToken" + fakeAccessToken2 = "anotherRandomAccessToken" + fakeTxnID = "aRandomTxnID" + fakeResponse = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "0"}, + } + fakeResponse2 = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "1"}, + } ) // TestCache creates a New Cache and tests AddTransaction & FetchTransaction func TestCache(t *testing.T) { fakeTxnCache := New() - fakeTxnCache.AddTransaction(fakeTxnID, fakeResponse) + fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) // Add entries for noise. for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( + fakeAccessToken, fakeTxnID+string(i), &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: string(i)}}, ) } - testResponse, ok := fakeTxnCache.FetchTransaction(fakeTxnID) + testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) if !ok { t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) } else if testResponse.JSON != fakeResponse.JSON { t.Error("Fetched response incorrect. Expected: ", fakeResponse.JSON, " got: ", testResponse.JSON) } } + +// TestCacheScope ensures transactions with the same transaction ID are not shared +// across multiple access tokens. +func TestCacheScope(t *testing.T) { + cache := New() + cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) + + if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) + } + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse2.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) + } +} diff --git a/common/types.go b/common/types.go index 6888d3806..91765be00 100644 --- a/common/types.go +++ b/common/types.go @@ -15,9 +15,14 @@ package common import ( + "errors" "strconv" ) +// ErrProfileNoExists is returned when trying to lookup a user's profile that +// doesn't exist locally. +var ErrProfileNoExists = errors.New("no known profile for given user ID") + // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { diff --git a/docker/README.md b/docker/README.md index 7d18ce605..ff88c0818 100644 --- a/docker/README.md +++ b/docker/README.md @@ -58,7 +58,7 @@ docker-compose up kafka zookeeper postgres and the following dendrite components ``` -docker-compose up client_api media_api sync_api room_server public_rooms_api +docker-compose up client_api media_api sync_api room_server public_rooms_api typing_server docker-compose up client_api_proxy ``` diff --git a/docker/dendrite-docker.yml b/docker/dendrite-docker.yml index c2e7682eb..abb8c3307 100644 --- a/docker/dendrite-docker.yml +++ b/docker/dendrite-docker.yml @@ -114,6 +114,7 @@ listen: media_api: "media_api:7774" public_rooms_api: "public_rooms_api:7775" federation_sender: "federation_sender:7776" + typing_server: "typing_server:7777" # The configuration for tracing the dendrite components. tracing: diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 763e5b0f0..9cf67457c 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -95,6 +95,16 @@ services: networks: - internal + typing_server: + container_name: dendrite_typing_server + hostname: typing_server + entrypoint: ["bash", "./docker/services/typing-server.sh"] + build: ./ + volumes: + - ..:/build + networks: + - internal + federation_api_proxy: container_name: dendrite_federation_api_proxy hostname: federation_api_proxy diff --git a/docker/services/typing-server.sh b/docker/services/typing-server.sh new file mode 100644 index 000000000..16ee0fa62 --- /dev/null +++ b/docker/services/typing-server.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +bash ./docker/build.sh + +./bin/dendrite-typing-server --config=dendrite.yaml diff --git a/docs/images/details-button-location.jpg b/docs/images/details-button-location.jpg new file mode 100644 index 000000000..53129a6e1 Binary files /dev/null and b/docs/images/details-button-location.jpg differ diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 0b60408f7..6f6574dd7 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -58,7 +58,7 @@ func MakeJoin( Type: "m.room.member", StateKey: &userID, } - err = builder.SetContent(map[string]interface{}{"membership": "join"}) + err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Join}) if err != nil { return httputil.LogThenError(httpReq, err) } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 3c57d39d1..a982b87f8 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -56,7 +56,7 @@ func MakeLeave( Type: "m.room.member", StateKey: &userID, } - err = builder.SetContent(map[string]interface{}{"membership": "leave"}) + err = builder.SetContent(map[string]interface{}{"membership": gomatrixserverlib.Leave}) if err != nil { return httputil.LogThenError(httpReq, err) } @@ -153,7 +153,7 @@ func SendLeave( mem, err := event.Membership() if err != nil { return httputil.LogThenError(httpReq, err) - } else if mem != "leave" { + } else if mem != gomatrixserverlib.Leave { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("The membership in the event content must be set to leave"), diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index aa4fcdc42..2b478cfbf 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -53,7 +53,7 @@ func GetProfile( return httputil.LogThenError(httpReq, err) } - profile, err := appserviceAPI.RetreiveUserProfile(httpReq.Context(), userID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(httpReq.Context(), userID, asAPI, accountDB) if err != nil { return httputil.LogThenError(httpReq, err) } diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 16704e0b2..9f576790b 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -64,8 +64,9 @@ func Setup( // {keyID} argument and always return a response containing all of the keys. v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) + v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) - v1fedmux.Handle("/send/{txnID}/", common.MakeFedAPI( + v1fedmux.Handle("/send/{txnID}", common.MakeFedAPI( "federation_send", cfg.Matrix.ServerName, keys, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) @@ -260,7 +261,7 @@ func Setup( }, )).Methods(http.MethodPost) - v1fedmux.Handle("/backfill/{roomID}/", common.MakeFedAPI( + v1fedmux.Handle("/backfill/{roomID}", common.MakeFedAPI( "federation_backfill", cfg.Matrix.ServerName, keys, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(httpReq)) diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index 27796067b..7fa02be91 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -27,7 +27,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common/config" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -38,11 +37,11 @@ import ( ) type invite struct { - MXID string `json:"mxid"` - RoomID string `json:"room_id"` - Sender string `json:"sender"` - Token string `json:"token"` - Signed common.TPInviteSigned `json:"signed"` + MXID string `json:"mxid"` + RoomID string `json:"room_id"` + Sender string `json:"sender"` + Token string `json:"token"` + Signed gomatrixserverlib.MemberThirdPartyInviteSigned `json:"signed"` } type invites struct { @@ -194,16 +193,16 @@ func createInviteFrom3PIDInvite( StateKey: &inv.MXID, } - profile, err := appserviceAPI.RetreiveUserProfile(ctx, inv.MXID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, inv.MXID, asAPI, accountDB) if err != nil { return nil, err } - content := common.MemberContent{ + content := gomatrixserverlib.MemberContent{ AvatarURL: profile.AvatarURL, DisplayName: profile.DisplayName, - Membership: "invite", - ThirdPartyInvite: &common.TPInvite{ + Membership: gomatrixserverlib.Invite, + ThirdPartyInvite: &gomatrixserverlib.MemberThirdPartyInvite{ Signed: inv.Signed, }, } @@ -330,7 +329,7 @@ func sendToRemoteServer( func fillDisplayName( builder *gomatrixserverlib.EventBuilder, authEvents gomatrixserverlib.AuthEvents, ) error { - var content common.MemberContent + var content gomatrixserverlib.MemberContent if err := json.Unmarshal(builder.Content, &content); err != nil { return err } @@ -343,7 +342,7 @@ func fillDisplayName( return nil } - var thirdPartyInviteContent common.ThirdPartyInviteContent + var thirdPartyInviteContent gomatrixserverlib.ThirdPartyInviteContent if err := json.Unmarshal(thirdPartyInviteEvent.Content(), &thirdPartyInviteContent); err != nil { return err } diff --git a/federationsender/api/query.go b/federationsender/api/query.go new file mode 100644 index 000000000..ebc6e833f --- /dev/null +++ b/federationsender/api/query.go @@ -0,0 +1,98 @@ +package api + +import ( + "context" + "net/http" + + commonHTTP "github.com/matrix-org/dendrite/common/http" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/opentracing/opentracing-go" +) + +// QueryJoinedHostsInRoomRequest is a request to QueryJoinedHostsInRoom +type QueryJoinedHostsInRoomRequest struct { + RoomID string `json:"room_id"` +} + +// QueryJoinedHostsInRoomResponse is a response to QueryJoinedHostsInRoom +type QueryJoinedHostsInRoomResponse struct { + JoinedHosts []types.JoinedHost `json:"joined_hosts"` +} + +// QueryJoinedHostServerNamesRequest is a request to QueryJoinedHostServerNames +type QueryJoinedHostServerNamesInRoomRequest struct { + RoomID string `json:"room_id"` +} + +// QueryJoinedHostServerNamesResponse is a response to QueryJoinedHostServerNames +type QueryJoinedHostServerNamesInRoomResponse struct { + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +// FederationSenderQueryAPI is used to query information from the federation sender. +type FederationSenderQueryAPI interface { + // Query the joined hosts and the membership events accounting for their participation in a room. + // Note that if a server has multiple users in the room, it will have multiple entries in the returned slice. + // See `QueryJoinedHostServerNamesInRoom` for a de-duplicated version. + QueryJoinedHostsInRoom( + ctx context.Context, + request *QueryJoinedHostsInRoomRequest, + response *QueryJoinedHostsInRoomResponse, + ) error + // Query the server names of the joined hosts in a room. + // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice + // containing only the server names (without information for membership events). + QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *QueryJoinedHostServerNamesInRoomRequest, + response *QueryJoinedHostServerNamesInRoomResponse, + ) error +} + +// FederationSenderQueryJoinedHostsInRoomPath is the HTTP path for the QueryJoinedHostsInRoom API. +const FederationSenderQueryJoinedHostsInRoomPath = "/api/federationsender/queryJoinedHostsInRoom" + +// FederationSenderQueryJoinedHostServerNamesInRoomPath is the HTTP path for the QueryJoinedHostServerNamesInRoom API. +const FederationSenderQueryJoinedHostServerNamesInRoomPath = "/api/federationsender/queryJoinedHostServerNamesInRoom" + +// NewFederationSenderQueryAPIHTTP creates a FederationSenderQueryAPI implemented by talking to a HTTP POST API. +// If httpClient is nil then it uses the http.DefaultClient +func NewFederationSenderQueryAPIHTTP(federationSenderURL string, httpClient *http.Client) FederationSenderQueryAPI { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &httpFederationSenderQueryAPI{federationSenderURL, httpClient} +} + +type httpFederationSenderQueryAPI struct { + federationSenderURL string + httpClient *http.Client +} + +// QueryJoinedHostsInRoom implements FederationSenderQueryAPI +func (h *httpFederationSenderQueryAPI) QueryJoinedHostsInRoom( + ctx context.Context, + request *QueryJoinedHostsInRoomRequest, + response *QueryJoinedHostsInRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostsInRoom") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostsInRoomPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +// QueryJoinedHostServerNamesInRoom implements FederationSenderQueryAPI +func (h *httpFederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *QueryJoinedHostServerNamesInRoomRequest, + response *QueryJoinedHostServerNamesInRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderQueryJoinedHostServerNamesInRoomPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index 45e48f166..3ba978b1d 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -233,7 +233,7 @@ func joinedHostsFromEvents(evs []gomatrixserverlib.Event) ([]types.JoinedHost, e if err != nil { return nil, err } - if membership != "join" { + if membership != gomatrixserverlib.Join { continue } _, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) diff --git a/federationsender/query/query.go b/federationsender/query/query.go new file mode 100644 index 000000000..ec9242863 --- /dev/null +++ b/federationsender/query/query.go @@ -0,0 +1,55 @@ +package query + +import ( + "context" + + "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/federationsender/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// FederationSenderQueryDatabase has the APIs needed to implement the query API. +type FederationSenderQueryDatabase interface { + GetJoinedHosts( + ctx context.Context, roomID string, + ) ([]types.JoinedHost, error) +} + +// FederationSenderQueryAPI is an implementation of api.FederationSenderQueryAPI +type FederationSenderQueryAPI struct { + DB FederationSenderQueryDatabase +} + +// QueryJoinedHostsInRoom implements api.FederationSenderQueryAPI +func (f *FederationSenderQueryAPI) QueryJoinedHostsInRoom( + ctx context.Context, + request *api.QueryJoinedHostsInRoomRequest, + response *api.QueryJoinedHostsInRoomResponse, +) (err error) { + response.JoinedHosts, err = f.DB.GetJoinedHosts(ctx, request.RoomID) + return +} + +// QueryJoinedHostServerNamesInRoom implements api.FederationSenderQueryAPI +func (f *FederationSenderQueryAPI) QueryJoinedHostServerNamesInRoom( + ctx context.Context, + request *api.QueryJoinedHostServerNamesInRoomRequest, + response *api.QueryJoinedHostServerNamesInRoomResponse, +) (err error) { + joinedHosts, err := f.DB.GetJoinedHosts(ctx, request.RoomID) + if err != nil { + return + } + + serverNamesSet := make(map[gomatrixserverlib.ServerName]bool, len(joinedHosts)) + for _, host := range joinedHosts { + serverNamesSet[host.ServerName] = true + } + + response.ServerNames = make([]gomatrixserverlib.ServerName, 0, len(serverNamesSet)) + for name := range serverNamesSet { + response.ServerNames = append(response.ServerNames, name) + } + + return +} diff --git a/go.mod b/go.mod index 072d9ef30..d51f0a33e 100644 --- a/go.mod +++ b/go.mod @@ -20,10 +20,11 @@ require ( github.com/jaegertracing/jaeger-client-go v0.0.0-20170921145708-3ad49a1d839b github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/lib/pq v0.0.0-20170918175043-23da1db4f16d github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5 - github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af - github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 + github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 + github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 github.com/matttproud/golang_protobuf_extensions v1.0.1 @@ -40,8 +41,9 @@ require ( github.com/prometheus/common v0.0.0-20170108231212-dd2f054febf4 github.com/prometheus/procfs v0.0.0-20170128160123-1878d9fbb537 github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 - github.com/sirupsen/logrus v1.3.0 - github.com/stretchr/testify v1.2.2 + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/objx v0.2.0 // indirect + github.com/stretchr/testify v1.3.0 github.com/tidwall/gjson v1.1.5 github.com/tidwall/match v1.0.1 github.com/tidwall/sjson v1.0.3 @@ -54,7 +56,7 @@ require ( go.uber.org/zap v1.7.1 golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613 golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 - golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 + golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 gopkg.in/Shopify/sarama.v1 v1.11.0 gopkg.in/airbrake/gobrake.v2 v2.0.9 gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20170727041045-23bcc3c4eae3 diff --git a/go.sum b/go.sum index ce3c07dd7..56781c9a6 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,7 @@ github.com/jaegertracing/jaeger-lib v0.0.0-20170920222118-21a3da6d66fe/go.mod h1 github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -47,10 +48,18 @@ github.com/matrix-org/gomatrix v0.0.0-20171003113848-a7fc80c8060c h1:aZap604NyBG github.com/matrix-org/gomatrix v0.0.0-20171003113848-a7fc80c8060c/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af h1:piaIBNQGIHnni27xRB7VKkEwoWCgAmeuYf8pxAyG0bI= github.com/matrix-org/gomatrix v0.0.0-20190130130140-385f072fe9af/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= +github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bhrnp3Ky1qgx/fzCtCALOoGYylh2tpS9K4= +github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0 h1:3UzhmERBbis4ZaB3imEbZwtDjGz/oVRC2cLLEajCzJA= github.com/matrix-org/gomatrixserverlib v0.0.0-20181109104322-1c2cbc0872f0/go.mod h1:YHyhIQUmuXyKtoVfDUMk/DyU93Taamlu6nPZkij/JtA= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2 h1:pYajAEdi3sowj4iSunqctchhcMNW3rDjeeH0T4uDkMY= github.com/matrix-org/gomatrixserverlib v0.0.0-20190619132215-178ed5e3b8e2/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6 h1:B8n1H5Wb1B5jwLzTylBpY0kJCMRqrofT7PmOw4aJFJA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190724145009-a6df10ef35d6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6 h1:xr69Hk6QM3RIN6JSvx3RpDowBGpHpDDqhqXCeySwYow= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190805173246-3a2199d5ecd6/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f h1:20CZL7ApB7xgR7sZF9yD/qpsP51Sfx0TTgUJ3vKgnZQ= +github.com/matrix-org/gomatrixserverlib v0.0.0-20190814163046-d6285a18401f/go.mod h1:sf0RcKOdiwJeTti7A313xsaejNUGYDq02MQZ4JD4w/E= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A= github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A= github.com/matrix-org/util v0.0.0-20171013132526-8b1c8ab81986 h1:TiWl4hLvezAhRPM8tPcPDFTysZ7k4T/1J4GPp/iqlZo= @@ -88,9 +97,14 @@ github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2 h1:+8J/sCAVv2Y9Ct1 github.com/sirupsen/logrus v0.0.0-20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v0.0.0-20170809224252-890a5c3458b4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/tidwall/gjson v1.0.2 h1:5BsM7kyEAHAUGEGDkEKO9Mdyiuw6QQ6TSDdarP0Nnmk= github.com/tidwall/gjson v1.0.2/go.mod h1:c/nTNbUr0E0OrXEhq1pwa8iEgc2DOt4ZZqAt1HtCkPA= github.com/tidwall/gjson v1.1.5 h1:QysILxBeUEY3GTLA0fQVgkQG1zme8NxGvhh2SSqWNwI= @@ -126,6 +140,9 @@ golang.org/x/sys v0.0.0-20171012164349-43eea11bc926 h1:PY6OU86NqbyZiOzaPnDw6oOjA golang.org/x/sys v0.0.0-20171012164349-43eea11bc926/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33 h1:I6FyU15t786LL7oL/hn43zqTuEGr4PN7F4XJ1p4E3Y8= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 h1:LepdCS8Gf/MVejFIt8lsiexZATdoGVyp5bcyS+rYoUI= +golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0= gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= @@ -140,4 +157,3 @@ gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab h1:yZ6iByf7GKeJ3gsd1Dr/xaj1D gopkg.in/yaml.v2 v2.0.0-20171116090243-287cf08546ab/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= - diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 9c8f43c44..80ad8418d 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -55,7 +55,7 @@ type downloadRequest struct { Logger *log.Entry } -// Download implements /download amd /thumbnail +// Download implements GET /download and GET /thumbnail // Files from this server (i.e. origin == cfg.ServerName) are served directly // Files from remote servers (i.e. origin != cfg.ServerName) are cached locally. // If they are present in the cache, they are served directly. @@ -107,14 +107,6 @@ func Download( } // request validation - if req.Method != http.MethodGet { - dReq.jsonErrorResponse(w, util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.Unknown("request method must be GET"), - }) - return - } - if resErr := dReq.Validate(); resErr != nil { dReq.jsonErrorResponse(w, *resErr) return @@ -305,6 +297,10 @@ func (r *downloadRequest) respondFromLocalFile( }).Info("Responding with file") responseFile = file responseMetadata = r.MediaMetadata + + if len(responseMetadata.UploadName) > 0 { + w.Header().Set("Content-Disposition", fmt.Sprintf(`inline; filename*=utf-8"%s"`, responseMetadata.UploadName)) + } } w.Header().Set("Content-Type", string(responseMetadata.ContentType)) diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 1051e0e03..2cb0d8757 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -48,7 +48,7 @@ type uploadResponse struct { ContentURI string `json:"content_uri"` } -// Upload implements /upload +// Upload implements POST /upload // This endpoint involves uploading potentially significant amounts of data to the homeserver. // This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. // Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. @@ -75,13 +75,6 @@ func Upload(req *http.Request, cfg *config.Dendrite, db *storage.Database, activ // all the metadata about the media being uploaded. // Returns either an uploadRequest or an error formatted as a util.JSONResponse func parseAndValidateRequest(req *http.Request, cfg *config.Dendrite) (*uploadRequest, *util.JSONResponse) { - if req.Method != http.MethodPost { - return nil, &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.Unknown("HTTP request method must be POST."), - } - } - r := &uploadRequest{ MediaMetadata: &types.MediaMetadata{ Origin: cfg.Matrix.ServerName, diff --git a/publicroomsapi/directory/directory.go b/publicroomsapi/directory/directory.go index bb0153850..626a1c153 100644 --- a/publicroomsapi/directory/directory.go +++ b/publicroomsapi/directory/directory.go @@ -19,6 +19,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/publicroomsapi/storage" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -39,7 +40,7 @@ func GetVisibility( var v roomVisibility if isPublic { - v.Visibility = "public" + v.Visibility = gomatrixserverlib.Public } else { v.Visibility = "private" } @@ -61,7 +62,7 @@ func SetVisibility( return *reqErr } - isPublic := v.Visibility == "public" + isPublic := v.Visibility == gomatrixserverlib.Public if err := publicRoomsDatabase.SetRoomVisibility(req.Context(), isPublic, roomID); err != nil { return httputil.LogThenError(req, err) } diff --git a/publicroomsapi/directory/public_rooms.go b/publicroomsapi/directory/public_rooms.go index 100e28e9b..ef7b2662e 100644 --- a/publicroomsapi/directory/public_rooms.go +++ b/publicroomsapi/directory/public_rooms.go @@ -42,8 +42,8 @@ type publicRoomRes struct { Estimate int64 `json:"total_room_count_estimate,omitempty"` } -// GetPublicRooms implements GET /publicRooms -func GetPublicRooms( +// GetPostPublicRooms implements GET and POST /publicRooms +func GetPostPublicRooms( req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase, ) util.JSONResponse { var limit int16 @@ -89,6 +89,7 @@ func GetPublicRooms( // fillPublicRoomsReq fills the Limit, Since and Filter attributes of a GET or POST request // on /publicRooms by parsing the incoming HTTP request +// Filter is only filled for POST requests func fillPublicRoomsReq(httpReq *http.Request, request *publicRoomReq) *util.JSONResponse { if httpReq.Method == http.MethodGet { limit, err := strconv.Atoi(httpReq.FormValue("limit")) diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go index 3a1c9eb58..422414bc2 100644 --- a/publicroomsapi/routing/routing.go +++ b/publicroomsapi/routing/routing.go @@ -64,7 +64,7 @@ func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storag ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/publicRooms", common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { - return directory.GetPublicRooms(req, publicRoomsDB) + return directory.GetPostPublicRooms(req, publicRoomsDB) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } diff --git a/publicroomsapi/storage/storage.go b/publicroomsapi/storage/storage.go index eab27041b..aa9806945 100644 --- a/publicroomsapi/storage/storage.go +++ b/publicroomsapi/storage/storage.go @@ -185,7 +185,7 @@ func (d *PublicRoomsServerDatabase) updateNumJoinedUsers( return err } - if membership != "join" { + if membership != gomatrixserverlib.Join { return nil } diff --git a/roomserver/alias/alias.go b/roomserver/alias/alias.go index 6a34aacdd..aeaf5ae94 100644 --- a/roomserver/alias/alias.go +++ b/roomserver/alias/alias.go @@ -33,13 +33,16 @@ import ( type RoomserverAliasAPIDatabase interface { // Save a given room alias with the room ID it refers to. // Returns an error if there was a problem talking to the database. - SetRoomAlias(ctx context.Context, alias string, roomID string) error + SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error // Look up the room ID a given alias refers to. // Returns an error if there was a problem talking to the database. GetRoomIDForAlias(ctx context.Context, alias string) (string, error) // Look up all aliases referring to a given room ID. // Returns an error if there was a problem talking to the database. GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + // Get the user ID of the creator of an alias. + // Returns an error if there was a problem talking to the database. + GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) // Remove a given room alias. // Returns an error if there was a problem talking to the database. RemoveRoomAlias(ctx context.Context, alias string) error @@ -73,7 +76,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias( response.AliasExists = false // Save the new alias - if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil { + if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil { return err } @@ -133,6 +136,22 @@ func (r *RoomserverAliasAPI) GetAliasesForRoomID( return nil } +// GetCreatorIDForAlias implements alias.RoomserverAliasAPI +func (r *RoomserverAliasAPI) GetCreatorIDForAlias( + ctx context.Context, + request *roomserverAPI.GetCreatorIDForAliasRequest, + response *roomserverAPI.GetCreatorIDForAliasResponse, +) error { + // Look up the aliases in the database for the given RoomID + creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) + if err != nil { + return err + } + + response.UserID = creatorID + return nil +} + // RemoveRoomAlias implements alias.RoomserverAliasAPI func (r *RoomserverAliasAPI) RemoveRoomAlias( ctx context.Context, @@ -277,6 +296,34 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + servMux.Handle( + roomserverAPI.RoomserverGetCreatorIDForAliasPath, + common.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse { + var request roomserverAPI.GetCreatorIDForAliasRequest + var response roomserverAPI.GetCreatorIDForAliasResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + servMux.Handle( + roomserverAPI.RoomserverGetAliasesForRoomIDPath, + common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { + var request roomserverAPI.GetAliasesForRoomIDRequest + var response roomserverAPI.GetAliasesForRoomIDResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) servMux.Handle( roomserverAPI.RoomserverRemoveRoomAliasPath, common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/alias/alias_test.go b/roomserver/alias/alias_test.go index 4b9ca022d..6ddb63a73 100644 --- a/roomserver/alias/alias_test.go +++ b/roomserver/alias/alias_test.go @@ -30,7 +30,7 @@ type MockRoomserverAliasAPIDatabase struct { } // These methods can be essentially noop -func (db MockRoomserverAliasAPIDatabase) SetRoomAlias(ctx context.Context, alias string, roomID string) error { +func (db MockRoomserverAliasAPIDatabase) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return nil } @@ -43,6 +43,12 @@ func (db MockRoomserverAliasAPIDatabase) RemoveRoomAlias(ctx context.Context, al return nil } +func (db *MockRoomserverAliasAPIDatabase) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return "", nil +} + // This method needs to change depending on test case func (db *MockRoomserverAliasAPIDatabase) GetRoomIDForAlias( ctx context.Context, diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 576710713..cb78f726a 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -62,6 +62,18 @@ type GetAliasesForRoomIDResponse struct { Aliases []string `json:"aliases"` } +// GetCreatorIDForAliasRequest is a request to GetCreatorIDForAlias +type GetCreatorIDForAliasRequest struct { + // The alias we want to find the creator of + Alias string `json:"alias"` +} + +// GetCreatorIDForAliasResponse is a response to GetCreatorIDForAlias +type GetCreatorIDForAliasResponse struct { + // The user ID of the alias creator + UserID string `json:"user_id"` +} + // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias @@ -96,6 +108,13 @@ type RoomserverAliasAPI interface { response *GetAliasesForRoomIDResponse, ) error + // Get the user ID of the creator of an alias + GetCreatorIDForAlias( + ctx context.Context, + req *GetCreatorIDForAliasRequest, + response *GetCreatorIDForAliasResponse, + ) error + // Remove a room alias RemoveRoomAlias( ctx context.Context, @@ -113,6 +132,9 @@ const RoomserverGetRoomIDForAliasPath = "/api/roomserver/GetRoomIDForAlias" // RoomserverGetAliasesForRoomIDPath is the HTTP path for the GetAliasesForRoomID API. const RoomserverGetAliasesForRoomIDPath = "/api/roomserver/GetAliasesForRoomID" +// RoomserverGetCreatorIDForAliasPath is the HTTP path for the GetCreatorIDForAlias API. +const RoomserverGetCreatorIDForAliasPath = "/api/roomserver/GetCreatorIDForAlias" + // RoomserverRemoveRoomAliasPath is the HTTP path for the RemoveRoomAlias API. const RoomserverRemoveRoomAliasPath = "/api/roomserver/removeRoomAlias" @@ -169,6 +191,19 @@ func (h *httpRoomserverAliasAPI) GetAliasesForRoomID( return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +// GetCreatorIDForAlias implements RoomserverAliasAPI +func (h *httpRoomserverAliasAPI) GetCreatorIDForAlias( + ctx context.Context, + request *GetCreatorIDForAliasRequest, + response *GetCreatorIDForAliasResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetCreatorIDForAlias") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverGetCreatorIDForAliasPath + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + // RemoveRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverAliasAPI) RemoveRoomAlias( ctx context.Context, diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index 2dce6f6dc..5ff1fadad 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -23,7 +23,7 @@ func IsServerAllowed( ) bool { for _, ev := range authEvents { membership, err := ev.Membership() - if err != nil || membership != "join" { + if err != nil || membership != gomatrixserverlib.Join { continue } diff --git a/roomserver/input/membership.go b/roomserver/input/membership.go index 0c3fbb80a..841c5fec6 100644 --- a/roomserver/input/membership.go +++ b/roomserver/input/membership.go @@ -23,13 +23,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// Membership values -// TODO: Factor these out somewhere sensible? -const join = "join" -const leave = "leave" -const invite = "invite" -const ban = "ban" - // updateMembership updates the current membership and the invites for each // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the @@ -91,8 +84,8 @@ func updateMembership( ) ([]api.OutputEvent, error) { var err error // Default the membership to Leave if no event was added or removed. - oldMembership := leave - newMembership := leave + oldMembership := gomatrixserverlib.Leave + newMembership := gomatrixserverlib.Leave if remove != nil { oldMembership, err = remove.Membership() @@ -106,7 +99,7 @@ func updateMembership( return nil, err } } - if oldMembership == newMembership && newMembership != join { + if oldMembership == newMembership && newMembership != gomatrixserverlib.Join { // If the membership is the same then nothing changed and we can return // immediately, unless it's a Join update (e.g. profile update). return updates, nil @@ -118,11 +111,11 @@ func updateMembership( } switch newMembership { - case invite: + case gomatrixserverlib.Invite: return updateToInviteMembership(mu, add, updates) - case join: + case gomatrixserverlib.Join: return updateToJoinMembership(mu, add, updates) - case leave, ban: + case gomatrixserverlib.Leave, gomatrixserverlib.Ban: return updateToLeaveMembership(mu, add, newMembership, updates) default: panic(fmt.Errorf( @@ -183,7 +176,7 @@ func updateToJoinMembership( for _, eventID := range retired { orie := api.OutputRetireInviteEvent{ EventID: eventID, - Membership: join, + Membership: gomatrixserverlib.Join, RetiredByEventID: add.EventID(), TargetUserID: *add.StateKey(), } diff --git a/roomserver/query/query.go b/roomserver/query/query.go index b97d50b17..a62a1f706 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -359,7 +359,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID( return nil, err } - if membership == "join" { + if membership == gomatrixserverlib.Join { events = append(events, event) } } diff --git a/roomserver/storage/room_aliases_table.go b/roomserver/storage/room_aliases_table.go index f640c37fe..3ed20e8e3 100644 --- a/roomserver/storage/room_aliases_table.go +++ b/roomserver/storage/room_aliases_table.go @@ -25,14 +25,16 @@ CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( -- Alias of the room alias TEXT NOT NULL PRIMARY KEY, -- Room ID the alias refers to - room_id TEXT NOT NULL + room_id TEXT NOT NULL, + -- User ID of the creator of this alias + creator_id TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); ` const insertRoomAliasSQL = "" + - "INSERT INTO roomserver_room_aliases (alias, room_id) VALUES ($1, $2)" + "INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)" const selectRoomIDFromAliasSQL = "" + "SELECT room_id FROM roomserver_room_aliases WHERE alias = $1" @@ -40,14 +42,18 @@ const selectRoomIDFromAliasSQL = "" + const selectAliasesFromRoomIDSQL = "" + "SELECT alias FROM roomserver_room_aliases WHERE room_id = $1" +const selectCreatorIDFromAliasSQL = "" + + "SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1" + const deleteRoomAliasSQL = "" + "DELETE FROM roomserver_room_aliases WHERE alias = $1" type roomAliasesStatements struct { - insertRoomAliasStmt *sql.Stmt - selectRoomIDFromAliasStmt *sql.Stmt - selectAliasesFromRoomIDStmt *sql.Stmt - deleteRoomAliasStmt *sql.Stmt + insertRoomAliasStmt *sql.Stmt + selectRoomIDFromAliasStmt *sql.Stmt + selectAliasesFromRoomIDStmt *sql.Stmt + selectCreatorIDFromAliasStmt *sql.Stmt + deleteRoomAliasStmt *sql.Stmt } func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { @@ -59,14 +65,15 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, + {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, }.prepare(db) } func (s *roomAliasesStatements) insertRoomAlias( - ctx context.Context, alias string, roomID string, + ctx context.Context, alias string, roomID string, creatorUserID string, ) (err error) { - _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID) + _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) return } @@ -101,6 +108,16 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID( return } +func (s *roomAliasesStatements) selectCreatorIDFromAlias( + ctx context.Context, alias string, +) (creatorID string, err error) { + err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + func (s *roomAliasesStatements) deleteRoomAlias( ctx context.Context, alias string, ) (err error) { diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index f6c2fccd4..71c13b7ca 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -441,8 +441,8 @@ func (d *Database) GetInvitesForUser( } // SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID) +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) } // GetRoomIDForAlias implements alias.RoomserverAliasAPIDB @@ -455,6 +455,13 @@ func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]st return d.statements.selectAliasesFromRoomID(ctx, roomID) } +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.statements.selectCreatorIDFromAlias(ctx, alias) +} + // RemoveRoomAlias implements alias.RoomserverAliasAPIDB func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { return d.statements.deleteRoomAlias(ctx, alias) diff --git a/scripts/find-lint.sh b/scripts/find-lint.sh index 6511272b2..25b311f94 100755 --- a/scripts/find-lint.sh +++ b/scripts/find-lint.sh @@ -22,7 +22,15 @@ then args="--fast" fi echo "Installing golangci-lint..." + +# Make a backup of go.{mod,sum} first +# TODO: Once go 1.13 is out, use go get's -mod=readonly option +# https://github.com/golang/go/issues/30667 +cp go.mod go.mod.bak && cp go.sum go.sum.bak go get github.com/golangci/golangci-lint/cmd/golangci-lint echo "Looking for lint..." golangci-lint run $args + +# Restore go.{mod,sum} +mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index d05a76920..f0db56427 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" log "github.com/sirupsen/logrus" sarama "gopkg.in/Shopify/sarama.v1" ) @@ -29,7 +30,7 @@ import ( // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { clientAPIConsumer *common.ContinualConsumer - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource notifier *sync.Notifier } @@ -38,7 +39,7 @@ func NewOutputClientDataConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatabase, + store *storage.SyncServerDatasource, ) *OutputClientDataConsumer { consumer := common.ContinualConsumer{ @@ -78,7 +79,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error "room_id": output.RoomID, }).Info("received data from client API server") - syncStreamPos, err := s.db.UpsertAccountData( + pduPos, err := s.db.UpsertAccountData( context.TODO(), string(msg.Key), output.RoomID, output.Type, ) if err != nil { @@ -89,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, string(msg.Key), syncStreamPos) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1866a9667..e4f1ab460 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -33,7 +33,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource notifier *sync.Notifier query api.RoomserverQueryAPI } @@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatabase, + store *storage.SyncServerDatasource, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { @@ -126,7 +126,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } - syncStreamPos, err := s.db.WriteEvent( + pduPos, err := s.db.WriteEvent( ctx, &ev, addsStateEvents, @@ -144,7 +144,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", types.StreamPosition(syncStreamPos)) + s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos}) return nil } @@ -152,7 +152,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent( ctx context.Context, msg api.OutputNewInviteEvent, ) error { - syncStreamPos, err := s.db.AddInviteEvent(ctx, msg.Event) + pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) if err != nil { // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ @@ -161,7 +161,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", syncStreamPos) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/typingserver.go b/syncapi/consumers/typingserver.go new file mode 100644 index 000000000..5d998a18a --- /dev/null +++ b/syncapi/consumers/typingserver.go @@ -0,0 +1,96 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "encoding/json" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/typingserver/api" + log "github.com/sirupsen/logrus" + sarama "gopkg.in/Shopify/sarama.v1" +) + +// OutputTypingEventConsumer consumes events that originated in the typing server. +type OutputTypingEventConsumer struct { + typingConsumer *common.ContinualConsumer + db *storage.SyncServerDatasource + notifier *sync.Notifier +} + +// NewOutputTypingEventConsumer creates a new OutputTypingEventConsumer. +// Call Start() to begin consuming from the typing server. +func NewOutputTypingEventConsumer( + cfg *config.Dendrite, + kafkaConsumer sarama.Consumer, + n *sync.Notifier, + store *storage.SyncServerDatasource, +) *OutputTypingEventConsumer { + + consumer := common.ContinualConsumer{ + Topic: string(cfg.Kafka.Topics.OutputTypingEvent), + Consumer: kafkaConsumer, + PartitionStore: store, + } + + s := &OutputTypingEventConsumer{ + typingConsumer: &consumer, + db: store, + notifier: n, + } + + consumer.ProcessMessage = s.onMessage + + return s +} + +// Start consuming from typing api +func (s *OutputTypingEventConsumer) Start() error { + s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { + s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition}) + }) + + return s.typingConsumer.Start() +} + +func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Value, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("typing server output log: message parse failure") + return nil + } + + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from typing server") + + var typingPos int64 + typingEvent := output.Event + if typingEvent.Typing { + typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) + } else { + typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) + } + + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos}) + return nil +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index cbdcfb6bb..0f5019fc3 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) { +func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatasource, deviceDB *devices.Database) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ diff --git a/syncapi/routing/state.go b/syncapi/routing/state.go index 6b98a0b7b..87a93d194 100644 --- a/syncapi/routing/state.go +++ b/syncapi/routing/state.go @@ -40,11 +40,14 @@ type stateEventInStateResp struct { // TODO: Check if the user is in the room. If not, check if the room's history // is publicly visible. Current behaviour is returning an empty array if the // user cannot see the room's history. -func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string) util.JSONResponse { +func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) - stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID) + stateFilterPart := gomatrixserverlib.DefaultFilterPart() + // TODO: stateFilterPart should not limit the number of state events (or only limits abusive number of events) + + stateEvents, err := db.GetStateEventsForRoom(req.Context(), roomID, &stateFilterPart) if err != nil { return httputil.LogThenError(req, err) } @@ -84,7 +87,7 @@ func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatabase, r // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // state to see if there is an event with that type and state key, if there // is then (by default) we return the content, otherwise a 404. -func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatabase, roomID string, evType, stateKey string) util.JSONResponse { +func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string, evType, stateKey string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) diff --git a/syncapi/storage/account_data_table.go b/syncapi/storage/account_data_table.go index d4d74d158..7b4803e3d 100644 --- a/syncapi/storage/account_data_table.go +++ b/syncapi/storage/account_data_table.go @@ -18,9 +18,9 @@ import ( "context" "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/common" - - "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -43,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( CONSTRAINT syncapi_account_data_unique UNIQUE (user_id, room_id, type) ); -CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account_data_type(id); +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_account_data_id_idx ON syncapi_account_data_type(id, type); ` const insertAccountDataSQL = "" + @@ -55,7 +55,9 @@ const insertAccountDataSQL = "" + const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " ORDER BY id ASC" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + + " ORDER BY id ASC LIMIT $6" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" @@ -94,7 +96,8 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos types.StreamPosition, + oldPos, newPos int64, + accountDataFilterPart *gomatrixserverlib.FilterPart, ) (data map[string][]string, err error) { data = make(map[string][]string) @@ -105,7 +108,11 @@ func (s *accountDataStatements) selectAccountDataInRange( oldPos-- } - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos, + pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)), + accountDataFilterPart.Limit, + ) if err != nil { return } diff --git a/syncapi/storage/current_room_state_table.go b/syncapi/storage/current_room_state_table.go index 852bfd760..88e7a76c3 100644 --- a/syncapi/storage/current_room_state_table.go +++ b/syncapi/storage/current_room_state_table.go @@ -17,6 +17,7 @@ package storage import ( "context" "database/sql" + "encoding/json" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" @@ -32,6 +33,10 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( event_id TEXT NOT NULL, -- The state event type e.g 'm.room.member' type TEXT NOT NULL, + -- The 'sender' property of the event. + sender TEXT NOT NULL, + -- true if the event content contains a url key + contains_url BOOL NOT NULL, -- The state_key value for this state event e.g '' state_key TEXT NOT NULL, -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. @@ -46,16 +51,16 @@ CREATE TABLE IF NOT EXISTS syncapi_current_room_state ( CONSTRAINT syncapi_room_state_unique UNIQUE (room_id, type, state_key) ); -- for event deletion -CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url); -- for querying membership states of users CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; ` const upsertRoomStateSQL = "" + - "INSERT INTO syncapi_current_room_state (room_id, event_id, type, state_key, event_json, membership, added_at)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" + + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, event_json, membership, added_at)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " ON CONFLICT ON CONSTRAINT syncapi_room_state_unique" + - " DO UPDATE SET event_id = $2, event_json = $5, membership = $6, added_at = $7" + " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, event_json = $7, membership = $8, added_at = $9" const deleteRoomStateByEventIDSQL = "" + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" @@ -64,7 +69,13 @@ const selectRoomIDsWithMembershipSQL = "" + "SELECT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" const selectCurrentStateSQL = "" + - "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" + "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" + + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + + " AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" + + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + + " AND ( $6::bool IS NULL OR contains_url = $6 )" + + " LIMIT $7" const selectJoinedUsersSQL = "" + "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'" @@ -166,9 +177,17 @@ func (s *currentRoomStateStatements) selectRoomIDsWithMembership( // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) selectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, + stateFilterPart *gomatrixserverlib.FilterPart, ) ([]gomatrixserverlib.Event, error) { stmt := common.TxStmt(txn, s.selectCurrentStateStmt) - rows, err := stmt.QueryContext(ctx, roomID) + rows, err := stmt.QueryContext(ctx, roomID, + pq.StringArray(stateFilterPart.Senders), + pq.StringArray(stateFilterPart.NotSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), + stateFilterPart.ContainsURL, + stateFilterPart.Limit, + ) if err != nil { return nil, err } @@ -189,12 +208,23 @@ func (s *currentRoomStateStatements) upsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.Event, membership *string, addedAt int64, ) error { + // Parse content as JSON and search for an "url" key + containsURL := false + var content map[string]interface{} + if json.Unmarshal(event.Content(), &content) != nil { + // Set containsURL to true if url is present + _, containsURL = content["url"] + } + + // upsert state event stmt := common.TxStmt(txn, s.upsertRoomStateStmt) _, err := stmt.ExecContext( ctx, event.RoomID(), event.EventID(), event.Type(), + event.Sender(), + containsURL, *event.StateKey(), event.JSON(), membership, diff --git a/syncapi/storage/filtering.go b/syncapi/storage/filtering.go new file mode 100644 index 000000000..27b0b888a --- /dev/null +++ b/syncapi/storage/filtering.go @@ -0,0 +1,36 @@ +// Copyright 2017 Thibaut CHARLES +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "strings" +) + +// filterConvertWildcardToSQL converts wildcards as defined in +// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter +// to SQL wildcards that can be used with LIKE() +func filterConvertTypeWildcardToSQL(values []string) []string { + if values == nil { + // Return nil instead of []string{} so IS NULL can work correctly when + // the return value is passed into SQL queries + return nil + } + + ret := make([]string, len(values)) + for i := range values { + ret[i] = strings.Replace(values[i], "*", "%", -1) + } + return ret +} diff --git a/syncapi/storage/invites_table.go b/syncapi/storage/invites_table.go index 88c98f7e3..9f52087f6 100644 --- a/syncapi/storage/invites_table.go +++ b/syncapi/storage/invites_table.go @@ -23,7 +23,7 @@ CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx -- For deleting old invites CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx - ON syncapi_invite_events(target_user_id, id); + ON syncapi_invite_events (event_id); ` const insertInviteEventSQL = "" + diff --git a/syncapi/storage/output_room_events_table.go b/syncapi/storage/output_room_events_table.go index 035db9882..8fbeb18c9 100644 --- a/syncapi/storage/output_room_events_table.go +++ b/syncapi/storage/output_room_events_table.go @@ -17,13 +17,13 @@ package storage import ( "context" "database/sql" + "encoding/json" "sort" "github.com/matrix-org/dendrite/roomserver/api" "github.com/lib/pq" "github.com/matrix-org/dendrite/common" - "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -44,6 +44,12 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( room_id TEXT NOT NULL, -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. event_json TEXT NOT NULL, + -- The event type e.g 'm.room.member'. + type TEXT NOT NULL, + -- The 'sender' property of the event. + sender TEXT NOT NULL, + -- true if the event content contains a url key. + contains_url BOOL NOT NULL, -- A list of event IDs which represent a delta of added/removed room state. This can be NULL -- if there is no delta. add_state_ids TEXT[], @@ -57,8 +63,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - " room_id, event_id, event_json, add_state_ids, remove_state_ids, device_id, transaction_id" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" const selectEventsSQL = "" + "SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)" @@ -76,7 +82,13 @@ const selectStateInRangeSQL = "" + "SELECT id, event_json, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + - " ORDER BY id ASC" + " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + + " AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" + + " AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" + + " AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" + + " AND ( $7::bool IS NULL OR contains_url = $7 )" + + " ORDER BY id ASC" + + " LIMIT $8" type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt @@ -109,15 +121,24 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { return } -// selectStateInRange returns the state events between the two given stream positions, exclusive of oldPos, inclusive of newPos. +// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, + ctx context.Context, txn *sql.Tx, oldPos, newPos int64, + stateFilterPart *gomatrixserverlib.FilterPart, ) (map[string]map[string]bool, map[string]streamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) - rows, err := stmt.QueryContext(ctx, oldPos, newPos) + rows, err := stmt.QueryContext( + ctx, oldPos, newPos, + pq.StringArray(stateFilterPart.Senders), + pq.StringArray(stateFilterPart.NotSenders), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), + pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)), + stateFilterPart.ContainsURL, + stateFilterPart.Limit, + ) if err != nil { return nil, nil, err } @@ -171,7 +192,7 @@ func (s *outputRoomEventsStatements) selectStateInRange( eventIDToEvent[ev.EventID()] = streamEvent{ Event: ev, - streamPosition: types.StreamPosition(streamPos), + streamPosition: streamPos, } } @@ -206,12 +227,23 @@ func (s *outputRoomEventsStatements) insertEvent( txnID = &transactionID.TransactionID } + // Parse content as JSON and search for an "url" key + containsURL := false + var content map[string]interface{} + if json.Unmarshal(event.Content(), &content) != nil { + // Set containsURL to true if url is present + _, containsURL = content["url"] + } + stmt := common.TxStmt(txn, s.insertEventStmt) err = stmt.QueryRowContext( ctx, event.RoomID(), event.EventID(), event.JSON(), + event.Type(), + event.Sender(), + containsURL, pq.StringArray(addState), pq.StringArray(removeState), deviceID, @@ -223,7 +255,7 @@ func (s *outputRoomEventsStatements) insertEvent( // RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. func (s *outputRoomEventsStatements) selectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos types.StreamPosition, limit int, + roomID string, fromPos, toPos int64, limit int, ) ([]streamEvent, error) { stmt := common.TxStmt(txn, s.selectRecentEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) @@ -236,7 +268,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents( return nil, err } // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to + // necessarily the way the SQL query returns them, so a sort is necessary to // ensure the events are in the right order in the slice. sort.SliceStable(events, func(i int, j int) bool { return events[i].streamPosition < events[j].streamPosition @@ -286,7 +318,7 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { result = append(result, streamEvent{ Event: ev, - streamPosition: types.StreamPosition(streamPos), + streamPosition: streamPos, transactionID: transactionID, }) } diff --git a/syncapi/storage/syncserver.go b/syncapi/storage/syncserver.go index b0655a0a8..fb883702c 100644 --- a/syncapi/storage/syncserver.go +++ b/syncapi/storage/syncserver.go @@ -17,7 +17,10 @@ package storage import ( "context" "database/sql" + "encoding/json" "fmt" + "strconv" + "time" "github.com/sirupsen/logrus" @@ -28,6 +31,7 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/typingserver/cache" "github.com/matrix-org/gomatrixserverlib" ) @@ -35,33 +39,35 @@ type stateDelta struct { roomID string stateEvents []gomatrixserverlib.Event membership string - // The stream position of the latest membership event for this user, if applicable. + // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. - membershipPos types.StreamPosition + membershipPos int64 } -// Same as gomatrixserverlib.Event but also has the stream position for this event. +// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. type streamEvent struct { gomatrixserverlib.Event - streamPosition types.StreamPosition + streamPosition int64 transactionID *api.TransactionID } -// SyncServerDatabase represents a sync server database -type SyncServerDatabase struct { +// SyncServerDatabase represents a sync server datasource which manages +// both the database for PDUs and caches for EDUs. +type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements invites inviteEventsStatements + typingCache *cache.TypingCache } // NewSyncServerDatabase creates a new sync server database -func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { - var d SyncServerDatabase +func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { + var d SyncServerDatasource var err error - if d.db, err = sql.Open("postgres", dataSourceName); err != nil { + if d.db, err = sql.Open("postgres", dbDataSourceName); err != nil { return nil, err } if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { @@ -79,11 +85,12 @@ func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) { if err := d.invites.prepare(d.db); err != nil { return nil, err } + d.typingCache = cache.NewTypingCache() return &d, nil } // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. -func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { +func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { return d.roomstate.selectJoinedUsers(ctx) } @@ -92,7 +99,7 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str // If an event is not found in the database then it will be omitted from the list. // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. -func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) if err != nil { return nil, err @@ -104,38 +111,38 @@ func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]g } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races -// when generating the stream position for this event. Returns the sync stream position for the inserted event. +// when generating the sync stream position for this event. Returns the sync stream position for the inserted event. // Returns an error if there was a problem inserting this event. -func (d *SyncServerDatabase) WriteEvent( +func (d *SyncServerDatasource) WriteEvent( ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, -) (streamPos types.StreamPosition, returnErr error) { +) (pduPosition int64, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) if err != nil { return err } - streamPos = types.StreamPosition(pos) + pduPosition = pos if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil } - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, streamPos) + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) }) return } -func (d *SyncServerDatabase) updateRoomState( +func (d *SyncServerDatasource) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, - streamPos types.StreamPosition, + pduPosition int64, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { @@ -157,7 +164,7 @@ func (d *SyncServerDatabase) updateRoomState( } membership = &value } - if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, int64(streamPos)); err != nil { + if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { return err } } @@ -168,7 +175,7 @@ func (d *SyncServerDatabase) updateRoomState( // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error -func (d *SyncServerDatabase) GetStateEvent( +func (d *SyncServerDatasource) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.Event, error) { return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) @@ -177,56 +184,60 @@ func (d *SyncServerDatabase) GetStateEvent( // GetStateEventsForRoom fetches the state events for a given room. // Returns an empty slice if no state events could be found for this room. // Returns an error if there was an issue with the retrieval. -func (d *SyncServerDatabase) GetStateEventsForRoom( - ctx context.Context, roomID string, +func (d *SyncServerDatasource) GetStateEventsForRoom( + ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.FilterPart, ) (stateEvents []gomatrixserverlib.Event, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) + stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) return err }) return } -// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatabase) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) +// SyncPosition returns the latest positions for syncing. +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { + return d.syncPositionTx(ctx, nil) } -func (d *SyncServerDatabase) syncStreamPositionTx( +func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (types.StreamPosition, error) { - maxID, err := d.events.selectMaxEventID(ctx, txn) +) (sp types.SyncPosition, err error) { + + maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { - return 0, err + return sp, err } maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) if err != nil { - return 0, err + return sp, err } - if maxAccountDataID > maxID { - maxID = maxAccountDataID + if maxAccountDataID > maxEventID { + maxEventID = maxAccountDataID } maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) if err != nil { - return 0, err + return sp, err } - if maxInviteID > maxID { - maxID = maxInviteID + if maxInviteID > maxEventID { + maxEventID = maxInviteID } - return types.StreamPosition(maxID), nil + sp.PDUPosition = maxEventID + + sp.TypingPosition = d.typingCache.GetLatestSyncPosition() + + return } -// IncrementalSync returns all the data needed in order to create an incremental -// sync response for the given user. Events returned will include any client -// transaction IDs associated with the given device. These transaction IDs come -// from when the device sent the event via an API that included a transaction -// ID. -func (d *SyncServerDatabase) IncrementalSync( +// addPDUDeltaToResponse adds all PDU deltas to a sync response. +// IDs of all rooms the user joined are returned so EDU deltas can be added for them. +func (d *SyncServerDatasource) addPDUDeltaToResponse( ctx context.Context, device authtypes.Device, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, numRecentEventsPerRoom int, -) (*types.Response, error) { + wantFullState bool, + res *types.Response, +) ([]string, error) { txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { return nil, err @@ -234,16 +245,27 @@ func (d *SyncServerDatabase) IncrementalSync( var succeeded bool defer common.EndTransaction(txn, &succeeded) + stateFilterPart := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in request + // Work out which rooms to return in the response. This is done by getting not only the currently - // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. + // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions. // This works out what the 'state' key should be for each room as well as which membership block // to put the room into. - deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID) + var deltas []stateDelta + var joinedRoomIDs []string + if !wantFullState { + deltas, joinedRoomIDs, err = d.getStateDeltas( + ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, + ) + } else { + deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( + ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart, + ) + } if err != nil { return nil, err } - res := types.NewResponse(toPos) for _, delta := range deltas { err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) if err != nil { @@ -257,52 +279,154 @@ func (d *SyncServerDatabase) IncrementalSync( } succeeded = true + return joinedRoomIDs, nil +} + +// addTypingDeltaToResponse adds all typing notifications to a sync response +// since the specified position. +func (d *SyncServerDatasource) addTypingDeltaToResponse( + since int64, + joinedRoomIDs []string, + res *types.Response, +) error { + var jr types.JoinResponse + var ok bool + var err error + for _, roomID := range joinedRoomIDs { + if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( + roomID, since, + ); updated { + ev := gomatrixserverlib.ClientEvent{ + Type: gomatrixserverlib.MTyping, + } + ev.Content, err = json.Marshal(map[string]interface{}{ + "user_ids": typingUsers, + }) + if err != nil { + return err + } + + if jr, ok = res.Rooms.Join[roomID]; !ok { + jr = *types.NewJoinResponse() + } + jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev) + res.Rooms.Join[roomID] = jr + } + } + return nil +} + +// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if +// the positions of that type are not equal in fromPos and toPos. +func (d *SyncServerDatasource) addEDUDeltaToResponse( + fromPos, toPos types.SyncPosition, + joinedRoomIDs []string, + res *types.Response, +) (err error) { + + if fromPos.TypingPosition != toPos.TypingPosition { + err = d.addTypingDeltaToResponse( + fromPos.TypingPosition, joinedRoomIDs, res, + ) + } + + return +} + +// IncrementalSync returns all the data needed in order to create an incremental +// sync response for the given user. Events returned will include any client +// transaction IDs associated with the given device. These transaction IDs come +// from when the device sent the event via an API that included a transaction +// ID. +func (d *SyncServerDatasource) IncrementalSync( + ctx context.Context, + device authtypes.Device, + fromPos, toPos types.SyncPosition, + numRecentEventsPerRoom int, + wantFullState bool, +) (*types.Response, error) { + nextBatchPos := fromPos.WithUpdates(toPos) + res := types.NewResponse(nextBatchPos) + + var joinedRoomIDs []string + var err error + if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { + joinedRoomIDs, err = d.addPDUDeltaToResponse( + ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, + ) + } else { + joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( + ctx, nil, device.UserID, gomatrixserverlib.Join, + ) + } + if err != nil { + return nil, err + } + + err = d.addEDUDeltaToResponse( + fromPos, toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + return res, nil } -// CompleteSync a complete /sync API response for the given user. -func (d *SyncServerDatabase) CompleteSync( - ctx context.Context, userID string, numRecentEventsPerRoom int, -) (*types.Response, error) { +// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed +// to it. It returns toPos and joinedRoomIDs for use of adding EDUs. +func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( + ctx context.Context, + userID string, + numRecentEventsPerRoom int, +) ( + res *types.Response, + toPos types.SyncPosition, + joinedRoomIDs []string, + err error, +) { // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have - // a consistent view of the database throughout. This includes extracting the sync stream position. + // a consistent view of the database throughout. This includes extracting the sync position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) if err != nil { - return nil, err + return } var succeeded bool defer common.EndTransaction(txn, &succeeded) - // Get the current stream position which we will base the sync response on. - pos, err := d.syncStreamPositionTx(ctx, txn) + // Get the current sync position which we will base the sync response on. + toPos, err = d.syncPositionTx(ctx, txn) if err != nil { - return nil, err + return } + res = types.NewResponse(toPos) + // Extract room state and recent events for all rooms the user is joined to. - roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") + joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { - return nil, err + return } + stateFilterPart := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in request + // Build up a /sync response. Add joined rooms. - res := types.NewResponse(pos) - for _, roomID := range roomIDs { + for _, roomID := range joinedRoomIDs { var stateEvents []gomatrixserverlib.Event - stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) + stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart) if err != nil { - return nil, err + return } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []streamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, + ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ) if err != nil { - return nil, err + return } // We don't include a device here as we don't need to send down @@ -311,10 +435,12 @@ func (d *SyncServerDatabase) CompleteSync( stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() + if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) } else { - jr.Timeline.PrevBatch = types.StreamPosition(1).String() + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = "1" } jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true @@ -322,12 +448,34 @@ func (d *SyncServerDatabase) CompleteSync( res.Rooms.Join[roomID] = *jr } - if err = d.addInvitesToResponse(ctx, txn, userID, 0, pos, res); err != nil { - return nil, err + if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { + return } succeeded = true - return res, err + return res, toPos, joinedRoomIDs, err +} + +// CompleteSync returns a complete /sync API response for the given user. +func (d *SyncServerDatasource) CompleteSync( + ctx context.Context, userID string, numRecentEventsPerRoom int, +) (*types.Response, error) { + res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( + ctx, userID, numRecentEventsPerRoom, + ) + if err != nil { + return nil, err + } + + // Use a zero value SyncPosition for fromPos so all EDU states are added. + err = d.addEDUDeltaToResponse( + types.SyncPosition{}, toPos, joinedRoomIDs, res, + ) + if err != nil { + return nil, err + } + + return res, nil } var txReadOnlySnapshot = sql.TxOptions{ @@ -345,10 +493,11 @@ var txReadOnlySnapshot = sql.TxOptions{ // Returns a map following the format data[roomID] = []dataTypes // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error -func (d *SyncServerDatabase) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, +func (d *SyncServerDatasource) GetAccountDataInRange( + ctx context.Context, userID string, oldPos, newPos int64, + accountDataFilterPart *gomatrixserverlib.FilterPart, ) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos) + return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) } // UpsertAccountData keeps track of new or updated account data, by saving the type @@ -357,26 +506,24 @@ func (d *SyncServerDatabase) GetAccountDataInRange( // If no data with the given type, user ID and room ID exists in the database, // creates a new row, else update the existing one // Returns an error if there was an issue with the upsert -func (d *SyncServerDatabase) UpsertAccountData( +func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, -) (types.StreamPosition, error) { - pos, err := d.accountData.insertAccountData(ctx, userID, roomID, dataType) - return types.StreamPosition(pos), err +) (int64, error) { + return d.accountData.insertAccountData(ctx, userID, roomID, dataType) } // AddInviteEvent stores a new invite event for a user. // If the invite was successfully stored this returns the stream ID it was stored at. // Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatabase) AddInviteEvent( +func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (types.StreamPosition, error) { - pos, err := d.invites.insertInviteEvent(ctx, inviteEvent) - return types.StreamPosition(pos), err +) (int64, error) { + return d.invites.insertInviteEvent(ctx, inviteEvent) } // RetireInviteEvent removes an old invite event from the database. // Returns an error if there was a problem communicating with the database. -func (d *SyncServerDatabase) RetireInviteEvent( +func (d *SyncServerDatasource) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { // TODO: Record that invite has been retired in a stream so that we can @@ -385,10 +532,30 @@ func (d *SyncServerDatabase) RetireInviteEvent( return err } -func (d *SyncServerDatabase) addInvitesToResponse( +func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { + d.typingCache.SetTimeoutCallback(fn) +} + +// AddTypingUser adds a typing user to the typing cache. +// Returns the newly calculated sync position for typing notifications. +func (d *SyncServerDatasource) AddTypingUser( + userID, roomID string, expireTime *time.Time, +) int64 { + return d.typingCache.AddTypingUser(userID, roomID, expireTime) +} + +// RemoveTypingUser removes a typing user from the typing cache. +// Returns the newly calculated sync position for typing notifications. +func (d *SyncServerDatasource) RemoveTypingUser( + userID, roomID string, +) int64 { + return d.typingCache.RemoveUser(userID, roomID) +} + +func (d *SyncServerDatasource) addInvitesToResponse( ctx context.Context, txn *sql.Tx, userID string, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, res *types.Response, ) error { invites, err := d.invites.selectInviteEventsInRange( @@ -409,17 +576,17 @@ func (d *SyncServerDatabase) addInvitesToResponse( } // addRoomDeltaToResponse adds a room state delta to a sync response -func (d *SyncServerDatabase) addRoomDeltaToResponse( +func (d *SyncServerDatasource) addRoomDeltaToResponse( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, + fromPos, toPos int64, delta stateDelta, numRecentEventsPerRoom int, res *types.Response, ) error { endPos := toPos - if delta.membershipPos > 0 && delta.membership == "leave" { + if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave { // make sure we don't leak recent events after the leave event. // TODO: History visibility makes this somewhat complex to handle correctly. For example: // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). @@ -437,34 +604,42 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( recentEvents := streamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - // Don't bother appending empty room entries - if len(recentEvents) == 0 && len(delta.stateEvents) == 0 { - return nil + var prevPDUPos int64 + + if len(recentEvents) == 0 { + if len(delta.stateEvents) == 0 { + // Don't bother appending empty room entries + return nil + } + + // If full_state=true and since is already up to date, then we'll have + // state events but no recent events. + prevPDUPos = toPos - 1 + } else { + prevPDUPos = recentStreamEvents[0].streamPosition - 1 + } + + if prevPDUPos <= 0 { + prevPDUPos = 1 } switch delta.membership { - case "join": + case gomatrixserverlib.Join: jr := types.NewJoinResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - jr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() - } else { - jr.Timeline.PrevBatch = types.StreamPosition(1).String() - } + // Use the short form of batch token for prev_batch + jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[delta.roomID] = *jr - case "leave": + case gomatrixserverlib.Leave: fallthrough // transitions to leave are the same as ban - case "ban": + case gomatrixserverlib.Ban: // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - if prevBatch := recentStreamEvents[0].streamPosition - 1; prevBatch > 0 { - lr.Timeline.PrevBatch = types.StreamPosition(prevBatch).String() - } else { - lr.Timeline.PrevBatch = types.StreamPosition(1).String() - } + // Use the short form of batch token for prev_batch + lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -476,7 +651,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // Returns a map of room ID to list of events. -func (d *SyncServerDatabase) fetchStateEvents( +func (d *SyncServerDatasource) fetchStateEvents( ctx context.Context, txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, eventIDToEvent map[string]streamEvent, @@ -521,7 +696,7 @@ func (d *SyncServerDatabase) fetchStateEvents( return stateBetween, nil } -func (d *SyncServerDatabase) fetchMissingStateEvents( +func (d *SyncServerDatasource) fetchMissingStateEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]streamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the @@ -560,10 +735,15 @@ func (d *SyncServerDatabase) fetchMissingStateEvents( return events, nil } -func (d *SyncServerDatabase) getStateDeltas( +// getStateDeltas returns the state deltas between fromPos and toPos, +// exclusive of oldPos, inclusive of newPos, for the rooms in which +// the user has new membership events. +// A list of joined room IDs is also returned in case the caller needs it. +func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos types.StreamPosition, userID string, -) ([]stateDelta, error) { + fromPos, toPos int64, userID string, + stateFilterPart *gomatrixserverlib.FilterPart, +) ([]stateDelta, []string, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response // - For each room which has membership list changes: @@ -575,13 +755,13 @@ func (d *SyncServerDatabase) getStateDeltas( var deltas []stateDelta // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos) + stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) if err != nil { - return nil, err + return nil, nil, err } state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) if err != nil { - return nil, err + return nil, nil, err } for roomID, stateStreamEvents := range state { @@ -592,16 +772,12 @@ func (d *SyncServerDatabase) getStateDeltas( // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to // the timeline. if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { - if membership == "join" { + if membership == gomatrixserverlib.Join { // send full room state down instead of a delta - var allState []gomatrixserverlib.Event - allState, err = d.roomstate.selectCurrentState(ctx, txn, roomID) + var s []streamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) if err != nil { - return nil, err - } - s := make([]streamEvent, len(allState)) - for i := 0; i < len(s); i++ { - s[i] = streamEvent{Event: allState[i], streamPosition: types.StreamPosition(0)} + return nil, nil, err } state[roomID] = s continue // we'll add this room in when we do joined rooms @@ -619,19 +795,94 @@ func (d *SyncServerDatabase) getStateDeltas( } // Add in currently joined rooms - joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") + joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) if err != nil { - return nil, err + return nil, nil, err } for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ - membership: "join", + membership: gomatrixserverlib.Join, stateEvents: streamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } - return deltas, nil + return deltas, joinedRoomIDs, nil +} + +// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// requests with full_state=true. +// Fetches full state for all joined rooms and uses selectStateInRange to get +// updates for other rooms. +func (d *SyncServerDatasource) getStateDeltasForFullStateSync( + ctx context.Context, device *authtypes.Device, txn *sql.Tx, + fromPos, toPos int64, userID string, + stateFilterPart *gomatrixserverlib.FilterPart, +) ([]stateDelta, []string, error) { + joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) + if err != nil { + return nil, nil, err + } + + // Use a reasonable initial capacity + deltas := make([]stateDelta, 0, len(joinedRoomIDs)) + + // Add full states for all joined rooms + for _, joinedRoomID := range joinedRoomIDs { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart) + if stateErr != nil { + return nil, nil, stateErr + } + deltas = append(deltas, stateDelta{ + membership: gomatrixserverlib.Join, + stateEvents: streamEventsToEvents(device, s), + roomID: joinedRoomID, + }) + } + + // Get all the state events ever between these two positions + stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart) + if err != nil { + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) + if err != nil { + return nil, nil, err + } + + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { + if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + deltas = append(deltas, stateDelta{ + membership: membership, + membershipPos: ev.streamPosition, + stateEvents: streamEventsToEvents(device, stateStreamEvents), + roomID: roomID, + }) + } + + break + } + } + } + + return deltas, joinedRoomIDs, nil +} + +func (d *SyncServerDatasource) currentStateStreamEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, + stateFilterPart *gomatrixserverlib.FilterPart, +) ([]streamEvent, error) { + allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) + if err != nil { + return nil, err + } + s := make([]streamEvent, len(allState)) + for i := 0; i < len(s); i++ { + s[i] = streamEvent{Event: allState[i], streamPosition: 0} + } + return s, nil } // streamEventsToEvents converts streamEvent to Event. If device is non-nil and diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 5ed701d8e..15d6b070c 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -26,7 +26,7 @@ import ( ) // Notifier will wake up sleeping requests when there is some new data. -// It does not tell requests what that data is, only the stream position which +// It does not tell requests what that data is, only the sync position which // they can use to get at it. This is done to prevent races whereby we tell the caller // the event, but the token has already advanced by the time they fetch it, resulting // in missed events. @@ -35,18 +35,18 @@ type Notifier struct { roomIDToJoinedUsers map[string]userIDSet // Protects currPos and userStreams. streamLock *sync.Mutex - // The latest sync stream position - currPos types.StreamPosition + // The latest sync position + currPos types.SyncPosition // A map of user_id => UserStream which can be used to wake a given user's /sync request. userStreams map[string]*UserStream // The last time we cleaned out stale entries from the userStreams map lastCleanUpTime time.Time } -// NewNotifier creates a new notifier set to the given stream position. +// NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.StreamPosition) *Notifier { +func NewNotifier(pos types.SyncPosition) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), @@ -58,20 +58,30 @@ func NewNotifier(pos types.StreamPosition) *Notifier { // OnNewEvent is called when a new event is received from the room server. Must only be // called from a single goroutine, to avoid races between updates which could set the -// current position in the stream incorrectly. -// Can be called either with a *gomatrixserverlib.Event, or with an user ID -func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos types.StreamPosition) { +// current sync position incorrectly. +// Chooses which user sync streams to update by a provided *gomatrixserverlib.Event +// (based on the users in the event's room), +// a roomID directly, or a list of user IDs, prioritised by parameter ordering. +// posUpdate contains the latest position(s) for one or more types of events. +// If a position in posUpdate is 0, it means no updates are available of that type. +// Typically a consumer supplies a posUpdate with the latest sync position for the +// event type it handles, leaving other fields as 0. +func (n *Notifier) OnNewEvent( + ev *gomatrixserverlib.Event, roomID string, userIDs []string, + posUpdate types.SyncPosition, +) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. n.streamLock.Lock() defer n.streamLock.Unlock() - n.currPos = pos + latestPos := n.currPos.WithUpdates(posUpdate) + n.currPos = latestPos n.removeEmptyUserStreams() if ev != nil { // Map this event's room_id to a list of joined users, and wake them up. - userIDs := n.joinedUsers(ev.RoomID()) + usersToNotify := n.joinedUsers(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { targetUserID := *ev.StateKey() @@ -83,26 +93,30 @@ func (n *Notifier) OnNewEvent(ev *gomatrixserverlib.Event, userID string, pos ty } else { // Keep the joined user map up-to-date switch membership { - case "invite": - userIDs = append(userIDs, targetUserID) - case "join": + case gomatrixserverlib.Invite: + usersToNotify = append(usersToNotify, targetUserID) + case gomatrixserverlib.Join: // Manually append the new user's ID so they get notified // along all members in the room - userIDs = append(userIDs, targetUserID) + usersToNotify = append(usersToNotify, targetUserID) n.addJoinedUser(ev.RoomID(), targetUserID) - case "leave": + case gomatrixserverlib.Leave: fallthrough - case "ban": + case gomatrixserverlib.Ban: n.removeJoinedUser(ev.RoomID(), targetUserID) } } } - for _, toNotifyUserID := range userIDs { - n.wakeupUser(toNotifyUserID, pos) - } - } else if len(userID) > 0 { - n.wakeupUser(userID, pos) + n.wakeupUsers(usersToNotify, latestPos) + } else if roomID != "" { + n.wakeupUsers(n.joinedUsers(roomID), latestPos) + } else if len(userIDs) > 0 { + n.wakeupUsers(userIDs, latestPos) + } else { + log.WithFields(log.Fields{ + "posUpdate": posUpdate.String, + }).Warn("Notifier.OnNewEvent called but caller supplied no user to wake up") } } @@ -127,7 +141,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { } // Load the membership states required to notify users correctly. -func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) error { +func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatasource) error { roomToUsers, err := db.AllJoinedUsersInRooms(ctx) if err != nil { return err @@ -136,8 +150,11 @@ func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatabase) err return nil } -// CurrentPosition returns the current stream position -func (n *Notifier) CurrentPosition() types.StreamPosition { +// CurrentPosition returns the current sync position +func (n *Notifier) CurrentPosition() types.SyncPosition { + n.streamLock.Lock() + defer n.streamLock.Unlock() + return n.currPos } @@ -156,17 +173,19 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUser(userID string, newPos types.StreamPosition) { - stream := n.fetchUserStream(userID, false) - if stream == nil { - return +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) { + for _, userID := range userIDs { + stream := n.fetchUserStream(userID, false) + if stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } } - stream.Broadcast(newPos) // wakeup all goroutines Wait()ing on this stream } // fetchUserStream retrieves a stream unique to the given user. If makeIfNotExists is true, // a stream will be made for this user if one doesn't exist and it will be returned. This // function does not wait for data to be available on the stream. +// NB: Callers should have locked the mutex before calling this function. func (n *Notifier) fetchUserStream(userID string, makeIfNotExists bool) *UserStream { stream, ok := n.userStreams[userID] if !ok && makeIfNotExists { diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 4fa543936..808e07cc7 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -32,19 +32,40 @@ var ( randomMessageEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event + syncPositionVeryOld types.SyncPosition + syncPositionBefore types.SyncPosition + syncPositionAfter types.SyncPosition + syncPositionNewEDU types.SyncPosition + syncPositionAfter2 types.SyncPosition ) var ( - streamPositionVeryOld = types.StreamPosition(5) - streamPositionBefore = types.StreamPosition(11) - streamPositionAfter = types.StreamPosition(12) - streamPositionAfter2 = types.StreamPosition(13) - roomID = "!test:localhost" - alice = "@alice:localhost" - bob = "@bob:localhost" + roomID = "!test:localhost" + alice = "@alice:localhost" + bob = "@bob:localhost" ) func init() { + baseSyncPos := types.SyncPosition{ + PDUPosition: 0, + TypingPosition: 0, + } + + syncPositionVeryOld = baseSyncPos + syncPositionVeryOld.PDUPosition = 5 + + syncPositionBefore = baseSyncPos + syncPositionBefore.PDUPosition = 11 + + syncPositionAfter = baseSyncPos + syncPositionAfter.PDUPosition = 12 + + syncPositionNewEDU = syncPositionAfter + syncPositionNewEDU.TypingPosition = 1 + + syncPositionAfter2 = baseSyncPos + syncPositionAfter2.PDUPosition = 13 + var err error randomMessageEvent, err = gomatrixserverlib.NewEventFromTrustedJSON([]byte(`{ "type": "m.room.message", @@ -92,19 +113,19 @@ func init() { // Test that the current position is returned if a request is already behind. func TestImmediateNotification(t *testing.T) { - n := NewNotifier(streamPositionBefore) - pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionVeryOld)) + n := NewNotifier(syncPositionBefore) + pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionVeryOld)) if err != nil { t.Fatalf("TestImmediateNotification error: %s", err) } - if pos != streamPositionBefore { - t.Fatalf("TestImmediateNotification want %d, got %d", streamPositionBefore, pos) + if pos != syncPositionBefore { + t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos) } } // Test that new events to a joined room unblocks the request. func TestNewEventAndJoinedToRoom(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -112,27 +133,27 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos) } wg.Done() }() - stream := n.fetchUserStream(bob, true) + stream := lockedFetchUserStream(n, bob) waitForBlocking(stream, 1) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) wg.Wait() } // Test that an invite unblocks the request func TestNewInviteEventForUser(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -140,27 +161,55 @@ func TestNewInviteEventForUser(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewInviteEventForUser error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewInviteEventForUser want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos) } wg.Done() }() - stream := n.fetchUserStream(bob, true) + stream := lockedFetchUserStream(n, bob) waitForBlocking(stream, 1) - n.OnNewEvent(&aliceInviteBobEvent, "", streamPositionAfter) + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionAfter) + + wg.Wait() +} + +// Test an EDU-only update wakes up the request. +func TestEDUWakeup(t *testing.T) { + n := NewNotifier(syncPositionAfter) + n.setUsersJoinedToRooms(map[string][]string{ + roomID: {alice, bob}, + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) + if err != nil { + t.Errorf("TestNewInviteEventForUser error: %s", err) + } + if pos != syncPositionNewEDU { + t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos) + } + wg.Done() + }() + + stream := lockedFetchUserStream(n, bob) + waitForBlocking(stream, 1) + + n.OnNewEvent(&aliceInviteBobEvent, "", nil, syncPositionNewEDU) wg.Wait() } // Test that all blocked requests get woken up on a new event. func TestMultipleRequestWakeup(t *testing.T) { - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -168,12 +217,12 @@ func TestMultipleRequestWakeup(t *testing.T) { var wg sync.WaitGroup wg.Add(3) poll := func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestMultipleRequestWakeup error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos) } wg.Done() } @@ -181,10 +230,10 @@ func TestMultipleRequestWakeup(t *testing.T) { go poll() go poll() - stream := n.fetchUserStream(bob, true) + stream := lockedFetchUserStream(n, bob) waitForBlocking(stream, 3) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter) wg.Wait() @@ -198,7 +247,7 @@ func TestMultipleRequestWakeup(t *testing.T) { func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // listen as bob. Make bob leave room. Make alice send event to room. // Make sure alice gets woken up only and not bob as well. - n := NewNotifier(streamPositionBefore) + n := NewNotifier(syncPositionBefore) n.setUsersJoinedToRooms(map[string][]string{ roomID: {alice, bob}, }) @@ -208,38 +257,38 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { // Make bob leave the room leaveWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionBefore)) + pos, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionBefore)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } - if pos != streamPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter, pos) + if pos != syncPositionAfter { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos) } leaveWG.Done() }() - bobStream := n.fetchUserStream(bob, true) + bobStream := lockedFetchUserStream(n, bob) waitForBlocking(bobStream, 1) - n.OnNewEvent(&bobLeaveEvent, "", streamPositionAfter) + n.OnNewEvent(&bobLeaveEvent, "", nil, syncPositionAfter) leaveWG.Wait() // send an event into the room. Make sure alice gets it. Bob should not. var aliceWG sync.WaitGroup - aliceStream := n.fetchUserStream(alice, true) + aliceStream := lockedFetchUserStream(n, alice) aliceWG.Add(1) go func() { - pos, err := waitForEvents(n, newTestSyncRequest(alice, streamPositionAfter)) + pos, err := waitForEvents(n, newTestSyncRequest(alice, syncPositionAfter)) if err != nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } - if pos != streamPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", streamPositionAfter2, pos) + if pos != syncPositionAfter2 { + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos) } aliceWG.Done() }() go func() { // this should timeout with an error (but the main goroutine won't wait for the timeout explicitly) - _, err := waitForEvents(n, newTestSyncRequest(bob, streamPositionAfter)) + _, err := waitForEvents(n, newTestSyncRequest(bob, syncPositionAfter)) if err == nil { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom expect error but got nil") } @@ -248,7 +297,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { waitForBlocking(aliceStream, 1) waitForBlocking(bobStream, 1) - n.OnNewEvent(&randomMessageEvent, "", streamPositionAfter2) + n.OnNewEvent(&randomMessageEvent, "", nil, syncPositionAfter2) aliceWG.Wait() // it's possible that at this point alice has been informed and bob is about to be informed, so wait @@ -256,18 +305,17 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -// same as Notifier.WaitForEvents but with a timeout. -func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.StreamPosition(0), fmt.Errorf( + return types.SyncPosition{}, fmt.Errorf( "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): - p := listener.GetStreamPosition() + p := listener.GetSyncPosition() return p, nil } } @@ -280,7 +328,16 @@ func waitForBlocking(s *UserStream, numBlocking uint) { } } -func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { +// lockedFetchUserStream invokes Notifier.fetchUserStream, respecting Notifier.streamLock. +// A new stream is made if it doesn't exist already. +func lockedFetchUserStream(n *Notifier, userID string) *UserStream { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + return n.fetchUserStream(userID, true) +} + +func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { return syncRequest{ device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 35a15f6f9..a5d2f60f4 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -16,8 +16,10 @@ package sync import ( "context" + "errors" "net/http" "strconv" + "strings" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -36,7 +38,7 @@ type syncRequest struct { device authtypes.Device limit int timeout time.Duration - since *types.StreamPosition // nil means that no since token was supplied + since *types.SyncPosition // nil means that no since token was supplied wantFullState bool log *log.Entry } @@ -73,15 +75,41 @@ func getTimeout(timeoutMS string) time.Duration { } // getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// stream position. If the string is empty then (nil, nil) is returned. -func getSyncStreamPosition(since string) (*types.StreamPosition, error) { +// types.SyncPosition. If the string is empty then (nil, nil) is returned. +// There are two forms of tokens: The full length form containing all PDU and EDU +// positions separated by "_", and the short form containing only the PDU +// position. Short form can be used for, e.g., `prev_batch` tokens. +func getSyncStreamPosition(since string) (*types.SyncPosition, error) { if since == "" { return nil, nil } - i, err := strconv.Atoi(since) - if err != nil { - return nil, err + + posStrings := strings.Split(since, "_") + if len(posStrings) != 2 && len(posStrings) != 1 { + // A token can either be full length or short (PDU-only). + return nil, errors.New("malformed batch token") + } + + positions := make([]int64, len(posStrings)) + for i, posString := range posStrings { + pos, err := strconv.ParseInt(posString, 10, 64) + if err != nil { + return nil, err + } + positions[i] = pos + } + + if len(positions) == 2 { + // Full length token; construct SyncPosition with every entry in + // `positions`. These entries must have the same order with the fields + // in struct SyncPosition, so we disable the govet check below. + return &types.SyncPosition{ //nolint:govet + positions[0], positions[1], + }, nil + } else { + // Token with PDU position only + return &types.SyncPosition{ + PDUPosition: positions[0], + }, nil } - token := types.StreamPosition(i) - return &token, nil } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 89137eb59..6b95f4698 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -31,13 +31,13 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db *storage.SyncServerDatabase + db *storage.SyncServerDatasource accountDB *accounts.Database notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool { +func NewRequestPool(db *storage.SyncServerDatasource, n *Notifier, adb *accounts.Database) *RequestPool { return &RequestPool{db, adb, n} } @@ -65,8 +65,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype currPos := rp.notifier.CurrentPosition() - // If this is an initial sync or timeout=0 we return immediately - if syncReq.since == nil || syncReq.timeout == 0 { + if shouldReturnImmediately(syncReq) { syncData, err = rp.currentSyncForUser(*syncReq, currPos) if err != nil { return httputil.LogThenError(req, err) @@ -92,11 +91,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // respond with, so we skip the return an go back to waiting for content to // be sent down or the request timing out. var hasTimedOut bool + sincePos := *syncReq.since for { select { // Wait for notifier to wake us up - case <-userStreamListener.GetNotifyChannel(currPos): - currPos = userStreamListener.GetStreamPosition() + case <-userStreamListener.GetNotifyChannel(sincePos): + currPos = userStreamListener.GetSyncPosition() + sincePos = currPos // Or for timeout to expire case <-timer.C: // We just need to ensure we get out of the select after reaching the @@ -128,24 +129,26 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) } else { - res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit) + res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState) } if err != nil { return } - res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) + accountDataFilter := gomatrixserverlib.DefaultFilterPart() // TODO: use filter provided in req instead + res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) return } func (rp *RequestPool) appendAccountData( - data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, + data *types.Response, userID string, req syncRequest, currentPos int64, + accountDataFilter *gomatrixserverlib.FilterPart, ) (*types.Response, error) { // TODO: Account data doesn't have a sync position of its own, meaning that // account data might be sent multiple time to the client if multiple account @@ -179,7 +182,7 @@ func (rp *RequestPool) appendAccountData( } // Sync is not initial, get all account data since the latest sync - dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, *req.since, currentPos) + dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter) if err != nil { return nil, err } @@ -214,3 +217,10 @@ func (rp *RequestPool) appendAccountData( return data, nil } + +// shouldReturnImmediately returns whether the /sync request is an initial sync, +// or timeout=0, or full_state=true, in any of the cases the request should +// return immediately. +func shouldReturnImmediately(syncReq *syncRequest) bool { + return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState +} diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index 77d09c202..beb10e487 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -34,8 +34,8 @@ type UserStream struct { lock sync.Mutex // Closed when there is an update. signalChannel chan struct{} - // The last stream position that there may have been an update for the suser - pos types.StreamPosition + // The last sync position that there may have been an update for the user + pos types.SyncPosition // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting @@ -51,7 +51,7 @@ type UserStreamListener struct { } // NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.StreamPosition) *UserStream { +func NewUserStream(userID string, currPos types.SyncPosition) *UserStream { return &UserStream{ UserID: userID, timeOfLastChannel: time.Now(), @@ -84,8 +84,8 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { return listener } -// Broadcast a new stream position for this user. -func (s *UserStream) Broadcast(pos types.StreamPosition) { +// Broadcast a new sync position for this user. +func (s *UserStream) Broadcast(pos types.SyncPosition) { s.lock.Lock() defer s.lock.Unlock() @@ -118,9 +118,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { return s.timeOfLastChannel } -// GetStreamPosition returns last stream position which the UserStream was +// GetStreamPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { +func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -132,11 +132,11 @@ func (s *UserStreamListener) GetStreamPosition() types.StreamPosition { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamPosition) <-chan struct{} { +func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() - if sincePos < s.userStream.pos { + if s.userStream.pos.IsAfter(sincePos) { // If the listener is behind, i.e. missed a potential update, then we // want them to wake up immediately. We do this by returning a new // closed stream, which returns immediately when selected. diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 2db54c3ce..4738feea2 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -28,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" - "github.com/matrix-org/dendrite/syncapi/types" ) // SetupSyncAPIComponent sets up and registers HTTP handlers for the SyncAPI @@ -39,17 +38,17 @@ func SetupSyncAPIComponent( accountsDB *accounts.Database, queryAPI api.RoomserverQueryAPI, ) { - syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI)) + syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) if err != nil { logrus.WithError(err).Panicf("failed to connect to sync db") } - pos, err := syncDB.SyncStreamPosition(context.Background()) + pos, err := syncDB.SyncPosition(context.Background()) if err != nil { - logrus.WithError(err).Panicf("failed to get stream position") + logrus.WithError(err).Panicf("failed to get sync position") } - notifier := sync.NewNotifier(types.StreamPosition(pos)) + notifier := sync.NewNotifier(pos) err = notifier.Load(context.Background(), syncDB) if err != nil { logrus.WithError(err).Panicf("failed to start notifier") @@ -71,5 +70,12 @@ func SetupSyncAPIComponent( logrus.WithError(err).Panicf("failed to start client data consumer") } + typingConsumer := consumers.NewOutputTypingEventConsumer( + base.Cfg, base.KafkaConsumer, notifier, syncDB, + ) + if err = typingConsumer.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start typing server consumer") + } + routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d0b1c38ab..af7ec865f 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -21,12 +21,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// StreamPosition represents the offset in the sync stream a client is at. -type StreamPosition int64 +// SyncPosition contains the PDU and EDU stream sync positions for a client. +type SyncPosition struct { + // PDUPosition is the stream position for PDUs the client is at. + PDUPosition int64 + // TypingPosition is the client's position for typing notifications. + TypingPosition int64 +} // String implements the Stringer interface. -func (sp StreamPosition) String() string { - return strconv.FormatInt(int64(sp), 10) +func (sp SyncPosition) String() string { + return strconv.FormatInt(sp.PDUPosition, 10) + "_" + + strconv.FormatInt(sp.TypingPosition, 10) +} + +// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition. +func (sp SyncPosition) IsAfter(other SyncPosition) bool { + return sp.PDUPosition > other.PDUPosition || + sp.TypingPosition > other.TypingPosition +} + +// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition. +// If the latter SyncPosition contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called. +func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition { + ret := sp + if other.PDUPosition != 0 { + ret.PDUPosition = other.PDUPosition + } + if other.TypingPosition != 0 { + ret.TypingPosition = other.TypingPosition + } + return ret } // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -53,11 +79,10 @@ type Response struct { } // NewResponse creates an empty response with initialised maps. -func NewResponse(pos StreamPosition) *Response { - res := Response{} - // Make sure we send the next_batch as a string. We don't want to confuse clients by sending this - // as an integer even though (at the moment) it is. - res.NextBatch = pos.String() +func NewResponse(pos SyncPosition) *Response { + res := Response{ + NextBatch: pos.String(), + } // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. res.Rooms.Join = make(map[string]JoinResponse) diff --git a/testfile b/testfile index 0ddaba417..a93dca16f 100644 --- a/testfile +++ b/testfile @@ -42,6 +42,7 @@ POST /join/:room_alias can join a room POST /join/:room_id can join a room POST /join/:room_id can join a room with custom content POST /join/:room_alias can join a room with custom content +POST /rooms/:room_id/join can join a room POST /rooms/:room_id/leave can leave a room POST /rooms/:room_id/invite can send an invite POST /rooms/:room_id/ban can ban a user @@ -142,7 +143,34 @@ Trying to get push rules with unknown rule_id fails with 404 Events come down the correct room local user can join room with version 5 User can invite local user to room with version 5 -Inbound federation can receive room-join requests +Inbound federation can receive v1 room-join requests +Typing events appear in initial sync +Typing events appear in incremental sync +Typing events appear in gapped sync +Inbound federation of state requires event_id as a mandatory paramater +Inbound federation of state_ids requires event_id as a mandatory paramater +POST /register returns the same device_id as that in the request +POST /login returns the same device_id as that in the request +POST /createRoom with creation content +User can create and send/receive messages in a room with version 1 +POST /createRoom ignores attempts to set the room version via creation_content +Inbound federation rejects remote attempts to join local users to rooms +Inbound federation rejects remote attempts to kick local users to rooms +An event which redacts itself should be ignored +A pair of events which redact each other should be ignored +Full state sync includes joined rooms +A message sent after an initial sync appears in the timeline of an incremental sync. +Can add tag +Can remove tag +Can list tags for a room +Tags appear in an initial v2 /sync +Newly updated tags appear in an incremental v2 /sync +Deleted tags appear in an incremental v2 /sync +/event/ on non world readable room does not work +Outbound federation can query profile data +/event/ on joined room works +/event/ does not allow access to events before the user joined +Federation key API allows unsigned requests for keys GET /directory/room/:room_alias yields room ID PUT /directory/room/:room_alias creates alias Room aliases can contain Unicode diff --git a/typingserver/api/output.go b/typingserver/api/output.go index 813b9b7c7..8696acf49 100644 --- a/typingserver/api/output.go +++ b/typingserver/api/output.go @@ -12,14 +12,17 @@ package api +import "time" + // OutputTypingEvent is an entry in typing server output kafka log. // This contains the event with extra fields used to create 'm.typing' event // in clientapi & federation. type OutputTypingEvent struct { // The Event for the typing edu event. Event TypingEvent `json:"event"` - // Users typing in the room when the event was generated. - TypingUsers []string `json:"typing_users"` + // ExpireTime is the interval after which the user should no longer be + // considered typing. Only available if Event.Typing is true. + ExpireTime *time.Time } // TypingEvent represents a matrix edu event of type 'm.typing'. diff --git a/typingserver/cache/cache.go b/typingserver/cache/cache.go index 85d74cd19..3f05c938e 100644 --- a/typingserver/cache/cache.go +++ b/typingserver/cache/cache.go @@ -22,25 +22,66 @@ const defaultTypingTimeout = 10 * time.Second // userSet is a map of user IDs to a timer, timer fires at expiry. type userSet map[string]*time.Timer +// TimeoutCallbackFn is a function called right after the removal of a user +// from the typing user list due to timeout. +// latestSyncPosition is the typing sync position after the removal. +type TimeoutCallbackFn func(userID, roomID string, latestSyncPosition int64) + +type roomData struct { + syncPosition int64 + userSet userSet +} + // TypingCache maintains a list of users typing in each room. type TypingCache struct { sync.RWMutex - data map[string]userSet + latestSyncPosition int64 + data map[string]*roomData + timeoutCallback TimeoutCallbackFn +} + +// Create a roomData with its sync position set to the latest sync position. +// Must only be called after locking the cache. +func (t *TypingCache) newRoomData() *roomData { + return &roomData{ + syncPosition: t.latestSyncPosition, + userSet: make(userSet), + } } // NewTypingCache returns a new TypingCache initialised for use. func NewTypingCache() *TypingCache { - return &TypingCache{data: make(map[string]userSet)} + return &TypingCache{data: make(map[string]*roomData)} +} + +// SetTimeoutCallback sets a callback function that is called right after +// a user is removed from the typing user list due to timeout. +func (t *TypingCache) SetTimeoutCallback(fn TimeoutCallbackFn) { + t.timeoutCallback = fn } // GetTypingUsers returns the list of users typing in a room. -func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { +func (t *TypingCache) GetTypingUsers(roomID string) []string { + users, _ := t.GetTypingUsersIfUpdatedAfter(roomID, 0) + // 0 should work above because the first position used will be 1. + return users +} + +// GetTypingUsersIfUpdatedAfter returns all users typing in this room with +// updated == true if the typing sync position of the room is after the given +// position. Otherwise, returns an empty slice with updated == false. +func (t *TypingCache) GetTypingUsersIfUpdatedAfter( + roomID string, position int64, +) (users []string, updated bool) { t.RLock() - usersMap, ok := t.data[roomID] - t.RUnlock() - if ok { - users = make([]string, 0, len(usersMap)) - for userID := range usersMap { + defer t.RUnlock() + + roomData, ok := t.data[roomID] + if ok && roomData.syncPosition > position { + updated = true + userSet := roomData.userSet + users = make([]string, 0, len(userSet)) + for userID := range userSet { users = append(users, userID) } } @@ -51,53 +92,84 @@ func (t *TypingCache) GetTypingUsers(roomID string) (users []string) { // AddTypingUser sets an user as typing in a room. // expire is the time when the user typing should time out. // if expire is nil, defaultTypingTimeout is assumed. -func (t *TypingCache) AddTypingUser(userID, roomID string, expire *time.Time) { +// Returns the latest sync position for typing after update. +func (t *TypingCache) AddTypingUser( + userID, roomID string, expire *time.Time, +) int64 { expireTime := getExpireTime(expire) if until := time.Until(expireTime); until > 0 { - timer := time.AfterFunc(until, t.timeoutCallback(userID, roomID)) - t.addUser(userID, roomID, timer) + timer := time.AfterFunc(until, func() { + latestSyncPosition := t.RemoveUser(userID, roomID) + if t.timeoutCallback != nil { + t.timeoutCallback(userID, roomID, latestSyncPosition) + } + }) + return t.addUser(userID, roomID, timer) } + return t.GetLatestSyncPosition() } // addUser with mutex lock & replace the previous timer. -func (t *TypingCache) addUser(userID, roomID string, expiryTimer *time.Timer) { +// Returns the latest typing sync position after update. +func (t *TypingCache) addUser( + userID, roomID string, expiryTimer *time.Timer, +) int64 { t.Lock() defer t.Unlock() + t.latestSyncPosition++ + if t.data[roomID] == nil { - t.data[roomID] = make(userSet) + t.data[roomID] = t.newRoomData() + } else { + t.data[roomID].syncPosition = t.latestSyncPosition } // Stop the timer to cancel the call to timeoutCallback - if timer, ok := t.data[roomID][userID]; ok { - // It may happen that at this stage timer fires but now we have a lock on t. - // Hence the execution of timeoutCallback will happen after we unlock. - // So we may lose a typing state, though this event is highly unlikely. - // This can be mitigated by keeping another time.Time in the map and check against it - // before removing. This however is not required in most practical scenario. + if timer, ok := t.data[roomID].userSet[userID]; ok { + // It may happen that at this stage the timer fires, but we now have a lock on + // it. Hence the execution of timeoutCallback will happen after we unlock. So + // we may lose a typing state, though this is highly unlikely. This can be + // mitigated by keeping another time.Time in the map and checking against it + // before removing, but its occurrence is so infrequent it does not seem + // worthwhile. timer.Stop() } - t.data[roomID][userID] = expiryTimer -} + t.data[roomID].userSet[userID] = expiryTimer -// Returns a function which is called after timeout happens. -// This removes the user. -func (t *TypingCache) timeoutCallback(userID, roomID string) func() { - return func() { - t.RemoveUser(userID, roomID) - } + return t.latestSyncPosition } // RemoveUser with mutex lock & stop the timer. -func (t *TypingCache) RemoveUser(userID, roomID string) { +// Returns the latest sync position for typing after update. +func (t *TypingCache) RemoveUser(userID, roomID string) int64 { t.Lock() defer t.Unlock() - if timer, ok := t.data[roomID][userID]; ok { - timer.Stop() - delete(t.data[roomID], userID) + roomData, ok := t.data[roomID] + if !ok { + return t.latestSyncPosition } + + timer, ok := roomData.userSet[userID] + if !ok { + return t.latestSyncPosition + } + + timer.Stop() + delete(roomData.userSet, userID) + + t.latestSyncPosition++ + t.data[roomID].syncPosition = t.latestSyncPosition + + return t.latestSyncPosition +} + +func (t *TypingCache) GetLatestSyncPosition() int64 { + t.Lock() + defer t.Unlock() + return t.latestSyncPosition } func getExpireTime(expire *time.Time) time.Time { diff --git a/typingserver/input/input.go b/typingserver/input/input.go index b9968ce4c..0e2fbe51f 100644 --- a/typingserver/input/input.go +++ b/typingserver/input/input.go @@ -57,15 +57,21 @@ func (t *TypingServerInputAPI) InputTypingEvent( } func (t *TypingServerInputAPI) sendEvent(ite *api.InputTypingEvent) error { - userIDs := t.Cache.GetTypingUsers(ite.RoomID) ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, RoomID: ite.RoomID, UserID: ite.UserID, + Typing: ite.Typing, } ote := &api.OutputTypingEvent{ - Event: *ev, - TypingUsers: userIDs, + Event: *ev, + } + + if ev.Typing { + expireTime := ite.OriginServerTS.Time().Add( + time.Duration(ite.Timeout) * time.Millisecond, + ) + ote.ExpireTime = &expireTime } eventJSON, err := json.Marshal(ote)