From f0340988777fa13727d458a56480f2c3f70d8660 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Fri, 28 Jan 2022 12:07:47 +0100 Subject: [PATCH 01/81] "Enable" remote room search (#2099) * "Enable" remote room search Signed-off-by: Till Faelligen * Update go.mod * Fix formatting --- clientapi/routing/directory_public.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index 2e3283be1..0dacfced5 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -63,7 +63,12 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && serverName != cfg.Matrix.ServerName { - res, err := federation.GetPublicRooms(req.Context(), serverName, int(request.Limit), request.Since, false, "") + res, err := federation.GetPublicRoomsFiltered( + req.Context(), serverName, + int(request.Limit), request.Since, + request.Filter.SearchTerms, false, + "", + ) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("failed to get public rooms") return jsonerror.InternalServerError() From 2ea5fd41623e22cd2a00c59e7954950037bfa8bb Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 28 Jan 2022 11:14:20 +0000 Subject: [PATCH 02/81] Add debug logging for incoming CSAPI calls on authentication failure (#2116) * Add debug logging for incoming CSAPI calls on authentication failure Will help to debug Complement failures, and just generally useful. * Update httpapi.go Co-authored-by: Neil Alexander --- internal/httputil/httpapi.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 1fbd77da9..1a37a1eec 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -53,12 +53,13 @@ func MakeAuthAPI( f func(*http.Request, *userapi.Device) util.JSONResponse, ) http.Handler { h := func(req *http.Request) util.JSONResponse { + logger := util.GetLogger(req.Context()) device, err := auth.VerifyUserFromRequest(req, userAPI) if err != nil { + logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, err.Code) return *err } // add the user ID to the logger - logger := util.GetLogger((req.Context())) logger = logger.WithField("user_id", device.UserID) req = req.WithContext(util.ContextWithLogger(req.Context(), logger)) // add the user to Sentry, if enabled From e9fbad6f2015ac0375eff76c8d50962791099e17 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 28 Jan 2022 12:33:31 +0000 Subject: [PATCH 03/81] Move hook call when processing room events (#2118) * Move hook call when processing room events * Fix build --- roomserver/internal/input/input.go | 5 ----- roomserver/internal/input/input_events.go | 5 ++++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e0ddd07cf..e6f325b47 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -25,7 +25,6 @@ import ( "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" fedapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" @@ -105,8 +104,6 @@ func (r *Inputer) Start() error { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } - } else { - go hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) } _ = msg.Ack() }) @@ -176,8 +173,6 @@ func (r *Inputer) InputRoomEvents( if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } - } else { - go hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) } select { case <-ctx.Done(): diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 5f9115223..334421400 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -25,6 +25,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" @@ -321,7 +322,9 @@ func (r *Inputer) processRoomEvent( } } - // Update the extremities of the event graph for the room + // Everything was OK — the latest events updater didn't error and + // we've sent output events. Finally, generate a hook call. + hooks.Run(hooks.KindNewEventPersisted, headered) return nil } From 8e4002831f6681b5f2de1c6490184ed6ed2275fe Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 28 Jan 2022 13:11:56 +0000 Subject: [PATCH 04/81] Call hooks for outliers (#2119) * Move hook call when processing room events * Fix build * Call hooks for outliers too --- roomserver/internal/input/input_events.go | 1 + 1 file changed, 1 insertion(+) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 334421400..8f262ebe6 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -250,6 +250,7 @@ func (r *Inputer) processRoomEvent( // notify anyone about it. if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") + hooks.Run(hooks.KindNewEventPersisted, headered) return nil } From bde7c1fd8ca040f61f339c96c6589d57feca3dd9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 28 Jan 2022 14:13:36 +0000 Subject: [PATCH 05/81] Version 0.6 (#2117) * Bump version, release notes * Update changelog * Update changelog --- CHANGES.md | 28 ++++++++++++++++++++++++++++ internal/version.go | 4 ++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 94edc6288..95e24de11 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,33 @@ # Changelog +## Dendrite 0.6.0 (2022-01-28) + +### Features + +* NATS JetStream is now used instead of Kafka and Naffka + * For monolith deployments, a built-in NATS Server is embedded into Dendrite or a standalone NATS Server deployment can be optionally used instead + * For polylith deployments, a standalone NATS Server deployment is required + * Requires the version 2 configuration file — please see the new `dendrite-config.yaml` sample config file + * Kafka and Naffka are no longer supported as of this release +* The roomserver is now responsible for fetching missing events and state instead of the federation API + * Removes a number of race conditions between the federation API and roomserver, which reduces duplicate work and overall lowers CPU usage +* The roomserver input API is now strictly ordered with support for asynchronous requests, smoothing out incoming federation significantly +* Consolidated the federation API, federation sender and signing key server into a single component + * If multiple databases are used, tables for the federation sender and signing key server should be merged into the federation API database (table names have not changed) +* Device list synchronisation is now database-backed rather than using the now-removed Kafka logs + +### Fixes + +* The code for fetching missing events and state now correctly identifies when gaps in history have been closed, so federation traffic will consume less CPU and memory than before +* The stream position is now correctly advanced when typing notifications time out in the sync API +* Event NIDs are now correctly returned when persisting events in the roomserver in SQLite mode + * The built-in SQLite was updated to version 3.37.0 as a result +* The `/event_auth` endpoint now strictly returns the auth chain for the requested event without loading the room state, which should reduce spikes in memory usage +* Filters are now correctly sent when using federated public room directories (contributed by [S7evinK](https://github.com/S7evinK)) +* Login usernames are now squashed to lower-case (contributed by [BernardZhao](https://github.com/BernardZhao)) +* The logs should no longer be flooded with `Failed to get server ACLs for room` warnings at startup +* Backfilling will now attempt federation as a last resort when trying to retrieve missing events from the database fails + ## Dendrite 0.5.1 (2021-11-16) ### Features diff --git a/internal/version.go b/internal/version.go index 88123693f..f09daabd9 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 5 - VersionPatch = 1 + VersionMinor = 6 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) From 2c3dd48bb2daf078fe2e36c3a06a995fcd693d20 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 28 Jan 2022 16:24:01 +0000 Subject: [PATCH 06/81] Require Go 1.16 (#2122) --- README.md | 2 +- cmd/dendrite-demo-yggdrasil/README.md | 2 +- docs/INSTALL.md | 2 +- go.mod | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3ec9f0296..a077788cf 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ If you have further questions, please take a look at [our FAQ](docs/FAQ.md) or j ## Requirements -To build Dendrite, you will need Go 1.15 or later. +To build Dendrite, you will need Go 1.16 or later. For a usable federating Dendrite deployment, you will also need: - A domain name (or subdomain) diff --git a/cmd/dendrite-demo-yggdrasil/README.md b/cmd/dendrite-demo-yggdrasil/README.md index c471cef22..946333576 100644 --- a/cmd/dendrite-demo-yggdrasil/README.md +++ b/cmd/dendrite-demo-yggdrasil/README.md @@ -1,6 +1,6 @@ # Yggdrasil Demo -This is the Dendrite Yggdrasil demo! It's easy to get started - all you need is Go 1.15 or later. +This is the Dendrite Yggdrasil demo! It's easy to get started - all you need is Go 1.16 or later. To run the homeserver, start at the root of the Dendrite repository and run: diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 2afb43c6a..686ae1dbb 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -27,7 +27,7 @@ use in production environments just yet! Dendrite requires: -* Go 1.15 or higher +* Go 1.16 or higher * PostgreSQL 12 or higher (if using PostgreSQL databases, not needed for SQLite) If you want to run a polylith deployment, you also need: diff --git a/go.mod b/go.mod index 6d482bd60..5ddcf980b 100644 --- a/go.mod +++ b/go.mod @@ -72,4 +72,4 @@ require ( nhooyr.io/websocket v1.8.7 ) -go 1.15 +go 1.16 From 4281976df9d08c87d367707e6bba437bc0e72745 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 28 Jan 2022 17:31:54 +0000 Subject: [PATCH 07/81] Update Sarama to fix 32-bit builds (#2120) --- go.mod | 2 +- go.sum | 30 ++++++++++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 5ddcf980b..01cff763c 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 - github.com/Shopify/sarama v1.29.0 + github.com/Shopify/sarama v1.31.0 github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible diff --git a/go.sum b/go.sum index 3ef5a54aa..f04754295 100644 --- a/go.sum +++ b/go.sum @@ -104,10 +104,13 @@ github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 h1:i3fOph9 github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32/go.mod h1:ne+jkLlzafIzaE4Q0Ze81T27dNgXe1wxovVEoAtSHTc= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ= -github.com/Shopify/sarama v1.29.0 h1:ARid8o8oieau9XrHI55f/L3EoRAhm9px6sonbD7yuUE= github.com/Shopify/sarama v1.29.0/go.mod h1:2QpgD79wpdAESqNQMxNc0KYMkycd4slxGdV3TWSVqrU= +github.com/Shopify/sarama v1.31.0 h1:gObk7jCPutDxf+E6GA5G21noAZsi1SvP9ftCQYqpzus= +github.com/Shopify/sarama v1.31.0/go.mod h1:BeW3gXRc/CxgAsrSly2RE9nIXUfC9ezb7QHBPVhvzjI= github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/Shopify/toxiproxy/v2 v2.3.0 h1:62YkpiP4bzdhKMH+6uC5E95y608k3zDwdzuBMsnn3uQ= +github.com/Shopify/toxiproxy/v2 v2.3.0/go.mod h1:KvQTtB6RjCJY4zqNJn7C7JDFgsG5uoHYDirfUfpIm0c= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= @@ -381,8 +384,9 @@ github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHqu github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= -github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= +github.com/frankban/quicktest v1.14.0 h1:+cqqvzZV87b4adx/5ayVOaYZ2CrvM4ejQvUdBzPPUss= +github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -494,8 +498,9 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= @@ -512,8 +517,9 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -747,6 +753,7 @@ github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdY github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= @@ -758,8 +765,9 @@ github.com/koron/go-ssdp v0.0.0-20191105050749-2e1c40ed0b5d/go.mod h1:5Ky9EC2xfo github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= @@ -1246,8 +1254,9 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= -github.com/pierrec/lz4 v2.6.0+incompatible h1:Ix9yFKn1nSPBLFl/yZknTp8TU5G4Ps0JDmguYK6iH1A= github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= +github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1307,6 +1316,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= @@ -1426,6 +1437,7 @@ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= @@ -1456,6 +1468,9 @@ github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPyS github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -1655,6 +1670,7 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1788,6 +1804,7 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7-0.20210503195748-5c7c50ebbd4f/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= @@ -2025,6 +2042,7 @@ gopkg.in/yaml.v2 v2.0.0-20170712054546-1be3d31502d6/go.mod h1:JAlM8MvJe8wmxCU4Bl gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From a271fde8f581389fd793e5064f88f70e90d50f66 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 10:39:33 +0000 Subject: [PATCH 08/81] Only limit context for fetching missing auth/prev events (#2131) --- roomserver/internal/input/input.go | 2 +- roomserver/internal/input/input_events.go | 17 +++++++---------- roomserver/internal/input/input_missing.go | 4 ++++ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e6f325b47..7933f9750 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -120,7 +120,7 @@ func (r *Inputer) Start() error { nats.DeliverAll(), // Ensure that NATS doesn't try to resend us something that wasn't done // within the period of time that we might still be processing it. - nats.AckWait(MaximumProcessingTime+(time.Second*10)), + nats.AckWait((MaximumMissingProcessingTime*2)+(time.Second*10)), ) return err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 8f262ebe6..2dc096674 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -41,7 +41,7 @@ func init() { } // TODO: Does this value make sense? -const MaximumProcessingTime = time.Minute * 2 +const MaximumMissingProcessingTime = time.Minute * 2 var processRoomEventDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -66,11 +66,11 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // TODO: Break up function - we should probably do transaction ID checks before calling this. // nolint:gocyclo func (r *Inputer) processRoomEvent( - inctx context.Context, + ctx context.Context, input *api.InputRoomEvent, ) (err error) { select { - case <-inctx.Done(): + case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. @@ -78,13 +78,6 @@ func (r *Inputer) processRoomEvent( default: } - // Wrap the context with a time limit. We'll allow no more than MaximumProcessingTime for - // everything that we need to do for this event, or it's possible that we could end up wedging - // the roomserver for a very long time. - var cancel context.CancelFunc - ctx, cancel := context.WithTimeout(inctx, MaximumProcessingTime) - defer cancel() - // Measure how long it takes to process this event. started := time.Now() defer func() { @@ -344,6 +337,10 @@ func (r *Inputer) fetchAuthEvents( known map[string]*types.Event, servers []gomatrixserverlib.ServerName, ) error { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) + defer cancel() + unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 44710962c..862b3a7fe 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -37,6 +37,10 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) error { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) + defer cancel() + // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: From eb8e770e9973fb22371a0a474fee7d10a981d800 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 10:42:41 +0000 Subject: [PATCH 09/81] Revert consumer change --- roomserver/internal/input/input.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 7933f9750..938d5ac1a 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -120,7 +120,7 @@ func (r *Inputer) Start() error { nats.DeliverAll(), // Ensure that NATS doesn't try to resend us something that wasn't done // within the period of time that we might still be processing it. - nats.AckWait((MaximumMissingProcessingTime*2)+(time.Second*10)), + nats.AckWait(MaximumMissingProcessingTime+(time.Second*10)), ) return err } From ba1a9b98b70e340b8dd7c748aab8998e493c9c05 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 10:48:28 +0000 Subject: [PATCH 10/81] Tweak some logging (#2130) * Modify some log levels * Update gomatrixserverlib to matrix-org/gomatrixserverlib@336334f * Update gomatrixserverlib to matrix-org/gomatrixserverlib@cde7ac8 * Demote warning about key change producer * Add more useful roomserver logging * Further tweaking --- clientapi/producers/syncapi.go | 2 +- eduserver/input/input.go | 8 +++--- federationapi/routing/send.go | 28 +++++++++---------- go.mod | 2 +- go.sum | 11 ++------ keyserver/producers/keychange.go | 4 +-- mediaapi/routing/download.go | 26 ++++++++--------- roomserver/internal/input/input.go | 14 ++++++++++ roomserver/internal/input/input_missing.go | 20 ++++++------- roomserver/internal/perform/perform_invite.go | 2 +- roomserver/internal/perform/perform_join.go | 5 ++++ roomserver/internal/perform/perform_leave.go | 10 ++++++- syncapi/consumers/clientapi.go | 2 +- 13 files changed, 77 insertions(+), 57 deletions(-) diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go index bd6af5f1f..9b1d6b1a2 100644 --- a/clientapi/producers/syncapi.go +++ b/clientapi/producers/syncapi.go @@ -51,7 +51,7 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string "user_id": userID, "room_id": roomID, "data_type": dataType, - }).Infof("Producing to topic '%s'", p.Topic) + }).Tracef("Producing to topic '%s'", p.Topic) _, err = p.JetStream.PublishMsg(m) return err diff --git a/eduserver/input/input.go b/eduserver/input/input.go index e7501a907..4f8ab3e34 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -98,7 +98,7 @@ func (t *EDUServerInputAPI) InputCrossSigningKeyUpdate( logrus.WithFields(logrus.Fields{ "user_id": request.UserID, - }).Infof("Producing to topic '%s'", t.OutputKeyChangeEventTopic) + }).Tracef("Producing to topic '%s'", t.OutputKeyChangeEventTopic) _, err = t.JetStream.PublishMsg(&nats.Msg{ Subject: t.OutputKeyChangeEventTopic, @@ -134,7 +134,7 @@ func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { "room_id": ite.RoomID, "user_id": ite.UserID, "typing": ite.Typing, - }).Infof("Producing to topic '%s'", t.OutputTypingEventTopic) + }).Tracef("Producing to topic '%s'", t.OutputTypingEventTopic) _, err = t.JetStream.PublishMsg(&nats.Msg{ Subject: t.OutputTypingEventTopic, @@ -175,7 +175,7 @@ func (t *EDUServerInputAPI) sendToDeviceEvent(ise *api.InputSendToDeviceEvent) e "user_id": ise.UserID, "num_devices": len(devices), "type": ise.Type, - }).Infof("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) + }).Tracef("Producing to topic '%s'", t.OutputSendToDeviceEventTopic) for _, device := range devices { ote := &api.OutputSendToDeviceEvent{ UserID: ise.UserID, @@ -208,7 +208,7 @@ func (t *EDUServerInputAPI) InputReceiptEvent( request *api.InputReceiptEventRequest, response *api.InputReceiptEventResponse, ) error { - logrus.WithFields(logrus.Fields{}).Infof("Producing to topic '%s'", t.OutputReceiptEventTopic) + logrus.WithFields(logrus.Fields{}).Tracef("Producing to topic '%s'", t.OutputReceiptEventTopic) output := &api.OutputReceiptEvent{ UserID: request.InputReceiptEvent.UserID, RoomID: request.InputReceiptEvent.RoomID, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index dbfd3ff92..524fd510e 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -162,7 +162,7 @@ func Send( t.TransactionID = txnID t.Destination = cfg.Matrix.ServerName - util.GetLogger(httpReq.Context()).Infof("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) + util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) resp, jsonErr := t.processTransaction(httpReq.Context()) if jsonErr != nil { @@ -221,7 +221,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} verRes := api.QueryRoomVersionForRoomResponse{} if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) return "" } roomVersions[roomID] = verRes.RoomVersion @@ -234,7 +234,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res RoomID string `json:"room_id"` } if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to extract room ID from event") + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") // We don't know the event ID at this point so we can't return the // failure in the PDU results continue @@ -255,7 +255,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res JSON: jsonerror.BadJSON("PDU contains bad JSON"), } } - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { @@ -265,7 +265,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res continue } if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -287,7 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res nil, true, ); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: err.Error(), } @@ -314,16 +314,16 @@ func (t *txnReq) processEDUs(ctx context.Context) { Typing bool `json:"typing"` } if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal typing event") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") continue } _, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID) if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to split domain from typing event sender") + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from typing event sender") continue } if domain != t.Origin { - util.GetLogger(ctx).Warnf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + util.GetLogger(ctx).Debugf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) continue } if err := eduserverAPI.SendTyping(ctx, t.eduAPI, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { @@ -333,7 +333,7 @@ func (t *txnReq) processEDUs(ctx context.Context) { // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema var directPayload gomatrixserverlib.ToDeviceMessage if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal send-to-device events") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") continue } for userID, byUser := range directPayload.Messages { @@ -355,7 +355,7 @@ func (t *txnReq) processEDUs(ctx context.Context) { payload := map[string]eduserverAPI.FederationReceiptMRead{} if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal receipt event") + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") continue } @@ -363,11 +363,11 @@ func (t *txnReq) processEDUs(ctx context.Context) { for userID, mread := range receipt.User { _, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to split domain from receipt event sender") + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") continue } if t.Origin != domain { - util.GetLogger(ctx).Warnf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) continue } if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { @@ -386,7 +386,7 @@ func (t *txnReq) processEDUs(ctx context.Context) { if err := json.Unmarshal(e.Content, &updatePayload); err != nil { util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ "user_id": updatePayload.UserID, - }).Error("Failed to send signing key update to edu server") + }).Debug("Failed to send signing key update to edu server") continue } inputReq := &eduserverAPI.InputCrossSigningKeyUpdateRequest{ diff --git a/go.mod b/go.mod index 01cff763c..bd695e6dd 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index f04754295..8952d034e 100644 --- a/go.sum +++ b/go.sum @@ -429,7 +429,6 @@ github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL9 github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= @@ -553,9 +552,7 @@ github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= @@ -1024,8 +1021,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32 h1:DiWPsGAYMlBQq/urm7TJkIeSf9FnfzegcaQUpgwIbUs= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220128100033-8d79e0c35e32/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8 h1:DeGMNY2iJ2u2zEOvUIJdrTPjh5mAa3Mim14hhAvS5zs= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1169,7 +1166,6 @@ github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= -github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= @@ -1579,7 +1575,6 @@ golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= @@ -1890,9 +1885,7 @@ golang.zx2c4.com/wireguard v0.0.0-20210604143328-f9b48a961cd2/go.mod h1:laHzsbfM golang.zx2c4.com/wireguard v0.0.0-20210927201915-bb745b2ea326/go.mod h1:SDoazCvdy7RDjBPNEMBwrXhomlmtG7svs8mgwWEqtVI= golang.zx2c4.com/wireguard/windows v0.3.14/go.mod h1:3P4IEAsb+BjlKZmpUXgy74c0iX9AVwwr3WcVJ8nPgME= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/api v0.0.0-20160322025152-9bf6e6e569ff/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index fd143c6cf..9e1c4c645 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -65,7 +65,7 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { logrus.WithFields(logrus.Fields{ "user_id": userID, "num_key_changes": count, - }).Infof("Produced to key change topic '%s'", p.Topic) + }).Tracef("Produced to key change topic '%s'", p.Topic) } return nil } @@ -103,6 +103,6 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er logrus.WithFields(logrus.Fields{ "user_id": key.UserID, - }).Infof("Produced to cross-signing update topic '%s'", p.Topic) + }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) return nil } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 2358915ee..4ce738b6e 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -157,7 +157,7 @@ func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSON // Set status code and write the body w.WriteHeader(res.Code) - r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) + r.Logger.WithField("code", res.Code).Tracef("Responding (%d bytes)", len(resBytes)) // we don't really care that much if we fail to write the error response w.Write(resBytes) // nolint: errcheck @@ -293,11 +293,11 @@ func (r *downloadRequest) respondFromLocalFile( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("No good thumbnail found. Responding with original file.") + }).Trace("No good thumbnail found. Responding with original file.") responseFile = file responseMetadata = r.MediaMetadata } else { - r.Logger.Info("Responding with thumbnail") + r.Logger.Trace("Responding with thumbnail") responseFile = thumbFile responseMetadata = thumbMetadata.MediaMetadata } @@ -307,7 +307,7 @@ func (r *downloadRequest) respondFromLocalFile( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("Responding with file") + }).Trace("Responding with file") responseFile = file responseMetadata = r.MediaMetadata if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { @@ -436,7 +436,7 @@ func (r *downloadRequest) getThumbnailFile( "Width": thumbnailSize.Width, "Height": thumbnailSize.Height, "ResizeMethod": thumbnailSize.ResizeMethod, - }).Info("Pre-generating thumbnail for immediate response.") + }).Debug("Pre-generating thumbnail for immediate response.") thumbnail, err = r.generateThumbnail( ctx, filePath, *thumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db, @@ -574,7 +574,7 @@ func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests defer activeRemoteRequests.Unlock() if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Waiting for another goroutine to fetch the remote file.") + r.Logger.Trace("Waiting for another goroutine to fetch the remote file.") // NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this. activeRemoteRequestResult.Cond.Wait() @@ -604,7 +604,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act defer activeRemoteRequests.Unlock() mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { - r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") + r.Logger.Trace("Signalling other goroutines waiting for this goroutine to fetch the file.") activeRemoteRequestResult.MediaMetadata = r.MediaMetadata activeRemoteRequestResult.Error = err activeRemoteRequestResult.Cond.Broadcast() @@ -635,7 +635,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( "UploadName": r.MediaMetadata.UploadName, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Info("Storing file metadata to media repository database") + }).Debug("Storing file metadata to media repository database") // FIXME: timeout db request if err := db.StoreMediaMetadata(ctx, r.MediaMetadata); err != nil { @@ -669,7 +669,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( "Base64Hash": r.MediaMetadata.Base64Hash, "FileSizeBytes": r.MediaMetadata.FileSizeBytes, "ContentType": r.MediaMetadata.ContentType, - }).Infof("Remote file cached") + }).Debug("Remote file cached") return nil } @@ -717,7 +717,7 @@ func (r *downloadRequest) fetchRemoteFile( absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, ) (types.Path, bool, error) { - r.Logger.Info("Fetching remote file") + r.Logger.Debug("Fetching remote file") // create request for remote file resp, err := r.createRemoteRequest(ctx, client) @@ -762,7 +762,7 @@ func (r *downloadRequest) fetchRemoteFile( } } - r.Logger.Info("Transferring remote file") + r.Logger.Trace("Transferring remote file") // The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct @@ -776,7 +776,7 @@ func (r *downloadRequest) fetchRemoteFile( return "", false, errors.New("file could not be downloaded from remote server") } - r.Logger.Info("Remote file transferred") + r.Logger.Trace("Remote file transferred") // It's possible the bytesWritten to the temporary file is different to the reported Content-Length from the remote // request's response. bytesWritten is therefore used as it is what would be sent to clients when reading from the local @@ -790,7 +790,7 @@ func (r *downloadRequest) fetchRemoteFile( return "", false, fmt.Errorf("fileutils.MoveFileWithHashCheck: %w", err) } if duplicate { - r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate") + r.Logger.WithField("dst", finalPath).Trace("File was stored previously - discarding duplicate") // Continue on to store the metadata in the database } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 938d5ac1a..a38d56d7e 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -104,6 +105,11 @@ func (r *Inputer) Start() error { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": inputRoomEvent.Event.EventID(), + "type": inputRoomEvent.Event.Type(), + }).Warn("Roomserver failed to process async event") } _ = msg.Ack() }) @@ -146,6 +152,10 @@ func (r *Inputer) InputRoomEvents( return } if _, err = r.JetStream.PublishMsg(msg); err != nil { + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": e.Event.EventID(), + }).Error("Roomserver failed to queue async event") return } } @@ -173,6 +183,10 @@ func (r *Inputer) InputRoomEvents( if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } + logrus.WithError(err).WithFields(logrus.Fields{ + "room_id": roomID, + "event_id": inputRoomEvent.Event.EventID(), + }).Warn("Roomserver failed to process sync event") } select { case <-ctx.Done(): diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 862b3a7fe..aa2b94f8a 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -327,7 +327,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room queryReq := api.QueryEventsByIDRequest{ EventIDs: missingEventList, } - util.GetLogger(ctx).WithField("count", len(missingEventList)).Infof("Fetching missing auth events") + util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") var queryRes api.QueryEventsByIDResponse if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { return nil @@ -382,7 +382,7 @@ retryAllowedState: default: return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } - util.GetLogger(ctx).Infof("fetched event %s", missing.AuthEventID) + util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID) resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) goto retryAllowedState default: @@ -429,7 +429,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve missingResp = &m break } else { - logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) + logger.WithError(err).Warnf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) if errors.Is(err, context.DeadlineExceeded) { select { case <-ctx.Done(): // the parent request context timed out @@ -442,7 +442,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } if missingResp == nil { - logger.WithError(err).Errorf( + logger.WithError(err).Warnf( "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, len(t.servers), ) @@ -454,7 +454,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. - logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) + logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) for i, ev := range missingResp.Events { missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() } @@ -474,7 +474,7 @@ Event: } if !hasPrevEvent { err = fmt.Errorf("called /get_missing_events but server %s didn't return any prev_events with IDs %v", t.origin, shouldHaveSomeEventIDs) - logger.WithError(err).Errorf( + logger.WithError(err).Warnf( "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, ) @@ -565,7 +565,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo concurrentRequests := 8 missingCount := len(missing) - util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Infof("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) + util.GetLogger(ctx).WithField("room_id", roomID).WithField("event_id", eventID).Debugf("lookupMissingStateViaStateIDs missing %d/%d events", missingCount, len(wantIDs)) // If over 50% of the auth/state events from /state_ids are missing // then we'll just call /state instead, otherwise we'll just end up @@ -577,7 +577,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo "room_id": roomID, "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), - }).Info("Fetching all state at event") + }).Debug("Fetching all state at event") return t.lookupMissingStateViaState(ctx, roomID, eventID, roomVersion) } @@ -589,7 +589,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo "total_state": len(stateIDs.StateEventIDs), "total_auth_events": len(stateIDs.AuthEventIDs), "concurrent_requests": concurrentRequests, - }).Info("Fetching missing state at event") + }).Debug("Fetching missing state at event") // Create a queue containing all of the missing event IDs that we want // to retrieve. @@ -626,7 +626,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": missingEventID, "room_id": roomID, - }).Info("Failed to fetch missing event") + }).Warn("Failed to fetch missing event") return } haveEventsMutex.Lock() diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 85b2322fe..e23ed47be 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -60,7 +60,7 @@ func (r *Inviter) PerformInvite( "room_version": req.RoomVersion, "target_user_id": targetUserID, "room_info_exists": info != nil, - }).Info("processing invite event") + }).Debug("processing invite event") _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) isTargetLocal := domain == r.Cfg.Matrix.ServerName diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index a1ffab5dd..2b0bccda6 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -53,6 +53,11 @@ func (r *Joiner) PerformJoin( ) { roomID, joinedVia, err := r.performJoin(ctx, req) if err != nil { + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomIDOrAlias, + "user_id": req.UserID, + "servers": req.ServerNames, + }).WithError(err).Error("Failed to join room") sentry.CaptureException(err) perr, ok := err.(*rsAPI.PerformError) if ok { diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index eac528eaf..b19916491 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) type Leaver struct { @@ -51,7 +52,14 @@ func (r *Leaver) PerformLeave( return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } if strings.HasPrefix(req.RoomID, "!") { - return r.performLeaveRoomByID(ctx, req, res) + output, err := r.performLeaveRoomByID(ctx, req, res) + if err != nil { + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomID, + "user_id": req.UserID, + }).WithError(err).Error("Failed to leave room") + } + return output, err } return nil, fmt.Errorf("room ID %q is invalid", req.RoomID) } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 1ec9beb04..3d340a16a 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -85,7 +85,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { log.WithFields(log.Fields{ "type": output.Type, "room_id": output.RoomID, - }).Info("received data from client API server") + }).Debug("Received data from client API server") streamPos, err := s.db.UpsertAccountData( s.ctx, userID, output.RoomID, output.Type, From 5367e7ed2c96175bf1b8652953da3a90cc2259df Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 10:51:01 +0000 Subject: [PATCH 11/81] Update to matrix-org/gomatrixserverlib@801c51af9f29e3630c8d83b0772c7ba52c0d8908 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bd695e6dd..33dafa8ac 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 8952d034e..800e03c56 100644 --- a/go.sum +++ b/go.sum @@ -1021,8 +1021,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8 h1:DeGMNY2iJ2u2zEOvUIJdrTPjh5mAa3Mim14hhAvS5zs= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131095121-cde7ac8c5bb8/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 h1:1t/J3AldUbgRxltlcmMbUefexxzolG5DvV2CkriZ4LM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 2995f73ae0464bc1d267f000c4876bd1f0dfcee3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 11:16:21 +0000 Subject: [PATCH 12/81] Update prometheus client --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 33dafa8ac..78fe4b0ef 100644 --- a/go.mod +++ b/go.mod @@ -54,7 +54,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0+incompatible - github.com/prometheus/client_golang v1.12.0 + github.com/prometheus/client_golang v1.12.1 github.com/sirupsen/logrus v1.8.1 github.com/tidwall/gjson v1.13.0 github.com/tidwall/sjson v1.2.4 diff --git a/go.sum b/go.sum index 800e03c56..d332071ec 100644 --- a/go.sum +++ b/go.sum @@ -1273,8 +1273,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.0 h1:C+UIj/QWtmqY13Arb8kwMt5j34/0Z2iKamrJ+ryC0Gg= -github.com/prometheus/client_golang v1.12.0/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= +github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= +github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= From dac762d025acc333f6f08eb494ae4386d1144b96 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 11:47:07 +0000 Subject: [PATCH 13/81] Revert Prometheus client upgrades altogether --- go.mod | 7 ++++++- go.sum | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 78fe4b0ef..6ef433c09 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 github.com/Shopify/sarama v1.31.0 + github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible @@ -23,6 +24,7 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 + github.com/json-iterator/go v1.1.12 // indirect github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect github.com/klauspost/compress v1.14.2 // indirect github.com/lib/pq v1.10.4 @@ -54,7 +56,9 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0+incompatible - github.com/prometheus/client_golang v1.12.1 + github.com/prometheus/client_golang v1.11.0 + github.com/prometheus/common v0.32.1 // indirect + github.com/prometheus/procfs v0.7.3 // indirect github.com/sirupsen/logrus v1.8.1 github.com/tidwall/gjson v1.13.0 github.com/tidwall/sjson v1.2.4 @@ -66,6 +70,7 @@ require ( golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index d332071ec..af416db70 100644 --- a/go.sum +++ b/go.sum @@ -1272,9 +1272,8 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= From f9547a53d2c378cf08b9a8296d2d727c11f535ba Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 12:01:53 +0000 Subject: [PATCH 14/81] Tweak roomserver logging for rejected events --- roomserver/internal/input/input_events.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2dc096674..f42168053 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -157,7 +157,7 @@ func (r *Inputer) processRoomEvent( var rejectionErr error if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s rejected", event.EventID()) + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) } // Accumulate the auth event NIDs. @@ -176,7 +176,7 @@ func (r *Inputer) processRoomEvent( // current room state. softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { - logger.WithError(err).Info("Error authing soft-failed event") + logger.WithError(err).Warn("Error authing soft-failed event") } } @@ -266,7 +266,10 @@ func (r *Inputer) processRoomEvent( // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. if isRejected || softfail { - logger.WithError(rejectionErr).WithField("soft_fail", softfail).Debug("Stored rejected event") + logger.WithError(rejectionErr).WithFields(logrus.Fields{ + "soft_fail": softfail, + "missing_prev": missingPrev, + }).Warn("Stored rejected event") return rejectionErr } From 1d5fd99cad518dcd7d387aa950c54a710452b71f Mon Sep 17 00:00:00 2001 From: Hoernschen Date: Mon, 31 Jan 2022 14:44:52 +0100 Subject: [PATCH 15/81] Allow uppercase username on login (#2126) * ADD jetstream folder to gitignore * CHANGE login to check on uppercase if lowercase not exists Co-authored-by: kegsay --- .gitignore | 1 + clientapi/auth/password.go | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index dbc84edb1..092f4501c 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ /vendor/bin /docker/build /logs +/jetstream # Architecture specific extensions/prefixes *.[568vq] diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 7dd21b3f2..9179d8da1 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -16,6 +16,7 @@ package auth import ( "context" + "database/sql" "net/http" "strings" @@ -49,8 +50,7 @@ func (t *LoginTypePassword) Request() interface{} { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - // Squash username to all lowercase letters - username := strings.ToLower(r.Username()) + username := r.Username() if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -64,8 +64,15 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.InvalidUsername(err.Error()), } } - _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + // Squash username to all lowercase letters + _, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password) if err != nil { + if err == sql.ErrNoRows { + _, err = t.GetAccountByPassword(ctx, localpart, r.Password) + if err == nil { + return &r.Login, nil + } + } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. return nil, &util.JSONResponse{ From 567fd0442868862c5499a64b8b100623d28fe4f6 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 14:29:13 +0000 Subject: [PATCH 16/81] Update to matrix-org/gomatrixserverlib#286 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6ef433c09..c36fbe3b3 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index af416db70..f72b4f4fb 100644 --- a/go.sum +++ b/go.sum @@ -1021,8 +1021,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29 h1:1t/J3AldUbgRxltlcmMbUefexxzolG5DvV2CkriZ4LM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131105022-801c51af9f29/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6 h1:v+WZXRsn9IaW3mta6bPICWbWcaZbnB1u1ZFlGFi/YU8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From d21f3eace0748e0b1385243be50c89d0d55c32d0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 14:36:59 +0000 Subject: [PATCH 17/81] Roomserver fixes (#2133) * Improve server selection somewhat * Remove things from the map when we're done * Be less panicky about auth event signatures in case they are not fatal after all * Accept HasState in all cases * Send join asynchronously * Revert "Send join asynchronously" This reverts commit 5b685bfcd0b1150a66c7b1e70fb3a3eda509efd1. * Joins and leaves use background context --- roomserver/internal/input/input_events.go | 47 +++++++++++++------- roomserver/internal/input/input_missing.go | 6 +-- roomserver/internal/perform/perform_join.go | 2 +- roomserver/internal/perform/perform_leave.go | 2 +- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index f42168053..147103cf5 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -130,7 +130,10 @@ func (r *Inputer) processRoomEvent( return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) } } - if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 { + missingAuth := len(missingRes.MissingAuthEventIDs) > 0 + missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 + + if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), ExcludeSelf: true, @@ -138,9 +141,26 @@ func (r *Inputer) processRoomEvent( if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } - } - if input.Origin != "" { - serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + // Sort all of the servers into a map so that we can randomise + // their order. Then make sure that the input origin and the + // event origin are first on the list. + servers := map[gomatrixserverlib.ServerName]struct{}{} + for _, server := range serverRes.ServerNames { + servers[server] = struct{}{} + } + serverRes.ServerNames = serverRes.ServerNames[:0] + if input.Origin != "" { + serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + delete(servers, input.Origin) + } + if origin := event.Origin(); origin != input.Origin { + serverRes.ServerNames = append(serverRes.ServerNames, origin) + delete(servers, origin) + } + for server := range servers { + serverRes.ServerNames = append(serverRes.ServerNames, server) + delete(servers, server) + } } // First of all, check that the auth events of the event are known. @@ -149,7 +169,7 @@ func (r *Inputer) processRoomEvent( authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.checkForMissingAuthEvents: %w", err) + return fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -190,7 +210,6 @@ func (r *Inputer) processRoomEvent( // typical federated room join) then we won't bother trying to fetch prev events // because we may not be allowed to see them and we have no choice but to trust // the state event IDs provided to us in the join instead. - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingPrev && input.Kind == api.KindNew { // Don't do this for KindOld events, otherwise old events that we fetch // to satisfy missing prev events/state will end up recursively calling @@ -204,13 +223,10 @@ func (r *Inputer) processRoomEvent( federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), - servers: map[gomatrixserverlib.ServerName]struct{}{}, + servers: serverRes.ServerNames, hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - for _, serverName := range serverRes.ServerNames { - missingState.servers[serverName] = struct{}{} - } if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) @@ -399,12 +415,11 @@ func (r *Inputer) fetchAuthEvents( continue } - // Check the signatures of the event. - // TODO: It really makes sense for the federation API to be doing this, - // because then it can attempt another server if one serves up an event - // with an invalid signature. For now this will do. + // Check the signatures of the event. If this fails then we'll simply + // skip it, because gomatrixserverlib.Allowed() will notice a problem + // if a critical event is missing anyway. if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { - return fmt.Errorf("event.VerifyEventSignatures: %w", err) + continue } // In order to store the new auth event, we need to know its auth chain @@ -457,7 +472,7 @@ func (r *Inputer) calculateAndSetState( var err error roomState := state.NewStateResolution(r.DB, roomInfo) - if input.HasState && !isRejected { + if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index aa2b94f8a..02ff0f8da 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -25,7 +25,7 @@ type missingStateReq struct { keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom - servers map[gomatrixserverlib.ServerName]struct{} + servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -417,7 +417,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } var missingResp *gomatrixserverlib.RespMissingEvents - for server := range t.servers { + for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, @@ -700,7 +700,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } var event *gomatrixserverlib.Event found := false - for serverName := range t.servers { + for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 2b0bccda6..dfa21bcbd 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -51,7 +51,7 @@ func (r *Joiner) PerformJoin( req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, ) { - roomID, joinedVia, err := r.performJoin(ctx, req) + roomID, joinedVia, err := r.performJoin(context.Background(), req) if err != nil { logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomIDOrAlias, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index b19916491..3c46e6573 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave( return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } if strings.HasPrefix(req.RoomID, "!") { - output, err := r.performLeaveRoomByID(ctx, req, res) + output, err := r.performLeaveRoomByID(context.Background(), req, res) if err != nil { logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID, From 07d0e72a8b2fd48cac50026346be4810c5768d36 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 15:33:00 +0000 Subject: [PATCH 18/81] Improve roomserver logging --- roomserver/api/input.go | 13 +++++++++++++ roomserver/internal/input/input_events.go | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 4b0704b9f..45a9ef497 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -42,6 +42,19 @@ const ( KindOld ) +func (k Kind) String() string { + switch k { + case KindOutlier: + return "KindOutlier" + case KindNew: + return "KindNew" + case KindOld: + return "KindOld" + default: + return "(unknown)" + } +} + // DoNotSendToOtherServers tells us not to send the event to other matrix // servers. const DoNotSendToOtherServers = "" diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 147103cf5..16703616e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -93,8 +93,16 @@ func (r *Inputer) processRoomEvent( logger := util.GetLogger(ctx).WithFields(logrus.Fields{ "event_id": event.EventID(), "room_id": event.RoomID(), + "kind": input.Kind, + "origin": input.Origin, "type": event.Type(), }) + if input.HasState { + logger = logger.WithFields(logrus.Fields{ + "has_state": input.HasState, + "state_ids": len(input.StateEventIDs), + }) + } // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. From 893aa3b1414f46a44900e7377b5b471cb2aff0f3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 31 Jan 2022 16:01:54 +0000 Subject: [PATCH 19/81] More logging tweaks --- roomserver/internal/input/input_missing.go | 8 ++++---- roomserver/internal/perform/perform_join.go | 14 +++++++++----- roomserver/internal/perform/perform_leave.go | 12 ++++++++---- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 02ff0f8da..d401fa0e9 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -666,7 +666,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] if !ok { - logrus.Warnf("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) + logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) continue } respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) @@ -674,7 +674,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib for i := range stateIDs.AuthEventIDs { ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] if !ok { - logrus.Warnf("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) + logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) continue } respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) @@ -718,7 +718,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } event, err = gomatrixserverlib.NewEventFromUntrustedJSON(txn.PDUs[0], roomVersion) if err != nil { - util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Transaction: Failed to parse event JSON of event") + util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warnf("Failed to parse event JSON of event returned from /event") continue } found = true @@ -729,7 +729,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } if err := event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) + util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event.Headered(roomVersion)), nil diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index dfa21bcbd..9d2a66d4c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -51,13 +51,15 @@ func (r *Joiner) PerformJoin( req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, ) { + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomIDOrAlias, + "user_id": req.UserID, + "servers": req.ServerNames, + }) + logger.Info("User requested to room join") roomID, joinedVia, err := r.performJoin(context.Background(), req) if err != nil { - logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": req.RoomIDOrAlias, - "user_id": req.UserID, - "servers": req.ServerNames, - }).WithError(err).Error("Failed to join room") + logger.WithError(err).Error("Failed to join room") sentry.CaptureException(err) perr, ok := err.(*rsAPI.PerformError) if ok { @@ -67,7 +69,9 @@ func (r *Joiner) PerformJoin( Msg: err.Error(), } } + return } + logger.Info("User joined room successfully") res.RoomID = roomID res.JoinedVia = joinedVia } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 3c46e6573..12784e5f5 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -51,13 +51,17 @@ func (r *Leaver) PerformLeave( if domain != r.Cfg.Matrix.ServerName { return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } + logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ + "room_id": req.RoomID, + "user_id": req.UserID, + }) + logger.Info("User requested to leave join") if strings.HasPrefix(req.RoomID, "!") { output, err := r.performLeaveRoomByID(context.Background(), req, res) if err != nil { - logrus.WithContext(ctx).WithFields(logrus.Fields{ - "room_id": req.RoomID, - "user_id": req.UserID, - }).WithError(err).Error("Failed to leave room") + logger.WithError(err).Error("Failed to leave room") + } else { + logger.Info("User left room successfully") } return output, err } From 9ada4578e36b367c53057c9ee32d044e2ba26395 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 1 Feb 2022 16:03:30 +0000 Subject: [PATCH 20/81] Fix JetStream paths for P2P demo builds --- build/gobind-pinecone/monolith.go | 2 +- build/gobind-yggdrasil/monolith.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 1c9c0ac4e..211b8d653 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -281,7 +281,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/%s", m.StorageDirectory, prefix)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 1aae418d1..3d9ba8aa0 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -86,7 +86,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName()) cfg.Global.PrivateKey = ygg.PrivateKey() cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("file:%s/", m.StorageDirectory)) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) From a09d71d231861f8825a4f8f1dfd79311c4c236a6 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 1 Feb 2022 16:36:17 +0000 Subject: [PATCH 21/81] Support CA certificates in CI (#2136) * Support CA setting in generate-keys * Set DNS names correctly * Use generate-config -server not sed --- build/scripts/Complement.Dockerfile | 10 +++- cmd/generate-config/main.go | 2 +- cmd/generate-keys/main.go | 20 +++++-- internal/test/config.go | 93 +++++++++++++++++++++++++---- 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 55b381ba5..401695abf 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -12,10 +12,14 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config -RUN ./generate-config --ci > dendrite.yaml -RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key +RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost EXPOSE 8008 8448 -CMD sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml && ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml +# At runtime, generate TLS cert based on the CA now mounted at /ca +# At runtime, replace the SERVER_NAME with what we are told +CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ + ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ + cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index a79470d83..60729672e 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -83,7 +83,7 @@ func main() { if *defaultsForCI { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false - cfg.FederationAPI.DisableTLSValidation = true + cfg.FederationAPI.DisableTLSValidation = false // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MSCs.MSCs = []string{"msc2836", "msc2946", "msc2444", "msc2753"} diff --git a/cmd/generate-keys/main.go b/cmd/generate-keys/main.go index 743109f13..bddf219dc 100644 --- a/cmd/generate-keys/main.go +++ b/cmd/generate-keys/main.go @@ -32,9 +32,12 @@ Arguments: ` var ( - tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") - tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") - privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + tlsCertFile = flag.String("tls-cert", "", "An X509 certificate file to generate for use for TLS") + tlsKeyFile = flag.String("tls-key", "", "An RSA private key file to generate for use for TLS") + privateKeyFile = flag.String("private-key", "", "An Ed25519 private key to generate for use for object signing") + authorityCertFile = flag.String("tls-authority-cert", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + authorityKeyFile = flag.String("tls-authority-key", "", "Optional: Create TLS certificate/keys based on this CA authority. Useful for integration testing.") + serverName = flag.String("server", "", "Optional: Create TLS certificate/keys with this domain name set. Useful for integration testing.") ) func main() { @@ -54,8 +57,15 @@ func main() { if *tlsCertFile == "" || *tlsKeyFile == "" { log.Fatal("Zero or both of --tls-key and --tls-cert must be supplied") } - if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { - panic(err) + if *authorityCertFile == "" && *authorityKeyFile == "" { + if err := test.NewTLSKey(*tlsKeyFile, *tlsCertFile); err != nil { + panic(err) + } + } else { + // generate the TLS cert/key based on the authority given. + if err := test.NewTLSKeyWithAuthority(*serverName, *tlsKeyFile, *tlsCertFile, *authorityKeyFile, *authorityCertFile); err != nil { + panic(err) + } } fmt.Printf("Created TLS cert file: %s\n", *tlsCertFile) fmt.Printf("Created TLS key file: %s\n", *tlsKeyFile) diff --git a/internal/test/config.go b/internal/test/config.go index bb2f8a4c6..4fb6a946c 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/pem" + "errors" "fmt" "io/ioutil" "math/big" @@ -158,11 +159,10 @@ func NewMatrixKey(matrixKeyPath string) (err error) { const certificateDuration = time.Hour * 24 * 365 * 10 -// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. -func NewTLSKey(tlsKeyPath, tlsCertPath string) error { +func generateTLSTemplate(dnsNames []string) (*rsa.PrivateKey, *x509.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { - return err + return nil, nil, err } notBefore := time.Now() @@ -170,7 +170,7 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return err + return nil, nil, err } template := x509.Certificate{ @@ -180,20 +180,21 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, + DNSNames: dnsNames, } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return err - } + return priv, &template, nil +} + +func writeCertificate(tlsCertPath string, derBytes []byte) error { certOut, err := os.Create(tlsCertPath) if err != nil { return err } defer certOut.Close() // nolint: errcheck - if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - return err - } + return pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) +} +func writePrivateKey(tlsKeyPath string, priv *rsa.PrivateKey) error { keyOut, err := os.OpenFile(tlsKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err @@ -205,3 +206,73 @@ func NewTLSKey(tlsKeyPath, tlsCertPath string) error { }) return err } + +// NewTLSKey generates a new RSA TLS key and certificate and writes it to a file. +func NewTLSKey(tlsKeyPath, tlsCertPath string) error { + priv, template, err := generateTLSTemplate(nil) + if err != nil { + return err + } + + // Self-signed certificate: template == parent + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} + +func NewTLSKeyWithAuthority(serverName, tlsKeyPath, tlsCertPath, authorityKeyPath, authorityCertPath string) error { + priv, template, err := generateTLSTemplate([]string{serverName}) + if err != nil { + return err + } + + // load the authority key + dat, err := ioutil.ReadFile(authorityKeyPath) + if err != nil { + return err + } + block, _ := pem.Decode([]byte(dat)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return errors.New("authority .key is not a valid pem encoded rsa private key") + } + authorityPriv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return err + } + + // load the authority certificate + dat, err = ioutil.ReadFile(authorityCertPath) + if err != nil { + return err + } + block, _ = pem.Decode([]byte(dat)) + if block == nil || block.Type != "CERTIFICATE" { + return errors.New("authority .crt is not a valid pem encoded x509 cert") + } + var caCerts []*x509.Certificate + caCerts, err = x509.ParseCertificates(block.Bytes) + if err != nil { + return err + } + if len(caCerts) != 1 { + return errors.New("authority .crt contains none or more than one cert") + } + authorityCert := caCerts[0] + + // Sign the new certificate using the authority's key/cert + derBytes, err := x509.CreateCertificate(rand.Reader, template, authorityCert, &priv.PublicKey, authorityPriv) + if err != nil { + return err + } + + if err = writeCertificate(tlsCertPath, derBytes); err != nil { + return err + } + return writePrivateKey(tlsKeyPath, priv) +} From 2dee706f9ef2de70516dbc993dcfc8ec6f7fdd52 Mon Sep 17 00:00:00 2001 From: kegsay Date: Wed, 2 Feb 2022 13:30:48 +0000 Subject: [PATCH 22/81] PerformInvite: bugfix and rejig control flow (#2137) * PerformInvite: bugfix and rejig control flow Local clients would not be notified of invites to rooms Dendrite had already joined in all cases due to not returning an `api.OutputNewInviteEvent` for local invites. We now do this. This was an easy mistake to make due to the control flow of the function which doesn't handle the happy case at the end of the function and instead forks the function depending on if the invite was via federation or not. This has now been changed to handle the federated invite as if it were an error (in that we check it, do it and bail out) rather than outstay our welcome. This ends up with the local invite being the happy case, which now both sends an `InputRoomEvent` to the roomserver _and_ a `api.OutputNewInviteEvent` is returned. * Don't send invite pokes in PerformInvite * Move event ID into logger --- roomserver/internal/perform/perform_invite.go | 161 ++++++++++-------- 1 file changed, 88 insertions(+), 73 deletions(-) diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e23ed47be..6559cd081 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) @@ -54,18 +55,23 @@ func (r *Inviter) PerformInvite( return nil, fmt.Errorf("failed to load RoomInfo: %w", err) } - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": roomID, - "room_version": req.RoomVersion, - "target_user_id": targetUserID, - "room_info_exists": info != nil, - }).Debug("processing invite event") - _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID) isTargetLocal := domain == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName + logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ + "inviter": event.Sender(), + "invitee": *event.StateKey(), + "room_id": roomID, + "event_id": event.EventID(), + }) + logger.WithFields(log.Fields{ + "room_version": req.RoomVersion, + "room_info_exists": info != nil, + "target_local": isTargetLocal, + "origin_local": isOriginLocal, + }).Debug("processing invite event") + inviteState := req.InviteRoomState if len(inviteState) == 0 && info != nil { var is []gomatrixserverlib.InviteV2StrippedState @@ -122,75 +128,17 @@ func (r *Inviter) PerformInvite( Code: api.PerformErrorNotAllowed, Msg: "User is already joined to room", } + logger.Debugf("user already joined") return nil, nil } - if isOriginLocal { - // The invite originated locally. Therefore we have a responsibility to - // try and see if the user is allowed to make this invite. We can't do - // this for invites coming in over federation - we have to take those on - // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) - if err != nil { - log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( - "processInviteEvent.checkAuthEvents failed for event", - ) - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - } - - // If the invite originated from us and the target isn't local then we - // should try and send the invite over federation first. It might be - // that the remote user doesn't exist, in which case we can give up - // processing here. - if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { - fsReq := &federationAPI.PerformInviteRequest{ - RoomVersion: req.RoomVersion, - Event: event, - InviteRoomState: inviteState, - } - fsRes := &federationAPI.PerformInviteResponse{} - if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") - return nil, nil - } - event = fsRes.Event - } - - // Send the invite event to the roomserver input stream. This will - // notify existing users in the room about the invite, update the - // membership table and ensure that the event is ready and available - // to use as an auth event when accepting the invite. - inputReq := &api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{ - { - Kind: api.KindNew, - Event: event, - Origin: event.Origin(), - SendAsServer: req.SendAsServer, - }, - }, - } - inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) - if err = inputRes.Err(); err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), - Code: api.PerformErrorNotAllowed, - } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") - return nil, nil - } - } else { + if !isOriginLocal { // The invite originated over federation. Process the membership // update, which will notify the sync API etc about the incoming - // invite. + // invite. We do NOT send an InputRoomEvent for the invite as it + // will never pass auth checks due to lacking room state, but we + // still need to tell the client about the invite so we can accept + // it, hence we return an output event to send to the sync api. updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) @@ -205,10 +153,77 @@ func (r *Inviter) PerformInvite( if err = updater.Commit(); err != nil { return nil, fmt.Errorf("updater.Commit: %w", err) } - + logger.Debugf("updated membership to invite and sending invite OutputEvent") return outputUpdates, nil } + // The invite originated locally. Therefore we have a responsibility to + // try and see if the user is allowed to make this invite. We can't do + // this for invites coming in over federation - we have to take those on + // trust. + _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + if err != nil { + logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( + "processInviteEvent.checkAuthEvents failed for event", + ) + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + return nil, nil + } + + // If the invite originated from us and the target isn't local then we + // should try and send the invite over federation first. It might be + // that the remote user doesn't exist, in which case we can give up + // processing here. + if req.SendAsServer != api.DoNotSendToOtherServers && !isTargetLocal { + fsReq := &federationAPI.PerformInviteRequest{ + RoomVersion: req.RoomVersion, + Event: event, + InviteRoomState: inviteState, + } + fsRes := &federationAPI.PerformInviteResponse{} + if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") + return nil, nil + } + event = fsRes.Event + logger.Debugf("Federated PerformInvite success with event ID %s", event.EventID()) + } + + // Send the invite event to the roomserver input stream. This will + // notify existing users in the room about the invite, update the + // membership table and ensure that the event is ready and available + // to use as an auth event when accepting the invite. + // It will NOT notify the invitee of this invite. + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: event, + Origin: event.Origin(), + SendAsServer: req.SendAsServer, + }, + }, + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = inputRes.Err(); err != nil { + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), + Code: api.PerformErrorNotAllowed, + } + logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return nil, nil + } + + // Don't notify the sync api of this event in the same way as a federated invite so the invitee + // gets the invite, as the roomserver will do this when it processes the m.room.member invite. return nil, nil } From c773b038bb1432f2265759ddf1da5e98b9bda525 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Feb 2022 13:32:48 +0000 Subject: [PATCH 23/81] Use pull consumers (#2140) * Pull consumers * Pull consumers * Only nuke consumers if they are push consumers * Clean up old consumers * Better error handling * Update comments --- appservice/consumers/roomserver.go | 54 ++-- federationapi/consumers/eduserver.go | 329 ++++++++++---------- federationapi/consumers/roomserver.go | 108 ++++--- roomserver/internal/api.go | 4 +- setup/config/config_jetstream.go | 6 +- setup/jetstream/helpers.go | 83 ++++- syncapi/consumers/clientapi.go | 72 ++--- syncapi/consumers/eduserver_receipts.go | 60 ++-- syncapi/consumers/eduserver_sendtodevice.go | 92 +++--- syncapi/consumers/eduserver_typing.go | 70 ++--- syncapi/consumers/roomserver.go | 102 +++--- 11 files changed, 521 insertions(+), 459 deletions(-) diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index 8aea5c347..7b59e3704 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -34,7 +34,7 @@ import ( type OutputRoomEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string asDB storage.Database rsAPI api.RoomserverInternalAPI @@ -66,37 +66,37 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the appservice component receives a new event from // the room server output log. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - if output.Type != api.OutputTypeNewRoomEvent { - return true - } - - events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) - - // Send event to any relevant application services - if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { - log.WithError(err).Errorf("roomserver output log: filter error") - return true - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + if output.Type != api.OutputTypeNewRoomEvent { + return true + } + + events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) + + // Send event to any relevant application services + if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { + log.WithError(err).Errorf("roomserver output log: filter error") + return true + } + + return true } // filterRoomserverEvents takes in events and decides whether any of them need diff --git a/federationapi/consumers/eduserver.go b/federationapi/consumers/eduserver.go index c3e5b4d49..22fedbeb4 100644 --- a/federationapi/consumers/eduserver.go +++ b/federationapi/consumers/eduserver.go @@ -34,7 +34,7 @@ import ( type OutputEDUConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName @@ -66,13 +66,22 @@ func NewOutputEDUConsumer( // Start consuming from EDU servers func (t *OutputEDUConsumer) Start() error { - if _, err := t.jetstream.Subscribe(t.typingTopic, t.onTypingEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.typingTopic, t.durable, t.onTypingEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.sendToDeviceTopic, t.onSendToDeviceEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.sendToDeviceTopic, t.durable, t.onSendToDeviceEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } - if _, err := t.jetstream.Subscribe(t.receiptTopic, t.onReceiptEvent, t.durable); err != nil { + if err := jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.receiptTopic, t.durable, t.onReceiptEvent, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { return err } return nil @@ -80,175 +89,169 @@ func (t *OutputEDUConsumer) Start() error { // onSendToDeviceEvent is called in response to a message received on the // send-to-device events topic from the EDU server. -func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { +func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.Msg) bool { // Extract the send-to-device event from msg. - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var ote api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") - return true - } - - // only send send-to-device events which originated from us - _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) - if err != nil { - log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") - return true - } - if originServerName != t.ServerName { - log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") - return true - } - - _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") - return true - } - - // Pack the EDU and marshal it - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), - } - tdm := gomatrixserverlib.ToDeviceMessage{ - Sender: ote.Sender, - Type: ote.Type, - MessageID: util.RandomString(32), - Messages: map[string]map[string]json.RawMessage{ - ote.UserID: { - ote.DeviceID: ote.Content, - }, - }, - } - if edu.Content, err = json.Marshal(tdm); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - log.Infof("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - + var ote api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") return true - }) + } + + // only send send-to-device events which originated from us + _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) + if err != nil { + log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") + return true + } + if originServerName != t.ServerName { + log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") + return true + } + + _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") + return true + } + + // Pack the EDU and marshal it + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MDirectToDevice, + Origin: string(t.ServerName), + } + tdm := gomatrixserverlib.ToDeviceMessage{ + Sender: ote.Sender, + Type: ote.Type, + MessageID: util.RandomString(32), + Messages: map[string]map[string]json.RawMessage{ + ote.UserID: { + ote.DeviceID: ote.Content, + }, + }, + } + if edu.Content, err = json.Marshal(tdm); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + log.Infof("Sending send-to-device message into %q destination queue", destServerName) + if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onTypingEvent is called in response to a message received on the typing // events topic from the EDU server. -func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var ote api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") - _ = msg.Ack() - return true - } - - // only send typing events which originated from us - _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") - _ = msg.Ack() - return true - } - if typingServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} - if edu.Content, err = json.Marshal(map[string]interface{}{ - "room_id": ote.Event.RoomID, - "user_id": ote.Event.UserID, - "typing": ote.Event.Typing, - }); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onTypingEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var ote api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") + _ = msg.Ack() return true - }) + } + + // only send typing events which originated from us + _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") + _ = msg.Ack() + return true + } + if typingServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, ote.Event.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": ote.Event.RoomID, + "user_id": ote.Event.UserID, + "typing": ote.Event.Typing, + }); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } // onReceiptEvent is called in response to a message received on the receipt // events topic from the EDU server. -func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Extract the typing event from msg. - var receipt api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &receipt); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") - return true - } - - // only send receipt events which originated from us - _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) - if err != nil { - log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") - return true - } - if receiptServerName != t.ServerName { - return true - } - - joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") - return false - } - - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } - - content := map[string]api.FederationReceiptMRead{} - content[receipt.RoomID] = api.FederationReceiptMRead{ - User: map[string]api.FederationReceiptData{ - receipt.UserID: { - Data: api.ReceiptTS{ - TS: receipt.Timestamp, - }, - EventIDs: []string{receipt.EventID}, - }, - }, - } - - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), - } - if edu.Content, err = json.Marshal(content); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - return true - } - - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - return false - } - +func (t *OutputEDUConsumer) onReceiptEvent(ctx context.Context, msg *nats.Msg) bool { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") return true - }) + } + + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") + return true + } + if receiptServerName != t.ServerName { + return true + } + + joined, err := t.db.GetJoinedHosts(ctx, receipt.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") + return false + } + + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } + + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, + }, + }, + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } + + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } + + return true } diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 25ea78274..ac29f930b 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct { cfg *config.FederationAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string db storage.Database queues *queue.OutgoingQueues topic string @@ -66,74 +66,70 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the federation server receives a new event from the room server output log. // It is unsafe to call this with messages for the same room in multiple gorountines // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } - switch output.Type { - case api.OutputTypeNewRoomEvent: - ev := output.NewRoomEvent.Event + switch output.Type { + case api.OutputTypeNewRoomEvent: + ev := output.NewRoomEvent.Event - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return false - } - } - - if err := s.processMessage(*output.NewRoomEvent); err != nil { - switch err.(type) { - case *queue.ErrorFederationDisabled: - log.WithField("error", output.Type).Info( - err.Error(), - ) - default: - // panic rather than continue with an inconsistent database - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "event": string(ev.JSON()), - "add": output.NewRoomEvent.AddsStateEventIDs, - "del": output.NewRoomEvent.RemovesStateEventIDs, - log.ErrorKey: err, - }).Panicf("roomserver output log: write room event failure") - } - } - - case api.OutputTypeNewInboundPeek: - if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { - log.WithFields(log.Fields{ - "event": output.NewInboundPeek, - log.ErrorKey: err, - }).Panicf("roomserver output log: remote peek event failure") + if output.NewRoomEvent.RewritesState { + if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { + log.WithError(err).Errorf("roomserver output log: purge room state failure") return false } - - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) } - return true - }) + if err := s.processMessage(*output.NewRoomEvent); err != nil { + switch err.(type) { + case *queue.ErrorFederationDisabled: + log.WithField("error", output.Type).Info( + err.Error(), + ) + default: + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "event": string(ev.JSON()), + "add": output.NewRoomEvent.AddsStateEventIDs, + "del": output.NewRoomEvent.RemovesStateEventIDs, + log.ErrorKey: err, + }).Panicf("roomserver output log: write room event failure") + } + } + + case api.OutputTypeNewInboundPeek: + if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { + log.WithFields(log.Fields{ + "event": output.NewInboundPeek, + log.ErrorKey: err, + }).Panicf("roomserver output log: remote peek event failure") + return false + } + + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + + return true } // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 5b87e623d..fd963ad83 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -41,7 +41,7 @@ type RoomserverInternalAPI struct { fsAPI fsAPI.FederationInternalAPI asAPI asAPI.AppServiceQueryAPI JetStream nats.JetStreamContext - Durable nats.SubOpt + Durable string InputRoomEventTopic string // JetStream topic for new input room events OutputRoomEventTopic string // JetStream topic for new output room events PerspectiveServerNames []gomatrixserverlib.ServerName @@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA InputRoomEventTopic: r.InputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic, JetStream: r.JetStream, - Durable: r.Durable, + Durable: nats.Durable(r.Durable), ServerName: r.Cfg.Matrix.ServerName, FSAPI: fsAPI, KeyRing: keyRing, diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index 94e2d88b3..9271cd8b4 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -2,8 +2,6 @@ package config import ( "fmt" - - "github.com/nats-io/nats.go" ) type JetStream struct { @@ -25,8 +23,8 @@ func (c *JetStream) TopicFor(name string) string { return fmt.Sprintf("%s%s", c.TopicPrefix, name) } -func (c *JetStream) Durable(name string) nats.SubOpt { - return nats.Durable(c.TopicFor(name)) +func (c *JetStream) Durable(name string) string { + return c.TopicFor(name) } func (c *JetStream) Defaults(generate bool) { diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1891b96b3..544b5f0c3 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -1,12 +1,81 @@ package jetstream -import "github.com/nats-io/nats.go" +import ( + "context" + "fmt" -func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { - _ = msg.InProgress() - if f(msg) { - _ = msg.Ack() - } else { - _ = msg.Nak() + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +func JetStreamConsumer( + ctx context.Context, js nats.JetStreamContext, subj, durable string, + f func(ctx context.Context, msg *nats.Msg) bool, + opts ...nats.SubOpt, +) error { + defer func() { + // If there are existing consumers from before they were pull + // consumers, we need to clean up the old push consumers. However, + // in order to not affect the interest-based policies, we need to + // do this *after* creating the new pull consumers, which have + // "Pull" suffixed to their name. + if _, err := js.ConsumerInfo(subj, durable); err == nil { + if err := js.DeleteConsumer(subj, durable); err != nil { + logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable) + } + } + }() + + name := durable + "Pull" + sub, err := js.PullSubscribe(subj, name, opts...) + if err != nil { + return fmt.Errorf("nats.SubscribeSync: %w", err) } + go func() { + for { + // The context behaviour here is surprising — we supply a context + // so that we can interrupt the fetch if we want, but NATS will still + // enforce its own deadline (roughly 5 seconds by default). Therefore + // it is our responsibility to check whether our context expired or + // not when a context error is returned. Footguns. Footguns everywhere. + msgs, err := sub.Fetch(1, nats.Context(ctx)) + if err != nil { + if err == context.Canceled || err == context.DeadlineExceeded { + // Work out whether it was the JetStream context that expired + // or whether it was our supplied context. + select { + case <-ctx.Done(): + // The supplied context expired, so we want to stop the + // consumer altogether. + return + default: + // The JetStream context expired, so the fetch probably + // just timed out and we should try again. + continue + } + } else { + // Something else went wrong, so we'll panic. + logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) + } + } + if len(msgs) < 1 { + continue + } + msg := msgs[0] + if err = msg.InProgress(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + continue + } + if f(ctx, msg) { + if err = msg.Ack(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Ack: %w", err)) + } + } else { + if err = msg.Nak(); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + } + } + } + }() + return nil } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 3d340a16a..c3650085f 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -34,7 +34,7 @@ import ( type OutputClientDataConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -63,45 +63,45 @@ func NewOutputClientDataConsumer( // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called when the sync server receives a new event from the client API server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - userID := msg.Header.Get(jetstream.UserID) - var output eventutil.AccountData - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("client API server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - }).Debug("Received data from client API server") - - streamPos, err := s.db.UpsertAccountData( - s.ctx, userID, output.RoomID, output.Type, - ) - if err != nil { - sentry.CaptureException(err) - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - log.ErrorKey: err, - }).Panicf("could not save account data") - } - - s.stream.Advance(streamPos) - s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) - +func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + userID := msg.Header.Get(jetstream.UserID) + var output eventutil.AccountData + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("client API server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + }).Debug("Received data from client API server") + + streamPos, err := s.db.UpsertAccountData( + s.ctx, userID, output.RoomID, output.Type, + ) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + log.ErrorKey: err, + }).Panicf("could not save account data") + } + + s.stream.Advance(streamPos) + s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 57d69d6fb..392840ece 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -34,7 +34,7 @@ import ( type OutputReceiptEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database stream types.StreamProvider @@ -64,36 +64,36 @@ func NewOutputReceiptEventConsumer( // Start consuming from EDU api func (s *OutputReceiptEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - streamPos, err := s.db.StoreReceipt( - s.ctx, - output.RoomID, - output.Type, - output.UserID, - output.EventID, - output.Timestamp, - ) - if err != nil { - sentry.CaptureException(err) - return true - } - - s.stream.Advance(streamPos) - s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) - +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + streamPos, err := s.db.StoreReceipt( + s.ctx, + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + sentry.CaptureException(err) + return true + } + + s.stream.Advance(streamPos) + s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) + + return true } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 54e689fa1..b0beef063 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -36,7 +36,7 @@ import ( type OutputSendToDeviceEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database serverName gomatrixserverlib.ServerName // our server name @@ -68,52 +68,52 @@ func NewOutputSendToDeviceEventConsumer( // Start consuming from EDU api func (s *OutputSendToDeviceEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - sentry.CaptureException(err) - return true - } - if domain != s.serverName { - return true - } - - util.GetLogger(context.TODO()).WithFields(log.Fields{ - "sender": output.Sender, - "user_id": output.UserID, - "device_id": output.DeviceID, - "event_type": output.Type, - }).Info("sync API received send-to-device event from EDU server") - - streamPos, err := s.db.StoreNewSendForDeviceMessage( - s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, - ) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Errorf("failed to store send-to-device message") - return false - } - - s.stream.Advance(streamPos) - s.notifier.OnNewSendToDevice( - output.UserID, - []string{output.DeviceID}, - types.StreamingToken{SendToDevicePosition: streamPos}, - ) - +func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + sentry.CaptureException(err) + return true + } + if domain != s.serverName { + return true + } + + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") + + streamPos, err := s.db.StoreNewSendForDeviceMessage( + s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + sentry.CaptureException(err) + log.WithError(err).Errorf("failed to store send-to-device message") + return false + } + + s.stream.Advance(streamPos) + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.StreamingToken{SendToDevicePosition: streamPos}, + ) + + return true } diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index de2f6f950..cae5df8a8 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -35,7 +35,7 @@ import ( type OutputTypingEventConsumer struct { ctx context.Context jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string eduCache *cache.EDUCache stream types.StreamProvider @@ -66,41 +66,41 @@ func NewOutputTypingEventConsumer( // Start consuming from EDU api func (s *OutputTypingEventConsumer) Start() error { - _, err := s.jetstream.Subscribe(s.topic, s.onMessage, s.durable) - return err + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - var output api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - return true - } - - log.WithFields(log.Fields{ - "room_id": output.Event.RoomID, - "user_id": output.Event.UserID, - "typing": output.Event.Typing, - }).Debug("received data from EDU server") - - var typingPos types.StreamPosition - typingEvent := output.Event - if typingEvent.Typing { - typingPos = types.StreamPosition( - s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), - ) - } else { - typingPos = types.StreamPosition( - s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), - ) - } - - s.stream.Advance(typingPos) - s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) - +func (s *OutputTypingEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) return true - }) + } + + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from EDU server") + + var typingPos types.StreamPosition + typingEvent := output.Event + if typingEvent.Typing { + typingPos = types.StreamPosition( + s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), + ) + } else { + typingPos = types.StreamPosition( + s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), + ) + } + + s.stream.Advance(typingPos) + s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) + + return true } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index e9c4abe88..7fe52b728 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { cfg *config.SyncAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext - durable nats.SubOpt + durable string topic string db storage.Database pduStream types.StreamProvider @@ -73,65 +73,61 @@ func NewOutputRoomEventConsumer( // Start consuming from room servers func (s *OutputRoomEventConsumer) Start() error { - _, err := s.jetstream.Subscribe( - s.topic, s.onMessage, s.durable, - nats.DeliverAll(), - nats.ManualAck(), + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), ) - return err } // onMessage is called when the sync server receives a new event from the room server output log. // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. -func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { - // Parse out the event JSON - var err error - var output api.OutputEvent - if err = json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - return true - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - // Ignore redaction events. We will add them to the database when they are - // validated (when we receive OutputTypeRedactedEvent) - event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { - // in the special case where the event redacts itself, just pass the message through because - // we will never see the other part of the pair - if event.Redacts() != event.EventID() { - return true - } - } - err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) - case api.OutputTypeOldRoomEvent: - err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) - case api.OutputTypeNewInviteEvent: - s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) - case api.OutputTypeRetireInviteEvent: - s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) - case api.OutputTypeNewPeek: - s.onNewPeek(s.ctx, *output.NewPeek) - case api.OutputTypeRetirePeek: - s.onRetirePeek(s.ctx, *output.RetirePeek) - case api.OutputTypeRedactedEvent: - err = s.onRedactEvent(s.ctx, *output.RedactedEvent) - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - } - if err != nil { - log.WithError(err).Error("roomserver output log: failed to process event") - return false - } - +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + // Parse out the event JSON + var err error + var output api.OutputEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true - }) + } + + switch output.Type { + case api.OutputTypeNewRoomEvent: + // Ignore redaction events. We will add them to the database when they are + // validated (when we receive OutputTypeRedactedEvent) + event := output.NewRoomEvent.Event + if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + // in the special case where the event redacts itself, just pass the message through because + // we will never see the other part of the pair + if event.Redacts() != event.EventID() { + return true + } + } + err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) + case api.OutputTypeOldRoomEvent: + err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) + case api.OutputTypeNewInviteEvent: + s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) + case api.OutputTypeRetireInviteEvent: + s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) + case api.OutputTypeNewPeek: + s.onNewPeek(s.ctx, *output.NewPeek) + case api.OutputTypeRetirePeek: + s.onRetirePeek(s.ctx, *output.RetirePeek) + case api.OutputTypeRedactedEvent: + err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + if err != nil { + log.WithError(err).Error("roomserver output log: failed to process event") + return false + } + + return true } func (s *OutputRoomEventConsumer) onRedactEvent( From 2a5c38fee23439103e2260d184f74bf41f729e09 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Feb 2022 17:33:36 +0000 Subject: [PATCH 24/81] Use background contexts during federated join for clarity (#2134) * Use background contexts for clarity * Don't wait for the context to expire before trying to return * Actually we don't really need a goroutine here --- federationapi/internal/perform.go | 71 +++++++++++++------------------ 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 4dd53c11b..7850f206c 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -196,29 +196,22 @@ func (r *FederationInternalAPI) performJoinUsingServer( return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err) } - // No longer reuse the request context from this point forward. - // We don't want the client timing out to interrupt the join. - var cancel context.CancelFunc - ctx, cancel = context.WithCancel(context.Background()) - // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( - ctx, + context.Background(), serverName, event, respMakeJoin.RoomVersion, ) if err != nil { r.statistics.ForServer(serverName).Failure() - cancel() return fmt.Errorf("r.federation.SendJoin: %w", err) } r.statistics.ForServer(serverName).Success() // Sanity-check the join response to ensure that it has a create // event, that the room version is known, etc. - if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { - cancel() + if err = sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } @@ -227,41 +220,35 @@ func (r *FederationInternalAPI) performJoinUsingServer( // to complete, but if the client does give up waiting, we'll // still continue to process the join anyway so that we don't // waste the effort. - go func() { - defer cancel() + // TODO: Can we expand Check here to return a list of missing auth + // events rather than failing one at a time? + var respState *gomatrixserverlib.RespState + respState, err = respSendJoin.Check( + context.Background(), + r.keyRing, + event, + federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + ) + if err != nil { + return fmt.Errorf("respSendJoin.Check: %w", err) + } - // TODO: Can we expand Check here to return a list of missing auth - // events rather than failing one at a time? - respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) - if err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to process room join response") - return - } + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + if err = roomserverAPI.SendEventWithState( + context.Background(), + r.rsAPI, + roomserverAPI.KindNew, + respState, + event.Headered(respMakeJoin.RoomVersion), + serverName, + nil, + false, + ); err != nil { + return fmt.Errorf("roomserverAPI.SendEventWithState: %w", err) + } - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, - roomserverAPI.KindNew, - respState, - event.Headered(respMakeJoin.RoomVersion), - serverName, - nil, - false, - ); err != nil { - logrus.WithFields(logrus.Fields{ - "room_id": roomID, - "user_id": userID, - }).WithError(err).Error("Failed to send room join response to roomserver") - return - } - }() - - <-ctx.Done() return nil } From 4d9f5b2e5787d23e1dbcebfda1c6d99d3498ec7e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Feb 2022 17:46:37 +0000 Subject: [PATCH 25/81] Fix panic from closing the input channel before the workers complete (it'll get GC'd either way) --- roomserver/internal/input/input.go | 1 - 1 file changed, 1 deletion(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index a38d56d7e..7834e2edc 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -161,7 +161,6 @@ func (r *Inputer) InputRoomEvents( } } else { responses := make(chan error, len(request.InputRoomEvents)) - defer close(responses) for _, e := range request.InputRoomEvents { inputRoomEvent := e roomID := inputRoomEvent.Event.RoomID() From eb352a5f6bdb48cb2d795e3fe2cd7d354580a761 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Feb 2022 10:39:34 +0000 Subject: [PATCH 26/81] Full roomserver input transactional isolation (#2141) * Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input * Better transaction management * Tweak order * Handle cases where the room does not exist * Other fixes * More tweaks * Fill some gaps * Fill in the gaps * good lord it gets worse * Don't roll back transactions when events rejected * Pass through errors properly * Fix bugs * Fix incorrect error check * Don't panic on nil txns * Tweaks * Hopefully fix panics for good in SQLite this time * Fix rollback * Minor bug fixes with latest event updater * Some review comments * Revert "Some review comments" This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e. * Fix a couple of bugs * Clearer commit and rollback results * Remove unnecessary prepares --- roomserver/internal/helpers/auth.go | 13 +- roomserver/internal/input/input.go | 57 +++- roomserver/internal/input/input_events.go | 84 +++--- .../internal/input/input_latest_events.go | 18 +- roomserver/internal/input/input_membership.go | 4 +- roomserver/internal/input/input_missing.go | 12 +- roomserver/state/state.go | 17 +- roomserver/storage/interface.go | 7 +- .../storage/postgres/event_json_table.go | 5 +- .../postgres/event_state_keys_table.go | 10 +- .../storage/postgres/event_types_table.go | 5 +- roomserver/storage/postgres/events_table.go | 30 +- roomserver/storage/postgres/invite_table.go | 13 +- .../storage/postgres/membership_table.go | 67 +++-- .../storage/postgres/published_table.go | 10 +- .../storage/postgres/room_aliases_table.go | 15 +- roomserver/storage/postgres/rooms_table.go | 27 +- .../storage/postgres/state_block_table.go | 11 +- .../storage/postgres/state_snapshot_table.go | 5 +- .../storage/shared/latest_events_updater.go | 133 --------- roomserver/storage/shared/room_updater.go | 262 +++++++++++++++++ roomserver/storage/shared/storage.go | 273 +++++++++++------- .../storage/sqlite3/event_json_table.go | 11 +- .../storage/sqlite3/event_state_keys_table.go | 21 +- .../storage/sqlite3/event_types_table.go | 5 +- roomserver/storage/sqlite3/events_table.go | 19 +- roomserver/storage/sqlite3/invite_table.go | 13 +- .../storage/sqlite3/membership_table.go | 45 ++- roomserver/storage/sqlite3/published_table.go | 10 +- .../storage/sqlite3/room_aliases_table.go | 15 +- roomserver/storage/sqlite3/rooms_table.go | 33 ++- .../storage/sqlite3/state_block_table.go | 9 +- .../storage/sqlite3/state_snapshot_table.go | 3 +- roomserver/storage/sqlite3/storage.go | 42 +-- roomserver/storage/tables/interface.go | 62 ++-- 35 files changed, 867 insertions(+), 499 deletions(-) delete mode 100644 roomserver/storage/shared/latest_events_updater.go create mode 100644 roomserver/storage/shared/room_updater.go diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index ddda8081c..9af0bf591 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -20,17 +20,22 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) +type checkForAuthAndSoftFailStorage interface { + state.StateResolutionStorage + StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) +} + // CheckForSoftFail returns true if the event should be soft-failed // and false otherwise. The return error value should be checked before // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -92,7 +97,7 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db checkForAuthAndSoftFailStorage, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * // loadAuthEvents loads the events needed for authentication from the supplied room state. func loadAuthEvents( ctx context.Context, - db storage.Database, + db state.StateResolutionStorage, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 7834e2edc..5bdec0a24 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "sync" "time" @@ -38,6 +39,19 @@ import ( "github.com/tidwall/gjson" ) +type retryAction int +type commitAction int + +const ( + doNotRetry retryAction = iota + retryLater +) + +const ( + commitTransaction commitAction = iota + rollbackTransaction +) + var keyContentFields = map[string]string{ "m.room.join_rules": "join_rule", "m.room.history_visibility": "history_visibility", @@ -101,7 +115,8 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { + action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) + if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) } @@ -111,7 +126,12 @@ func (r *Inputer) Start() error { "type": inputRoomEvent.Event.Type(), }).Warn("Roomserver failed to process async event") } - _ = msg.Ack() + switch action { + case retryLater: + _ = msg.Nak() + case doNotRetry: + _ = msg.Ack() + } }) }, // NATS wants to acknowledge automatically by default when the message is @@ -131,6 +151,37 @@ func (r *Inputer) Start() error { return err } +// processRoomEventUsingUpdater opens up a room updater and tries to +// process the event. It returns whether or not we should positively +// or negatively acknowledge the event (i.e. for NATS) and an error +// if it occurred. +func (r *Inputer) processRoomEventUsingUpdater( + ctx context.Context, + roomID string, + inputRoomEvent *api.InputRoomEvent, +) (retryAction, error) { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err) + } + updater, err := r.DB.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err) + } + action, err := r.processRoomEvent(ctx, updater, inputRoomEvent) + switch action { + case commitTransaction: + if cerr := updater.Commit(); cerr != nil { + return retryLater, fmt.Errorf("updater.Commit: %w", cerr) + } + case rollbackTransaction: + if rerr := updater.Rollback(); rerr != nil { + return retryLater, fmt.Errorf("updater.Rollback: %w", rerr) + } + } + return doNotRetry, err +} + // InputRoomEvents implements api.RoomserverInternalAPI func (r *Inputer) InputRoomEvents( ctx context.Context, @@ -177,7 +228,7 @@ func (r *Inputer) InputRoomEvents( worker.Act(nil, func() { defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - err := r.processRoomEvent(ctx, &inputRoomEvent) + _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 16703616e..f3fa83d83 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -29,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, -) (err error) { +) (commitAction, error) { select { case <-ctx.Done(): // Before we do anything, make sure the context hasn't expired for this pending task. // If it has then we'll give up straight away — it's probably a synchronous input // request and the caller has already given up, but the inbox task was still queued. - return context.DeadlineExceeded + return rollbackTransaction, context.DeadlineExceeded default: } @@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -116,11 +118,11 @@ func (r *Inputer) processRoomEvent( case gomatrixserverlib.EventIDFormatV1: if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) { logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } default: logger.Debugf("Already processed event; ignoring") - return nil + return rollbackTransaction, nil } } } @@ -134,8 +136,8 @@ func (r *Inputer) processRoomEvent( AuthEventIDs: event.AuthEventIDs(), PrevEventIDs: event.PrevEventIDs(), } - if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) + if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) } } missingAuth := len(missingRes.MissingAuthEventIDs) > 0 @@ -146,8 +148,8 @@ func (r *Inputer) processRoomEvent( RoomID: event.RoomID(), ExcludeSelf: true, } - if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { - return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) + if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise // their order. Then make sure that the input origin and the @@ -176,8 +178,8 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) + if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -193,7 +195,7 @@ func (r *Inputer) processRoomEvent( authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { if _, ok := knownEvents[authEventID]; !ok { - return fmt.Errorf("missing auth event %s", authEventID) + return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID) } authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } @@ -202,7 +204,8 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + var err error + softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -227,7 +230,7 @@ func (r *Inputer) processRoomEvent( origin: input.Origin, inputer: r, queryer: r.Queryer, - db: r.DB, + db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), @@ -235,7 +238,7 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) } else { @@ -248,16 +251,16 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err) } // if storing this event results in it being redacted then do so. if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, event) if rerr != nil { - return fmt.Errorf("eventutil.RedactEvent: %w", rerr) + return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr) } event = r } @@ -268,23 +271,23 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindOutlier { logger.Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } - roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + roomInfo, err := updater.RoomInfo(ctx, event.RoomID()) if err != nil { - return fmt.Errorf("r.DB.RoomInfo: %w", err) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err) } if roomInfo == nil { - return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected) + err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) if err != nil { - return fmt.Errorf("r.calculateAndSetState: %w", err) + return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err) } } @@ -294,13 +297,14 @@ func (r *Inputer) processRoomEvent( "soft_fail": softfail, "missing_prev": missingPrev, }).Warn("Stored rejected event") - return rejectionErr + return commitTransaction, rejectionErr } switch input.Kind { case api.KindNew: if err = r.updateLatestEvents( ctx, // context + updater, // room updater roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event @@ -308,7 +312,7 @@ func (r *Inputer) processRoomEvent( input.TransactionID, // transaction ID input.HasState, // rewrites state? ); err != nil { - return fmt.Errorf("r.updateLatestEvents: %w", err) + return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err) } case api.KindOld: err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{ @@ -320,7 +324,7 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (old): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err) } } @@ -339,14 +343,14 @@ func (r *Inputer) processRoomEvent( }, }) if err != nil { - return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) + return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err) } } // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) - return nil + return commitTransaction, nil } // fetchAuthEvents will check to see if any of the @@ -358,6 +362,7 @@ func (r *Inputer) processRoomEvent( // they are now in the database. func (r *Inputer) fetchAuthEvents( ctx context.Context, + updater *shared.RoomUpdater, logger *logrus.Entry, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, @@ -375,7 +380,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -454,9 +459,9 @@ func (r *Inputer) fetchAuthEvents( } // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) if err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + return fmt.Errorf("updater.StoreEvent: %w", err) } // Now we know about this event, it was stored and the signatures were OK. @@ -471,6 +476,7 @@ func (r *Inputer) fetchAuthEvents( func (r *Inputer) calculateAndSetState( ctx context.Context, + updater *shared.RoomUpdater, input *api.InputRoomEvent, roomInfo *types.RoomInfo, stateAtEvent *types.StateAtEvent, @@ -478,14 +484,14 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { + if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -495,13 +501,13 @@ func (r *Inputer) calculateAndSetState( // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) + if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { + return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err) } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return fmt.Errorf("r.DB.AddState: %w", err) + if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + return fmt.Errorf("updater.AddState: %w", err) } } else { stateAtEvent.Overwrite = false @@ -512,7 +518,7 @@ func (r *Inputer) calculateAndSetState( } } - err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) if err != nil { return fmt.Errorf("r.DB.SetState: %w", err) } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 6137941e1..5173d3ab2 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -48,6 +47,7 @@ import ( // Can only be called once at a time func (r *Inputer) updateLatestEvents( ctx context.Context, + updater *shared.RoomUpdater, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event *gomatrixserverlib.Event, @@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents( transactionID *api.TransactionID, rewritesState bool, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) - } - succeeded := false - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - u := latestEventsUpdater{ ctx: ctx, api: r, @@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents( return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } - succeeded = true return } @@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents( type latestEventsUpdater struct { ctx context.Context api *Inputer - updater *shared.LatestEventsUpdater + updater *shared.RoomUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event *gomatrixserverlib.Event @@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, u.roomInfo) + roomState := state.NewStateResolution(u.updater, u.roomInfo) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the @@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } @@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) stateEventNIDs = append(stateEventNIDs, entry.EventNID) } stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] - return u.api.DB.EventIDs(u.ctx, stateEventNIDs) + return u.updater.EventIDs(u.ctx, stateEventNIDs) } type eventNIDSorter []types.EventNID diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 2511097d0..ff3ed7e5d 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -31,7 +31,7 @@ import ( // consumers about the invites added or retired by the change in current state. func (r *Inputer) updateMemberships( ctx context.Context, - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, removed, added []types.StateEntry, ) ([]api.OutputEvent, error) { changes := membershipChanges(removed, added) @@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships( } func (r *Inputer) updateMembership( - updater *shared.LatestEventsUpdater, + updater *shared.RoomUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, updates []api.OutputEvent, diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index d401fa0e9..4cd2b3de1 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -11,7 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" - "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -19,7 +19,7 @@ import ( type missingStateReq struct { origin gomatrixserverlib.ServerName - db storage.Database + db *shared.RoomUpdater inputer *Inputer queryer *query.Queryer keys gomatrixserverlib.JSONVerifier @@ -78,7 +78,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindNew, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -187,7 +187,7 @@ func (t *missingStateReq) processEventWithMissingState( } // TODO: we could do this concurrently? for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil { return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) } } @@ -200,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -217,7 +217,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 15d592b46..e5f69521e 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -22,7 +22,6 @@ import ( "sort" "time" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -30,13 +29,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) +type StateResolutionStorage interface { + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) +} + type StateResolution struct { - db storage.Database + db StateResolutionStorage roomInfo *types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event } -func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution { +func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 15764366b..a9851e05b 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -86,11 +86,10 @@ type Database interface { // Lookup the event IDs for a batch of event numeric IDs. // Returns an error if the retrieval went wrong. EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) - // Look up the latest events in a room in preparation for an update. - // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. - // Returns the latest events in the room and the last eventID sent to the log along with an updater. + // Opens and returns a room updater, which locks the room and opens a transaction. + // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) + GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error) // Look up event references for the latest events in the room and the current state snapshot. // Returns the latest events, the current state and the maximum depth of the latest events plus 1. // Returns an error if there was a problem talking to the database. diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 32e457821..433e445d8 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { - rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 3a7cf03e3..762b3a1fc 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { - rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt) + rows, err := stmt.QueryContext( ctx, pq.StringArray(eventStateKeys), ) if err != nil { @@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { nIDs := make(pq.Int64Array, len(eventStateKeyNIDs)) for i := range eventStateKeyNIDs { nIDs[i] = int64(eventStateKeyNIDs[i]) } - rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt) + rows, err := stmt.QueryContext(ctx, nIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index e558072a5..1d5de5822 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { - rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 778cd8d73..6c3847752 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { - rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByNID lookups a list of state events by event NID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) if err != nil { return nil, err } @@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { - rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { - rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } @@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { - rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err } @@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { - rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt) + rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index 344302c8f..176c16e48 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) ([]string, error) { stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) @@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired( // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index b0d906c80..48c2c35cd 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { } func (s *membershipStatements) InsertMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) @@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership( } func (s *membershipStatements) SelectMembershipForUpdate( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (membership tables.MembershipState, err error) { err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext( ctx, roomNID, targetUserNID, @@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, roomNID types.RoomNID, localOnly bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt if localOnly { @@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { stmt = s.selectMembershipsFromRoomStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID) if err != nil { return @@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var rows *sql.Rows @@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err = stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } func (s *membershipStatements) UpdateMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( @@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms( + ctx context.Context, txn *sql.Tx, + roomNIDs []types.RoomNID, +) (map[types.EventStateKeyNID]int, error) { roomIDarray := make([]int64, len(roomNIDs)) for i := range roomNIDs { roomIDarray[i] = int64(roomNIDs[i]) } - rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) if err != nil { return nil, err } @@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers( + ctx context.Context, txn *sql.Tx, + userID types.EventStateKeyNID, searchString string, limit int, +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, - forget bool, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( ctx, roomNID, targetUserNID, forget, @@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, +) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, serverName gomatrixserverlib.ServerName, +) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 8deb68441..15985fcd6 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index 031825fee..d13df8e7f 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) ([]string, error) { - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return nil, err } @@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index f51eba4d4..b2685084d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID( return types.RoomNID(roomNID), err } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDs pq.Int64Array - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs, ) if err == sql.ErrNoRows { @@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs( ) ([]types.EventNID, types.StateSnapshotNID, error) { var nids pq.Int64Array var stateSnapshotNID int64 - stmt := s.selectLatestEventNIDsStmt + stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt) err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) if err != nil { return nil, 0, err @@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { - rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) + stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs)) if err != nil { return nil, err } @@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { var array pq.Int64Array for _, nid := range roomNIDs { array = append(array, int64(nid)) } - rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } @@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { var array pq.StringArray for _, roomID := range roomIDs { array = append(array, roomID) } - rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt) + rows, err := stmt.QueryContext(ctx, array) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index 27d85e83b..6f8f9e1b5 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData( for _, e := range entries { nids = append(nids, e.EventNID) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), eventNIDsAsArray(nids), ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt) + rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 4fc0fa48a..ce9f24636 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]int64, len(stateNIDs)) for i := range stateNIDs { nids[i] = int64(stateNIDs[i]) } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) + stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids)) if err != nil { return nil, err } diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go deleted file mode 100644 index 36865081a..000000000 --- a/roomserver/storage/shared/latest_events_updater.go +++ /dev/null @@ -1,133 +0,0 @@ -package shared - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type LatestEventsUpdater struct { - transaction - d *Database - roomInfo types.RoomInfo - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -func rollback(txn *sql.Tx) { - if txn == nil { - return - } - txn.Rollback() // nolint: errcheck -} - -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) - if err != nil { - rollback(txn) - return nil, err - } - stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - rollback(txn) - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - rollback(txn) - return nil, err - } - } - return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// RoomVersion implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - return u.roomInfo.RoomVersion -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer -func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { - return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) - } - if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { - roomInfo.StateSnapshotNID = currentStateSnapshotNID - roomInfo.IsStub = false - u.d.Cache.StoreRoomInfo(roomID, roomInfo) - } - } - return nil - }) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) - }) -} - -func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) -} diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go new file mode 100644 index 000000000..bb9f5dc62 --- /dev/null +++ b/roomserver/storage/shared/room_updater.go @@ -0,0 +1,262 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type RoomUpdater struct { + transaction + d *Database + roomInfo *types.RoomInfo + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +func rollback(txn *sql.Tx) { + if txn == nil { + return + } + txn.Rollback() // nolint: errcheck +} + +func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) { + // If the roomInfo is nil then that means that the room doesn't exist + // yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that + // would involve locking a row on the table that doesn't exist. Instead + // we will just run with a normal database transaction. It'll either + // succeed, processing a create event which creates the room, or it won't. + if roomInfo == nil { + return &RoomUpdater{ + transaction{ctx, txn}, d, nil, nil, "", 0, + }, nil + } + + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) + if err != nil { + rollback(txn) + return nil, err + } + stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + rollback(txn) + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + rollback(txn) + return nil, err + } + } + return &RoomUpdater{ + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Commit() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Rollback() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Rollback() +} + +// RoomVersion implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { + return u.roomInfo.RoomVersion +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer +func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } + } + return nil + }) +} + +func (u *RoomUpdater) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + return u.d.events(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID) +} + +func (u *RoomUpdater) StoreEvent( + ctx context.Context, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected) +} + +func (u *RoomUpdater) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs) +} + +func (u *RoomUpdater) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return u.d.stateEntries(ctx, u.txn, stateBlockNIDs) +} + +func (u *RoomUpdater) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples) +} + +func (u *RoomUpdater) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state) +} + +func (u *RoomUpdater) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID) + }) +} + +func (u *RoomUpdater) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return u.d.eventTypeNIDs(ctx, u.txn, eventTypes) +} + +func (u *RoomUpdater) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys) +} + +func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return u.d.roomInfo(ctx, u.txn, roomID) +} + +func (u *RoomUpdater) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) +} + +func (u *RoomUpdater) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs) +} + +func (u *RoomUpdater) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly) +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok { + if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok { + roomInfo.StateSnapshotNID = currentStateSnapshotNID + roomInfo.IsStub = false + u.d.Cache.StoreRoomInfo(roomID, roomInfo) + } + } + return nil + }) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) +} + +func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index d4c5ebb5b..127cd1f52 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -26,23 +26,23 @@ import ( const redactionsArePermanent = true type Database struct { - DB *sql.DB - Cache caching.RoomServerCaches - Writer sqlutil.Writer - EventsTable tables.Events - EventJSONTable tables.EventJSON - EventTypesTable tables.EventTypes - EventStateKeysTable tables.EventStateKeys - RoomsTable tables.Rooms - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock - RoomAliasesTable tables.RoomAliases - PrevEventsTable tables.PreviousEvents - InvitesTable tables.Invites - MembershipTable tables.Membership - PublishedTable tables.Published - RedactionsTable tables.Redactions - GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) + DB *sql.DB + Cache caching.RoomServerCaches + Writer sqlutil.Writer + EventsTable tables.Events + EventJSONTable tables.EventJSON + EventTypesTable tables.EventTypes + EventStateKeysTable tables.EventStateKeys + RoomsTable tables.Rooms + StateSnapshotTable tables.StateSnapshot + StateBlockTable tables.StateBlock + RoomAliasesTable tables.RoomAliases + PrevEventsTable tables.PreviousEvents + InvitesTable tables.Invites + MembershipTable tables.Membership + PublishedTable tables.Published + RedactionsTable tables.Redactions + GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } func (d *Database) SupportsConcurrentRoomInputs() bool { @@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.eventTypeNIDs(ctx, nil, eventTypes) +} + +func (d *Database) eventTypeNIDs( + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) remaining := []string{} @@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs( } } if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining) + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs( func (d *Database) EventStateKeys( ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { - return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs) + return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs) } func (d *Database) EventStateKeyNIDs( ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.eventStateKeyNIDs(ctx, nil, eventStateKeys) +} + +func (d *Database) eventStateKeyNIDs( + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) remaining := []string{} @@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs( } } if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining) + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) if err != nil { return nil, err } @@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs( func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { - return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs) } func (d *Database) StateEntriesForTuples( ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples) +} + +func (d *Database) stateEntriesForTuples( + ctx context.Context, txn *sql.Tx, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := []types.StateEntryList{} for i, entry := range entries { - entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples( } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { + return d.roomInfo(ctx, nil, roomID) +} + +func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { return &roomInfo, nil } - roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID) + roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID) if err == nil && roomInfo != nil { d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID) d.Cache.StoreRoomInfo(roomID, *roomInfo) @@ -153,13 +177,22 @@ func (d *Database) AddState( roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + return d.addState(ctx, nil, roomNID, stateBlockNIDs, state) +} + +func (d *Database) addState( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { if len(stateBlockNIDs) > 0 && len(state) > 0 { // Check to see if the event already appears in any of the existing state // blocks. If it does then we should not add it again, as this will just // result in excess state blocks and snapshots. // TODO: Investigate why this is happening - probably input_events.go! - blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) if berr != nil { return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) } @@ -180,7 +213,7 @@ func (d *Database) AddState( } } } - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { if len(state) > 0 { // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID @@ -205,7 +238,13 @@ func (d *Database) AddState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventNIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) } func (d *Database) SetState( @@ -219,24 +258,34 @@ func (d *Database) SetState( func (d *Database) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { - return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs) + return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs) } func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { - _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID) + return d.snapshotNIDFromEventID(ctx, nil, eventID) +} + +func (d *Database) snapshotNIDFromEventID( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID) return stateNID, err } func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { - return d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs) +} + +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs) if err != nil { return nil, err } @@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type nids = append(nids, nid) } - return d.Events(ctx, nids) + return d.events(ctx, txn, nids) } func (d *Database) LatestEventIDs( @@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs( func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { - return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) + return d.stateBlockNIDs(ctx, nil, stateNIDs) +} + +func (d *Database) stateBlockNIDs( + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs) } func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.stateEntries(ctx, nil, stateBlockNIDs) +} + +func (d *Database) stateEntries( + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( - ctx, stateBlockNIDs, + ctx, txn, stateBlockNIDs, ) if err != nil { return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) } lists := make([]types.StateEntryList, 0, len(entries)) for i, entry := range entries { - eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil) if err != nil { return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } @@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias) } func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID) + return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID) } func (d *Database) GetCreatorIDForAlias( ctx context.Context, alias string, ) (string, error) { - return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias) + return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias) } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { @@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req senderMembershipEventNID, senderMembership, isRoomforgotten, err := d.MembershipTable.SelectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, + ctx, nil, roomNID, requestSenderUserNID, ) if err == sql.ErrNoRows { // The user has never been a member of that room @@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req func (d *Database) GetMembershipEventNIDsForRoom( ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool, +) ([]types.EventNID, error) { + return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly) +} + +func (d *Database) getMembershipEventNIDsForRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool, ) ([]types.EventNID, error) { if joinOnly { return d.MembershipTable.SelectMembershipsFromRoomAndMembership( - ctx, roomNID, tables.MembershipStateJoin, localOnly, + ctx, txn, roomNID, tables.MembershipStateJoin, localOnly, ) } - return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly) + return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly) } func (d *Database) GetInvitesForUser( @@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) { - return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) + return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID) } func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + return d.events(ctx, nil, eventNIDs) +} + +func (d *Database) events( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs) if err != nil { return nil, err } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } var roomNIDs map[types.EventNID]types.RoomNID - roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs) + roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs) if err != nil { return nil, err } @@ -398,7 +471,7 @@ func (d *Database) Events( } fetchNIDList = append(fetchNIDList, n) } - dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList) + dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList) if err != nil { return nil, err } @@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater( return updater, err } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*LatestEventsUpdater, error) { - if d.GetLatestEventsForUpdateFn != nil { - return d.GetLatestEventsForUpdateFn(ctx, roomInfo) +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*RoomUpdater, error) { + if d.GetRoomUpdaterFn != nil { + return d.GetRoomUpdaterFn(ctx, roomInfo) } txn, err := d.DB.Begin() if err != nil { return nil, err } - var updater *LatestEventsUpdater + var updater *RoomUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) + updater, err = NewRoomUpdater(ctx, d, txn, roomInfo) return err }) return updater, err @@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate( func (d *Database) StoreEvent( ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID, isRejected bool, +) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { + return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected) +} + +func (d *Database) storeEvent( + ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event, + authEventNIDs []types.EventNID, isRejected bool, ) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -472,8 +552,11 @@ func (d *Database) StoreEvent( redactedEventID string err error ) - - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var txn *sql.Tx + if updater != nil { + txn = updater.txn + } + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { // TODO: Here we should aim to have two different code paths for new rooms // vs existing ones. @@ -546,42 +629,32 @@ func (d *Database) StoreEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - var roomInfo *types.RoomInfo - var updater *LatestEventsUpdater if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. - updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err) - } - // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents - // and EndTransaction in a writer then it's possible for a new write txn to be made between the two - // function calls which will then fail with 'database is locked'. This new write txn would HAVE to be - // something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to - // SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases - // as they don't go via InputRoomEvents - err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error { - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return fmt.Errorf("updater.StorePreviousEvents: %w", err) + succeeded := false + if updater == nil { + var roomInfo *types.RoomInfo + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) } - succeeded := true - err = sqlutil.EndTransaction(updater, &succeeded) - return err - }) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", err + if roomInfo == nil && len(prevEvents) > 0 { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) } + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + succeeded = true } return eventNID, roomNID, types.StateAtEvent{ @@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { - return d.PublishedTable.SelectAllPublishedRooms(ctx, true) + return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true) } func (d *Database) assignRoomNID( @@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s eventNIDs = append(eventNIDs, e.EventNID) } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } // return the event requested for _, e := range entries { if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { - data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID}) if err != nil { return nil, err } @@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership } return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) } - roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) } - roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) if err != nil { return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) } @@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } // we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which // isn't a failure. - eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes) + eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err) } @@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys) + eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } } } - eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs) + eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) if err != nil { eventIDs = map[types.EventNID]string{} } - events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs) + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err) } @@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs) if err != nil { return nil, err } - userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs) if err != nil { return nil, err } @@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) stateKeyNIDs[i] = nid i++ } - nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs) if err != nil { return nil, err } @@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { - return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID) + return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) } // GetServerInRoom returns true if we think a server is in a given room or false otherwise. func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { - return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName) + return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName) } // GetKnownUsers searches all users that userID knows about. @@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin if err != nil { return nil, err } - return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) + return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDs(ctx) + return d.RoomsTable.SelectRoomIDs(ctx, nil) } // ForgetRoom sets a users room to forgotten func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error { - roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID}) + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID}) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 53b219294..f470ea326 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON( } func (s *eventJSONStatements) BulkSelectEventJSON( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]tables.EventJSONPair, error) { iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { iEventNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 62fbce2d0..bf12d5b83 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( - ctx context.Context, eventStateKeys []string, + ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { iEventStateKeys := make([]interface{}, len(eventStateKeys)) for k, v := range eventStateKeys { iEventStateKeys[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + } else { + rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...) + } if err != nil { return nil, err } @@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID( } func (s *eventStateKeyStatements) BulkSelectEventStateKey( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, ) (map[types.EventStateKeyNID]string, error) { iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1) - - rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + stmt := sqlutil.TxStmt(txn, selectPrep) + rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 22df3fb22..f2c9c42fe 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID( } func (s *eventTypeStatements) BulkSelectEventTypeNID( - ctx context.Context, eventTypes []string, + ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { /////////////// iEventTypes := make([]interface{}, len(eventTypes)) @@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID( if err != nil { return nil, err } + stmt := sqlutil.TxStmt(txn, selectPrep) /////////////// - rows, err := selectPrep.QueryContext(ctx, iEventTypes...) + rows, err := stmt.QueryContext(ctx, iEventTypes...) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 7483e2815..e1e6a597c 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateEntry, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) @@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError func (s *eventStatements) BulkSelectStateEventByNID( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { tuples := stateKeyTupleSorter(stateKeyTuples) @@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( if err != nil { return nil, fmt.Errorf("s.db.Prepare: %w", err) } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, params...) if err != nil { return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) @@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID( // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. func (s *eventStatements) BulkSelectStateAtEventByID( - ctx context.Context, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StateAtEvent, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) @@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( if err != nil { return nil, err } + selectPrep = sqlutil.TxStmt(txn, selectPrep) ////////////// rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) @@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { /////////////// iEventNIDs := make([]interface{}, len(eventNIDs)) for k, v := range eventNIDs { @@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) @@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { @@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) /////////////// rows, err := selectStmt.QueryContext(ctx, iEventIDs...) if err != nil { @@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } func (s *eventStatements) SelectRoomNIDsForEventNIDs( - ctx context.Context, eventNIDs []types.EventNID, + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) (map[types.EventNID]types.RoomNID, error) { sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iEventNIDs := make([]interface{}, len(eventNIDs)) for i, v := range eventNIDs { iEventNIDs[i] = v diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index c1d7347ae..d54d313a9 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { } func (s *inviteStatements) InsertInviteEvent( - ctx context.Context, - txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, + inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte, ) (bool, error) { @@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent( } func (s *inviteStatements) UpdateInviteRetired( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { // gather all the event IDs we will retire stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) @@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired( // selectInviteActiveForUserInRoom returns a list of sender state key NIDs func (s *inviteStatements) SelectInviteActiveForUserInRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, ) ([]types.EventStateKeyNID, []string, error) { - rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt) + rows, err := stmt.QueryContext( ctx, targetUserNID, roomNID, ) if err != nil { diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 2e58431d3..181b4b4c9 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate( } func (s *membershipStatements) SelectMembershipFromRoomAndTarget( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) { - err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = stmt.QueryRowContext( ctx, roomNID, targetUserNID, ).Scan(&membership, &eventNID, &forgotten) return } func (s *membershipStatements) SelectMembershipsFromRoom( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var selectStmt *sql.Stmt @@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } else { selectStmt = s.selectMembershipsFromRoomStmt } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, roomNID) if err != nil { return nil, err @@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom( } func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership tables.MembershipState, localOnly bool, ) (eventNIDs []types.EventNID, err error) { var stmt *sql.Stmt @@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership( } else { stmt = s.selectMembershipsFromRoomAndMembershipStmt } + stmt = sqlutil.TxStmt(txn, stmt) rows, err := stmt.QueryContext(ctx, roomNID, membership) if err != nil { return @@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership( } func (s *membershipStatements) SelectRoomsWithMembership( - ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, + ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState, ) ([]types.RoomNID, error) { - rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt) + rows, err := stmt.QueryContext(ctx, membershipState, userID) if err != nil { return nil, err } @@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership( return roomNIDs, nil } -func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, query, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...) + } if err != nil { return nil, err } @@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, return result, rows.Err() } -func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { - rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) if err != nil { return nil, err } @@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type } func (s *membershipStatements) UpdateForgetMembership( - ctx context.Context, - txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool, ) error { _, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext( @@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership( return err } -func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { +func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) { var nid types.RoomNID - err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil @@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room return found, nil } -func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { +func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) { var nid types.RoomNID - err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) + stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt) + err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid) if err != nil { if err == sql.ErrNoRows { return false, nil diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index b07c0ac42..9e416ace3 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished( } func (s *publishedStatements) SelectPublishedFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (published bool, err error) { - err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published) + stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt) + err = stmt.QueryRowContext(ctx, roomID).Scan(&published) if err == sql.ErrNoRows { return false, nil } @@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID( } func (s *publishedStatements) SelectAllPublishedRooms( - ctx context.Context, published bool, + ctx context.Context, txn *sql.Tx, published bool, ) ([]string, error) { - rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published) + stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt) + rows, err := stmt.QueryContext(ctx, published) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index 323945b88..7c7bead95 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias( } func (s *roomAliasesStatements) SelectRoomIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (roomID string, err error) { - err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID) + stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&roomID) if err == sql.ErrNoRows { return "", nil } @@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias( } func (s *roomAliasesStatements) SelectAliasesFromRoomID( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (aliases []string, err error) { aliases = []string{} - rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID) + stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := stmt.QueryContext(ctx, roomID) if err != nil { return } @@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID( } func (s *roomAliasesStatements) SelectCreatorIDFromAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (creatorID string, err error) { - err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index c441daec0..5413475e2 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { }.Prepare(db) } -func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { - rows, err := s.selectRoomIDsStmt.QueryContext(ctx) +func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } @@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { return roomIDs, nil } -func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { +func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string - err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan( + stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan( &info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON, ) if err != nil { @@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs( } func (s *roomStatements) SelectRoomVersionsForRoomNIDs( - ctx context.Context, roomNIDs []types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, ) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) { sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) sqlPrep, err := s.db.Prepare(sqlStr) if err != nil { return nil, err } + sqlPrep = sqlutil.TxStmt(txn, sqlPrep) iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v @@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs( return result, nil } -func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) { iRoomNIDs := make([]interface{}, len(roomNIDs)) for i, v := range roomNIDs { iRoomNIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + } if err != nil { return nil, err } @@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types return roomIDs, nil } -func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i, v := range roomIDs { iRoomIDs[i] = v } sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) - rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + var rows *sql.Rows + var err error + if txn != nil { + rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...) + } else { + rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + } if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 58b0b5dc2..d51fc492d 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, + ctx context.Context, txn *sql.Tx, entries types.StateEntries, ) (id types.StateBlockNID, err error) { entries = entries[:util.SortAndUnique(entries)] @@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData( if err != nil { return 0, fmt.Errorf("json.Marshal: %w", err) } - err = s.insertStateDataStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt) + err = stmt.QueryRowContext( ctx, nids.Hash(), js, ).Scan(&id) return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs types.StateBlockNIDs, + ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs, ) ([][]types.EventNID, error) { intfs := make([]interface{}, len(stateBlockNIDs)) for i := range stateBlockNIDs { @@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 040d99ae6..3c4bde3f5 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState( } func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, ) ([]types.StateBlockNIDList, error) { nids := make([]interface{}, len(stateNIDs)) for k, v := range stateNIDs { @@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( if err != nil { return nil, err } + selectStmt = sqlutil.TxStmt(txn, selectStmt) rows, err := selectStmt.QueryContext(ctx, nids...) if err != nil { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 1fcc7989d..325c253b5 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, - GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate, + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + StateBlockTable: stateBlock, + StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, + GetRoomUpdaterFn: d.GetRoomUpdater, } return nil } @@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return false } -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomInfo types.RoomInfo, -) (*shared.LatestEventsUpdater, error) { +func (d *Database) GetRoomUpdater( + ctx context.Context, roomInfo *types.RoomInfo, +) (*shared.RoomUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional // write transactions independent of this one which will consistently cause // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) + return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 6ad7ed2e8..fed39b944 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -18,20 +18,20 @@ type EventJSONPair struct { type EventJSON interface { // Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions). InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error - BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error) + BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error) } type EventTypes interface { InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error) - BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error) } type EventStateKeys interface { InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) - BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) } type Events interface { @@ -42,12 +42,12 @@ type Events interface { SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError - BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) + BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. - BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error) UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error @@ -55,12 +55,12 @@ type Events interface { BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) // BulkSelectEventID returns a map from numeric event ID to string event ID. - BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. - BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) - SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) + SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } type Rooms interface { @@ -69,29 +69,29 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) - SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) - SelectRoomIDs(ctx context.Context) ([]string, error) - BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) - BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) + SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) + SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) - BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) - SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) - SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) - SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) + SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error) + SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error) + SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error) DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) } @@ -106,7 +106,7 @@ type Invites interface { InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error) UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error) // SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids. - SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) + SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error) } type MembershipState int64 @@ -121,24 +121,24 @@ const ( type Membership interface { InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error) - SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) - SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) - SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) + SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) + SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error - SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the // counts of how many rooms they are joined. - SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) - SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) + SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error - SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) - SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) + SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) + SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) } type Published interface { UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) - SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) - SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) + SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error) + SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error) } type RedactionInfo struct { From 532f445c4e31396fc3aa4f52e0e078cd499bc39a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Feb 2022 12:13:07 +0000 Subject: [PATCH 27/81] Remove roomserver input deadlines (#2144) It isn't really clear that the deadlines actually help in any way. Currently we can use up our 2 minutes doing something, run out of context time and then return an error which causes the transaction to rollback and forgetting everything we've done. If the message came to us from NATS then we probably will end up retrying just to be in the same situation. We'd be really a lot better if we just spent the time reconciling the problem in the first place, and then we're much less likely to need to fetch those missing auth or prev events in the future. Also includes matrix-org/gomatrixserverlib#287 so we don't wait so long for servers that are obviously dead. --- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/input/input_events.go | 4 ---- roomserver/internal/input/input_missing.go | 4 ---- 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index c36fbe3b3..bd94713a8 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index f72b4f4fb..038ccef5d 100644 --- a/go.sum +++ b/go.sum @@ -1021,8 +1021,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6 h1:v+WZXRsn9IaW3mta6bPICWbWcaZbnB1u1ZFlGFi/YU8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220131142840-8d9c3d71ffb6/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 h1:f6Hh7D3EOTl1uUr76FiyHNA1h4pKBhcVUtyHbxn0hKA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index f3fa83d83..0ca5c31a9 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -369,10 +369,6 @@ func (r *Inputer) fetchAuthEvents( known map[string]*types.Event, servers []gomatrixserverlib.ServerName, ) error { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) - defer cancel() - unknown := map[string]struct{}{} authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4cd2b3de1..4d3306660 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -37,10 +37,6 @@ type missingStateReq struct { func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, ) error { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, MaximumMissingProcessingTime) - defer cancel() - // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: From 9de7efa0b095f40457f0e348632c77326dcb4a42 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Fri, 4 Feb 2022 14:08:13 +0100 Subject: [PATCH 28/81] Remove sarama/saramajetstream dependencies (#2138) * Remove dependency on saramajetstream & sarama Signed-off-by: Till Faelligen * Remove internal.ContinualConsumer from federationapi * Remove internal.ContinualConsumer from syncapi * Remove internal.ContinualConsumer from keyserver * Move to new Prepare function * Remove saramajetstream & sarama dependency * Delete unneeded file * Remove duplicate import * Log error instead of silently irgnoring it * Move `OffsetNewest` and `OffsetOldest` into keyserver types, change them to be more sane values * Fix comments Co-authored-by: Neil Alexander --- appservice/appservice.go | 2 +- clientapi/clientapi.go | 2 +- eduserver/eduserver.go | 2 +- federationapi/consumers/keychange.go | 78 +++++----- federationapi/consumers/roomserver.go | 5 + federationapi/federationapi.go | 4 +- federationapi/storage/interface.go | 2 - go.mod | 4 +- go.sum | 54 ------- internal/consumers.go | 139 ------------------ keyserver/api/api.go | 2 +- keyserver/consumers/cross_signing.go | 62 ++++---- keyserver/keyserver.go | 4 +- keyserver/storage/interface.go | 5 +- .../storage/postgres/key_changes_table.go | 5 - .../storage/sqlite3/key_changes_table.go | 5 - keyserver/storage/storage_test.go | 6 +- keyserver/storage/tables/interface.go | 2 +- keyserver/types/storage.go | 13 +- roomserver/roomserver.go | 2 +- setup/jetstream/nats.go | 19 +-- syncapi/consumers/keychange.go | 83 +++++------ syncapi/internal/keychange.go | 6 +- syncapi/storage/interface.go | 3 - syncapi/syncapi.go | 6 +- 25 files changed, 155 insertions(+), 360 deletions(-) delete mode 100644 internal/consumers.go diff --git a/appservice/appservice.go b/appservice/appservice.go index 924a609ea..7e7c67f53 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -58,7 +58,7 @@ func NewInternalAPI( }, }, } - js, _, _ := jetstream.Prepare(&base.Cfg.Global.JetStream) + js := jetstream.Prepare(&base.Cfg.Global.JetStream) // Create a connection to the appservice postgres DB appserviceDB, err := storage.NewDatabase(&base.Cfg.AppServiceAPI.Database) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 7c772125a..d678ada96 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -49,7 +49,7 @@ func AddPublicRoutes( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, ) { - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncProducer := &producers.SyncAPIProducer{ JetStream: js, diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index db03001ba..febcf2864 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -42,7 +42,7 @@ func NewInternalAPI( ) api.EDUServerInputAPI { cfg := &base.Cfg.EDUServer - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return &input.EDUServerInputAPI{ Cache: eduCache, diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 6a737d0ad..1ec9f4c18 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -17,80 +17,73 @@ package consumers import ( "context" "encoding/json" - "fmt" - "github.com/Shopify/sarama" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { ctx context.Context - consumer *internal.ContinualConsumer + jetstream nats.JetStreamContext + durable string db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName rsAPI roomserverAPI.RoomserverInternalAPI + topic string } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. func NewKeyChangeConsumer( process *process.ProcessContext, cfg *config.KeyServer, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, queues *queue.OutgoingQueues, store storage.Database, rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { - c := &KeyChangeConsumer{ - ctx: process.Context(), - consumer: &internal.ContinualConsumer{ - Process: process, - ComponentName: "federationapi/keychange", - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Consumer: kafkaConsumer, - PartitionStore: store, - }, + return &KeyChangeConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.TopicFor("FederationAPIKeyChangeConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), queues: queues, db: store, serverName: cfg.Matrix.ServerName, rsAPI: rsAPI, } - c.consumer.ProcessMessage = c.onMessage - - return c } // Start consuming from key servers func (t *KeyChangeConsumer) Start() error { - if err := t.consumer.Start(); err != nil { - return fmt.Errorf("t.consumer.Start: %w", err) - } - return nil + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, t.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (t *KeyChangeConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -102,9 +95,9 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { } } -func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { if m.DeviceKeys == nil { - return nil + return true } logger := logrus.WithField("user_id", m.UserID) @@ -112,10 +105,10 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { _, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID) if err != nil { logger.WithError(err).Error("Failed to extract domain from key change event") - return nil + return true } if originServerName != t.serverName { - return nil + return true } var queryRes roomserverAPI.QueryRoomsForUserResponse @@ -125,13 +118,13 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -149,24 +142,26 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { Keys: m.KeyJSON, } if edu.Content, err = json.Marshal(event); err != nil { - return err + logger.WithError(err).Error("failed to marshal EDU JSON") + return true } - logrus.Infof("Sending device list update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + logger.Infof("Sending device list update message to %q", destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } -func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure") - return nil + return true } if host != gomatrixserverlib.ServerName(t.serverName) { // Ignore any messages that didn't originate locally, otherwise we'll // end up parroting information we received from other servers. - return nil + return true } logger := logrus.WithField("user_id", output.UserID) @@ -177,13 +172,13 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { }, &queryRes) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") - return nil + return true } // send this key change to all servers who share rooms with this user. destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") - return nil + return true } // Pack the EDU and marshal it @@ -193,11 +188,12 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { } if edu.Content, err = json.Marshal(output); err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping") - return nil + return true } logger.Infof("Sending cross-signing update message to %q", destinations) - return t.queues.SendEDU(edu, t.serverName, destinations) + err = t.queues.SendEDU(edu, t.serverName, destinations) + return err == nil } func prevID(streamID int) []int { diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index ac29f930b..e9862000a 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -114,6 +114,11 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) } } + case api.OutputTypeNewInviteEvent: + log.WithField("type", output.Type).Debug( + "received new invite, send device keys", + ) + case api.OutputTypeNewInboundPeek: if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { log.WithFields(log.Fields{ diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 63387b9d8..a982d8009 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -92,7 +92,7 @@ func NewInternalAPI( FailuresUntilBlacklist: cfg.FederationMaxRetries, } - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -120,7 +120,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI, + base.ProcessContext, &base.Cfg.KeyServer, js, queues, federationDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 21a919f6a..3fa8d1f7a 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -19,12 +19,10 @@ import ( "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) diff --git a/go.mod b/go.mod index bd94713a8..fc18ce07e 100644 --- a/go.mod +++ b/go.mod @@ -11,13 +11,12 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 - github.com/Shopify/sarama v1.31.0 github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible github.com/docker/go-connections v0.4.0 + github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 github.com/gorilla/mux v1.8.0 @@ -74,6 +73,7 @@ require ( golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect nhooyr.io/websocket v1.8.7 ) diff --git a/go.sum b/go.sum index 038ccef5d..3f8a99f48 100644 --- a/go.sum +++ b/go.sum @@ -100,17 +100,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= github.com/RyanCarrier/dijkstra v1.0.0/go.mod h1:5agGUBNEtUAGIANmbw09fuO3a2htPEkc1jNH01qxCWA= github.com/RyanCarrier/dijkstra-1 v0.0.0-20170512020943-0e5801a26345/go.mod h1:OK4EvWJ441LQqGzed5NGB6vKBAE34n3z7iayPcEwr30= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 h1:i3fOph9Hjleo6LbuqN9ODFxnwt7mOtYMpCGeC8qJN50= -github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32/go.mod h1:ne+jkLlzafIzaE4Q0Ze81T27dNgXe1wxovVEoAtSHTc= github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ= -github.com/Shopify/sarama v1.29.0/go.mod h1:2QpgD79wpdAESqNQMxNc0KYMkycd4slxGdV3TWSVqrU= -github.com/Shopify/sarama v1.31.0 h1:gObk7jCPutDxf+E6GA5G21noAZsi1SvP9ftCQYqpzus= -github.com/Shopify/sarama v1.31.0/go.mod h1:BeW3gXRc/CxgAsrSly2RE9nIXUfC9ezb7QHBPVhvzjI= -github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/Shopify/toxiproxy/v2 v2.3.0 h1:62YkpiP4bzdhKMH+6uC5E95y608k3zDwdzuBMsnn3uQ= -github.com/Shopify/toxiproxy/v2 v2.3.0/go.mod h1:KvQTtB6RjCJY4zqNJn7C7JDFgsG5uoHYDirfUfpIm0c= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= @@ -353,12 +344,6 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q= -github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -379,8 +364,6 @@ github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 h1:u/UEqS66A5ckRmS4yNp github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6/go.mod h1:1i71OnUq3iUe1ma7Lr6yG6/rjvM3emb6yoL7xLFzcVQ= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= @@ -498,8 +481,6 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= -github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.2.0/go.mod h1:gq31gQ8wEHkR+WekdWsqDuf8pXTUZA9BnnzTuPz1Y9U= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= github.com/gologme/log v1.3.0/go.mod h1:yKT+DvIPdDdDoPtqFrFxheooyVmoqi0BAsw+erN3wA4= @@ -552,8 +533,6 @@ github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -580,8 +559,6 @@ github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHh github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= -github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= -github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -667,18 +644,6 @@ github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsj github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o= github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= -github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= -github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= -github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= -github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= -github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= -github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= -github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJzodkA= -github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= -github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= -github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= @@ -747,10 +712,7 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.11.13/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= -github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.14.2 h1:S0OHlFk/Gbon/yauFJ4FfJJF5V0fc5HbBTJazi28pRw= github.com/klauspost/compress v1.14.2/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= @@ -1250,9 +1212,6 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= -github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= -github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1305,8 +1264,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= -github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= @@ -1432,7 +1389,6 @@ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= -github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= @@ -1463,11 +1419,6 @@ github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPyS github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= -github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= -github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= @@ -1550,7 +1501,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= @@ -1655,7 +1605,6 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -1664,7 +1613,6 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210927181540-4e4d966f7476/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220105145211-5b0dc2dfae98/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -1798,7 +1746,6 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7-0.20210503195748-5c7c50ebbd4f/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= @@ -2034,7 +1981,6 @@ gopkg.in/yaml.v2 v2.0.0-20170712054546-1be3d31502d6/go.mod h1:JAlM8MvJe8wmxCU4Bl gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/consumers.go b/internal/consumers.go deleted file mode 100644 index 3a4e0b7f8..000000000 --- a/internal/consumers.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "fmt" - - "github.com/Shopify/sarama" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" -) - -// A PartitionStorer has the storage APIs needed by the consumer. -type PartitionStorer interface { - // PartitionOffsets returns the offsets the consumer has reached for each partition. - PartitionOffsets(ctx context.Context, topic string) ([]sqlutil.PartitionOffset, error) - // SetPartitionOffset records where the consumer has reached for a partition. - SetPartitionOffset(ctx context.Context, topic string, partition int32, offset int64) error -} - -// A ContinualConsumer continually consumes logs even across restarts. It requires a PartitionStorer to -// remember the offset it reached. -type ContinualConsumer struct { - // The parent context for the listener, stop consuming when this context is done - Process *process.ProcessContext - // The component name - ComponentName string - // The kafkaesque topic to consume events from. - // This is the name used in kafka to identify the stream to consume events from. - Topic string - // A kafkaesque stream consumer providing the APIs for talking to the event source. - // The interface is taken from a client library for Apache Kafka. - // But any equivalent event streaming protocol could be made to implement the same interface. - Consumer sarama.Consumer - // A thing which can load and save partition offsets for a topic. - PartitionStore PartitionStorer - // ProcessMessage is a function which will be called for each message in the log. Return an error to - // stop processing messages. See ErrShutdown for specific control signals. - ProcessMessage func(msg *sarama.ConsumerMessage) error - // ShutdownCallback is called when ProcessMessage returns ErrShutdown, after the partition has been saved. - // It is optional. - ShutdownCallback func() -} - -// ErrShutdown can be returned from ContinualConsumer.ProcessMessage to stop the ContinualConsumer. -var ErrShutdown = fmt.Errorf("shutdown") - -// Start starts the consumer consuming. -// Starts up a goroutine for each partition in the kafka stream. -// Returns nil once all the goroutines are started. -// Returns an error if it can't start consuming for any of the partitions. -func (c *ContinualConsumer) Start() error { - _, err := c.StartOffsets() - return err -} - -// StartOffsets is the same as Start but returns the loaded offsets as well. -func (c *ContinualConsumer) StartOffsets() ([]sqlutil.PartitionOffset, error) { - offsets := map[int32]int64{} - - partitions, err := c.Consumer.Partitions(c.Topic) - if err != nil { - return nil, err - } - for _, partition := range partitions { - // Default all the offsets to the beginning of the stream. - offsets[partition] = sarama.OffsetOldest - } - - storedOffsets, err := c.PartitionStore.PartitionOffsets(context.TODO(), c.Topic) - if err != nil { - return nil, err - } - for _, offset := range storedOffsets { - // We've already processed events from this partition so advance the offset to where we got to. - // ConsumePartition will start streaming from the message with the given offset (inclusive), - // so increment 1 to avoid getting the same message a second time. - offsets[offset.Partition] = 1 + offset.Offset - } - - var partitionConsumers []sarama.PartitionConsumer - for partition, offset := range offsets { - pc, err := c.Consumer.ConsumePartition(c.Topic, partition, offset) - if err != nil { - for _, p := range partitionConsumers { - p.Close() // nolint: errcheck - } - return nil, err - } - partitionConsumers = append(partitionConsumers, pc) - } - for _, pc := range partitionConsumers { - go c.consumePartition(pc) - if c.Process != nil { - c.Process.ComponentStarted() - go func(pc sarama.PartitionConsumer) { - <-c.Process.WaitForShutdown() - _ = pc.Close() - c.Process.ComponentFinished() - logrus.Infof("Stopped consumer for %q topic %q", c.ComponentName, c.Topic) - }(pc) - } - } - - return storedOffsets, nil -} - -// consumePartition consumes the room events for a single partition of the kafkaesque stream. -func (c *ContinualConsumer) consumePartition(pc sarama.PartitionConsumer) { - defer pc.Close() // nolint: errcheck - for message := range pc.Messages() { - msgErr := c.ProcessMessage(message) - // Advance our position in the stream so that we will start at the right position after a restart. - if err := c.PartitionStore.SetPartitionOffset(context.TODO(), c.Topic, message.Partition, message.Offset); err != nil { - panic(fmt.Errorf("the ContinualConsumer in %q failed to SetPartitionOffset: %w", c.ComponentName, err)) - } - // Shutdown if we were told to do so. - if msgErr == ErrShutdown { - if c.ShutdownCallback != nil { - c.ShutdownCallback() - } - return - } - } -} diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 0eea2f0fa..3933961c1 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -228,7 +228,7 @@ type QueryKeyChangesRequest struct { // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use sarama.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). ToOffset int64 } diff --git a/keyserver/consumers/cross_signing.go b/keyserver/consumers/cross_signing.go index 4b2bd4a9a..a533006ff 100644 --- a/keyserver/consumers/cross_signing.go +++ b/keyserver/consumers/cross_signing.go @@ -18,29 +18,30 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" - - "github.com/Shopify/sarama" ) type OutputCrossSigningKeyUpdateConsumer struct { - eduServerConsumer *internal.ContinualConsumer - keyDB storage.Database - keyAPI api.KeyInternalAPI - serverName string + ctx context.Context + keyDB storage.Database + keyAPI api.KeyInternalAPI + serverName string + jetstream nats.JetStreamContext + durable string + topic string } func NewOutputCrossSigningKeyUpdateConsumer( process *process.ProcessContext, cfg *config.Dendrite, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, keyDB storage.Database, keyAPI api.KeyInternalAPI, ) *OutputCrossSigningKeyUpdateConsumer { @@ -48,60 +49,58 @@ func NewOutputCrossSigningKeyUpdateConsumer( // topic. We will only produce events where the UserID matches our server name, // and we will only consume events where the UserID does NOT match our server // name (because the update came from a remote server). - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "keyserver/keyserver", - Topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - Consumer: kafkaConsumer, - PartitionStore: keyDB, - } s := &OutputCrossSigningKeyUpdateConsumer{ - eduServerConsumer: &consumer, - keyDB: keyDB, - keyAPI: keyAPI, - serverName: string(cfg.Global.ServerName), + ctx: process.Context(), + keyDB: keyDB, + jetstream: js, + durable: cfg.Global.JetStream.Durable("KeyServerCrossSigningConsumer"), + topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), + keyAPI: keyAPI, + serverName: string(cfg.Global.ServerName), } - consumer.ProcessMessage = s.onMessage return s } func (s *OutputCrossSigningKeyUpdateConsumer) Start() error { - return s.eduServerConsumer.Start() + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } // onMessage is called in response to a message received on the // key change events topic from the key server. -func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: return t.onCrossSigningMessage(m) default: - return nil + return true } } -func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error { +func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { output := m.CrossSigningKeyUpdate _, host, err := gomatrixserverlib.SplitID('@', output.UserID) if err != nil { logrus.WithError(err).Errorf("eduserver output log: user ID parse failure") - return nil + return true } if host == gomatrixserverlib.ServerName(s.serverName) { // Ignore any messages that contain information about our own users, as // they already originated from this server. - return nil + return true } uploadReq := &api.PerformUploadDeviceKeysRequest{ UserID: output.UserID, @@ -114,5 +113,8 @@ func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.Device } uploadRes := &api.PerformUploadDeviceKeysResponse{} s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes) - return uploadRes.Error + if uploadRes.Error != nil { + return false + } + return true } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 8cc50ea0d..61ccc0303 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -40,7 +40,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { @@ -66,7 +66,7 @@ func NewInternalAPI( }() keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer( - base.ProcessContext, base.Cfg, consumer, db, ap, + base.ProcessContext, base.Cfg, js, db, ap, ) if err := keyconsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer") diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 87feae47d..0110860ea 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -18,15 +18,12 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { - internal.PartitionStorer - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) @@ -71,7 +68,7 @@ type Database interface { StoreKeyChange(ctx context.Context, userID string) (int64, error) // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of sarama.OffsetNewest means no upper limit. + // A to offset of types.OffsetNewest means no upper limit. // Returns the offset of the latest key change. KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index 20d227c24..f93a94bd3 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -17,9 +17,7 @@ package postgres import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -78,9 +76,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index d43c15ca9..e035e8c9c 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -17,9 +17,7 @@ package sqlite3 import ( "context" "database/sql" - "math" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -76,9 +74,6 @@ func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID strin func (s *keyChangesStatements) SelectKeyChanges( ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { - if toOffset == sarama.OffsetNewest { - toOffset = math.MaxInt64 - } latestOffset = fromOffset rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 2f8cf809b..c4c99d8c4 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -9,8 +9,8 @@ import ( "reflect" "testing" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/setup/config" ) @@ -50,7 +50,7 @@ func TestKeyChanges(t *testing.T) { MustNotError(t, err) deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } @@ -74,7 +74,7 @@ func TestKeyChangesNoDupes(t *testing.T) { } deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 0d94c94cc..e44757e1a 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -46,7 +46,7 @@ type DeviceKeys interface { type KeyChanges interface { InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) Prepare() error diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go index 3480ec65f..7fb90454e 100644 --- a/keyserver/types/storage.go +++ b/keyserver/types/storage.go @@ -14,7 +14,18 @@ package types -import "github.com/matrix-org/gomatrixserverlib" +import ( + "math" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + // OffsetNewest tells e.g. the database to get the most current data + OffsetNewest int64 = math.MaxInt64 + // OffsetOldest tells e.g. the database to get the oldest data + OffsetOldest int64 = 0 +) // KeyTypePurposeToInt maps a purpose to an integer, which is used in the // database to reduce the amount of space taken up by this column. diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 669957be1..e1b84b80c 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -50,7 +50,7 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - js, _, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) return internal.NewRoomserverAPI( cfg, roomserverDB, js, diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 5d7937b5c..77ad2b721 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -5,20 +5,17 @@ import ( "sync" "time" - "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" - saramajs "github.com/S7evinK/saramajetstream" natsserver "github.com/nats-io/nats-server/v2/server" - "github.com/nats-io/nats.go" natsclient "github.com/nats-io/nats.go" ) var natsServer *natsserver.Server var natsServerMutex sync.Mutex -func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func Prepare(cfg *config.JetStream) natsclient.JetStreamContext { // check if we need an in-process NATS Server if len(cfg.Addresses) != 0 { return setupNATS(cfg, nil) @@ -52,20 +49,20 @@ func Prepare(cfg *config.JetStream) (nats.JetStreamContext, sarama.Consumer, sar return setupNATS(cfg, nc) } -func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContext, sarama.Consumer, sarama.SyncProducer) { +func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) natsclient.JetStreamContext { if nc == nil { var err error - nc, err = nats.Connect(strings.Join(cfg.Addresses, ",")) + nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ",")) if err != nil { logrus.WithError(err).Panic("Unable to connect to NATS") - return nil, nil, nil + return nil } } s, err := nc.JetStream() if err != nil { logrus.WithError(err).Panic("Unable to get JetStream context") - return nil, nil, nil + return nil } for _, stream := range streams { // streams are defined in streams.go @@ -80,7 +77,7 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex // If we're trying to keep everything in memory (e.g. unit tests) // then overwrite the storage policy. if cfg.InMemory { - stream.Storage = nats.MemoryStorage + stream.Storage = natsclient.MemoryStorage } // Namespace the streams without modifying the original streams @@ -93,7 +90,5 @@ func setupNATS(cfg *config.JetStream, nc *natsclient.Conn) (nats.JetStreamContex } } - consumer := saramajs.NewJetStreamConsumer(nc, s, "") - producer := saramajs.NewJetStreamProducer(nc, s, "") - return s, consumer, producer + return s } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 97685cc04..e806f76e6 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -18,84 +18,81 @@ import ( "context" "encoding/json" - "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // Call Start() to begin consuming from the key server. func NewOutputKeyChangeEventConsumer( process *process.ProcessContext, - serverName gomatrixserverlib.ServerName, + cfg *config.SyncAPI, topic string, - kafkaConsumer sarama.Consumer, + js nats.JetStreamContext, keyAPI api.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, ) *OutputKeyChangeEventConsumer { - - consumer := internal.ContinualConsumer{ - Process: process, - ComponentName: "syncapi/keychange", - Topic: topic, - Consumer: kafkaConsumer, - PartitionStore: store, - } - s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + serverName: cfg.Matrix.ServerName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } - consumer.ProcessMessage = s.onMessage - return s } // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - return s.keyChangeConsumer.Start() + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ) } -func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { +func (s *OutputKeyChangeEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { var m api.DeviceMessage - if err := json.Unmarshal(msg.Value, &m); err != nil { + if err := json.Unmarshal(msg.Data, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") - return nil + return true } if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil { // This probably shouldn't happen but stops us from panicking if we come // across an update that doesn't satisfy either types. - return nil + return true } switch m.Type { case api.TypeCrossSigningUpdate: @@ -107,9 +104,9 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) bool { if m.DeviceKeys == nil { - return nil + return true } output := m.DeviceKeys // work out who we need to notify about the new key @@ -120,7 +117,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -131,10 +128,10 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) bool { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -144,7 +141,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage if err != nil { logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server") sentry.CaptureException(err) - return err + return true } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 @@ -155,5 +152,5 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage s.notifier.OnNewKeyChange(types.StreamingToken{DeviceListPosition: posUpdate}, userID, output.UserID) } - return nil + return true } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 41efd4a07..fa1064b70 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -18,8 +18,8 @@ import ( "context" "strings" - "github.com/Shopify/sarama" keyapi "github.com/matrix-org/dendrite/keyserver/api" + keytypes "github.com/matrix-org/dendrite/keyserver/types" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -64,8 +64,8 @@ func DeviceListCatchup( } // now also track users who we already share rooms with but who have updated their devices between the two tokens - offset := sarama.OffsetOldest - toOffset := sarama.OffsetNewest + offset := keytypes.OffsetOldest + toOffset := keytypes.OffsetNewest if to > 0 && to > from { toOffset = int64(to) } diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9cff4cad1..b464ad9cd 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -19,7 +19,6 @@ import ( eduAPI "github.com/matrix-org/dendrite/eduserver/api" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -27,8 +26,6 @@ import ( ) type Database interface { - internal.PartitionStorer - MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 39bc233ae..72462459c 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -48,7 +48,7 @@ func AddPublicRoutes( federation *gomatrixserverlib.FederationClient, cfg *config.SyncAPI, ) { - js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) + js := jetstream.Prepare(&cfg.Matrix.JetStream) syncDB, err := storage.NewSyncServerDatasource(&cfg.Database) if err != nil { @@ -65,8 +65,8 @@ func AddPublicRoutes( requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( - process, cfg.Matrix.ServerName, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - consumer, keyAPI, rsAPI, syncDB, notifier, + process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), + js, keyAPI, rsAPI, syncDB, notifier, streams.DeviceListStreamProvider, ) if err = keyChangeConsumer.Start(); err != nil { From 585ced89bd846e9657dc5eb0535dcecc30476cbd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Feb 2022 14:44:45 +0000 Subject: [PATCH 29/81] Version 0.6.1 (#2145) --- CHANGES.md | 21 +++++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 95e24de11..3e8c7157e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,26 @@ # Changelog +## Dendrite 0.6.1 (2022-02-04) + +### Features + +* Roomserver inputs now take place with full transactional isolation in PostgreSQL deployments +* Pull consumers are now used instead of push consumers when retrieving messages from NATS to better guarantee ordering and to reduce redelivery of duplicate messages +* Further logging tweaks, particularly when joining rooms +* Improved calculation of servers in the room, when checking for missing auth/prev events or state +* Dendrite will now skip dead servers more quickly when federating by reducing the TCP dial timeout +* The key change consumers have now been converted to use native NATS code rather than a wrapper +* Go 1.16 is now the minimum supported version for Dendrite + +### Fixes + +* Local clients should now be notified correctly of invites +* The roomserver input API now has more time to process events, particularly when fetching missing events or state, which should fix a number of errors from expired contexts +* Fixed a panic that could happen due to a closed channel in the roomserver input API +* Logging in with uppercase usernames from old installations is now supported again (contributed by [hoernschen](https://github.com/hoernschen)) +* Federated room joins now have more time to complete and should not fail due to expired contexts +* Events that were sent to the roomserver along with a complete state snapshot are now persisted with the correct state, even if they were rejected or soft-failed + ## Dendrite 0.6.0 (2022-01-28) ### Features diff --git a/internal/version.go b/internal/version.go index f09daabd9..0e9b73637 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 6 - VersionPatch = 0 + VersionPatch = 1 VersionTag = "" // example: "rc1" ) From 00cbe75150cdeed263677a33be12c0c2df078bcb Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Feb 2022 16:16:50 +0000 Subject: [PATCH 30/81] Fix CPU spin from key change consumer when an invalid message is supplied (#2146) --- keyserver/consumers/cross_signing.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keyserver/consumers/cross_signing.go b/keyserver/consumers/cross_signing.go index a533006ff..aae69e960 100644 --- a/keyserver/consumers/cross_signing.go +++ b/keyserver/consumers/cross_signing.go @@ -114,7 +114,10 @@ func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.Device uploadRes := &api.PerformUploadDeviceKeysResponse{} s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes) if uploadRes.Error != nil { - return false + // If the error is due to a missing or invalid parameter then we'd might + // as well just acknowledge the message, because otherwise otherwise we'll + // just keep getting delivered a faulty message over and over again. + return uploadRes.Error.IsMissingParam || uploadRes.Error.IsInvalidParam } return true } From 908d881a6e2049ab150f58c0697773656fc27a98 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Feb 2022 17:49:01 +0000 Subject: [PATCH 31/81] Version 0.6.2 --- CHANGES.md | 6 ++++++ internal/version.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 3e8c7157e..07e09480a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,11 @@ # Changelog +## Dendrite 0.6.2 (2022-02-04) + +### Fixes + +* Resolves an issue where the key change consumer in the keyserver could consume extreme amounts of CPU + ## Dendrite 0.6.1 (2022-02-04) ### Features diff --git a/internal/version.go b/internal/version.go index 0e9b73637..de0b7c8c3 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 6 - VersionPatch = 1 + VersionPatch = 2 VersionTag = "" // example: "rc1" ) From a572f4db034e13aa38e122d1c4233b15e2356494 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 7 Feb 2022 19:10:01 +0000 Subject: [PATCH 32/81] Fix bugs that could wedge rooms (#2154) * Don't flake so badly for rejected events * Moar * Fix panic * Don't count rejected events as missing * Don't treat rejected events without state as missing * Revert "Don't count rejected events as missing" This reverts commit 4b6139b62eb91ba059b47415b0275964b37d9b43. * Missing events should be KindOld * If we have state, use it, regardless of memberships which could be stale now * Fetch missing state for KindOld too * Tweak the condition again * Clean up a bit * Use room updater to get latest events in a race-free way * Return the correct error * Improve errors --- roomserver/internal/input/input_events.go | 14 +++----- roomserver/internal/input/input_missing.go | 41 +++++++++++----------- roomserver/internal/query/query.go | 3 +- roomserver/types/types.go | 6 ++++ 4 files changed, 33 insertions(+), 31 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 0ca5c31a9..85189e476 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -297,7 +297,10 @@ func (r *Inputer) processRoomEvent( "soft_fail": softfail, "missing_prev": missingPrev, }).Warn("Stored rejected event") - return commitTransaction, rejectionErr + if rejectionErr != nil { + return commitTransaction, types.RejectedError(rejectionErr.Error()) + } + return commitTransaction, nil } switch input.Kind { @@ -483,16 +486,7 @@ func (r *Inputer) calculateAndSetState( roomState := state.NewStateResolution(updater, roomInfo) if input.HasState { - // Check here if we think we're in the room already. stateAtEvent.Overwrite = true - var joinEventNIDs []types.EventNID - // Request join memberships only for local users only. - if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { - // If we have no local users that are joined to the room then any state about - // the room that we have is quite possibly out of date. Therefore in that case - // we should overwrite it rather than merge it. - stateAtEvent.Overwrite = len(joinEventNIDs) == 0 - } // We've been told what the state at the event is so we don't need to calculate it. // Check that those state events are in the database and store the state. diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 4d3306660..7a72b0381 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -75,13 +76,15 @@ func (t *missingStateReq) processEventWithMissingState( // in the gap in the DAG for _, newEvent := range newEvents { _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{ - Kind: api.KindNew, + Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, SendAsServer: api.DoNotSendToOtherServers, }) if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) + } } } return nil @@ -183,8 +186,11 @@ func (t *missingStateReq) processEventWithMissingState( } // TODO: we could do this concurrently? for _, ire := range outlierRoomEvents { - if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil { - return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err) + _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + } } } @@ -205,7 +211,9 @@ func (t *missingStateReq) processEventWithMissingState( SendAsServer: api.DoNotSendToOtherServers, }) if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) + } } // Then send all of the newer backfilled events, of which will all be newer @@ -220,7 +228,9 @@ func (t *missingStateReq) processEventWithMissingState( SendAsServer: api.DoNotSendToOtherServers, }) if err != nil { - return fmt.Errorf("t.inputer.processRoomEvent: %w", err) + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) + } } } @@ -395,20 +405,11 @@ retryAllowedState: // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - needed := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{e}) - // query latest events (our trusted forward extremities) - req := api.QueryLatestEventsAndStateRequest{ - RoomID: e.RoomID(), - StateToFetch: needed.Tuples(), - } - var res api.QueryLatestEventsAndStateResponse - if err = t.queryer.QueryLatestEventsAndState(ctx, &req, &res); err != nil { - logger.WithError(err).Warn("Failed to query latest events") - return nil, false, err - } - latestEvents := make([]string, len(res.LatestEvents)) - for i, ev := range res.LatestEvents { - latestEvents[i] = res.LatestEvents[i].EventID + + latest := t.db.LatestEvents() + latestEvents := make([]string, len(latest)) + for i, ev := range latest { + latestEvents[i] = ev.EventID t.hadEvent(ev.EventID) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6b4cb5816..845533032 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -149,7 +149,8 @@ func (r *Queryer) QueryMissingAuthPrevEvents( } for _, prevEventID := range request.PrevEventIDs { - if state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}); err != nil || len(state) == 0 { + state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 { response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) } } diff --git a/roomserver/types/types.go b/roomserver/types/types.go index d7e03ad61..5e1eebe98 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -209,6 +209,12 @@ type MissingEventError string func (e MissingEventError) Error() string { return string(e) } +// A RejectedError is returned when an event is stored as rejected. The error +// contains the reason why. +type RejectedError string + +func (e RejectedError) Error() string { return string(e) } + // RoomInfo contains metadata about a room type RoomInfo struct { RoomNID RoomNID From a2b4777ae5c4aa06ce7933bbd1b251bd777788b2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 09:30:21 +0000 Subject: [PATCH 33/81] Update to matrix-org/gomatrixserverlib@a05e156fd8a0c7bd326cbdadfff2bc7b2a70b44a --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fc18ce07e..4759ed295 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 3f8a99f48..6d7da3717 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275 h1:f6Hh7D3EOTl1uUr76FiyHNA1h4pKBhcVUtyHbxn0hKA= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220204110702-c559d2019275/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0 h1:ZCD8xUM9ppUwW99SzXLOFwWLfdfYRKihj/CCDnMuYMw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 0e26662a552b671ef201490defbbea59df4250c1 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 13:45:48 +0000 Subject: [PATCH 34/81] Allow events to be un-rejected (#2159) * Allow un-rejecting an event later * SQL * Only un-reject, don't re-reject * Clarify ambiguous column reference --- roomserver/storage/postgres/events_table.go | 6 +++--- roomserver/storage/sqlite3/events_table.go | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 6c3847752..ece1d9e3c 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -71,10 +71,10 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( ` const insertEventSQL = "" + - "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + - " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + - " DO NOTHING" + + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + + " SET is_rejected = $8 WHERE e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index e1e6a597c..cef09fe60 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -49,7 +49,8 @@ const eventsSchema = ` const insertEventSQL = ` INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT DO NOTHING + ON CONFLICT DO UPDATE + SET is_rejected = $8 WHERE is_rejected = 0 RETURNING event_nid, state_snapshot_nid; ` From 8a1dfffe3dc7e720964820301f58a7a9e50d5ee6 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 16:16:01 +0000 Subject: [PATCH 35/81] Various updates for renaming the `master` branch to `main` --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- .github/workflows/codeql-analysis.yml | 30 +++++++++++++-------------- .github/workflows/tests.yml | 6 +++--- build.sh | 2 +- docs/CONTRIBUTING.md | 6 +++--- docs/p2p.md | 2 +- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1204582e2..8014e9414 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,6 @@ -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/master/docs/CONTRIBUTING.md#sign-off) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: `Your Name ` diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a4ef8b395..de6c79ddc 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -2,9 +2,9 @@ name: "CodeQL" on: push: - branches: [master] + branches: [main] pull_request: - branches: [master] + branches: [main] jobs: analyze: @@ -14,21 +14,21 @@ jobs: strategy: fail-fast: false matrix: - language: ['go'] + language: ["go"] steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - fetch-depth: 2 + - name: Checkout repository + uses: actions/checkout@v2 + with: + fetch-depth: 2 - - run: git checkout HEAD^2 - if: ${{ github.event_name == 'pull_request' }} + - run: git checkout HEAD^2 + if: ${{ github.event_name == 'pull_request' }} - - name: Initialize CodeQL - uses: github/codeql-action/init@v1 - with: - languages: ${{ matrix.language }} + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad5a2660c..4a1720295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,7 +2,7 @@ name: Tests on: push: - branches: [ 'master' ] + branches: ["main"] pull_request: concurrency: @@ -33,7 +33,7 @@ jobs: path: dendrite # Attempt to check out the same branch of Complement as the PR. If it - # doesn't exist, fallback to master. + # doesn't exist, fallback to main. - name: Checkout complement shell: bash run: | @@ -68,4 +68,4 @@ jobs: name: Run Complement Tests env: COMPLEMENT_BASE_IMAGE: complement-dendrite:latest - working-directory: complement \ No newline at end of file + working-directory: complement diff --git a/build.sh b/build.sh index 8196fc653..700e6434f 100755 --- a/build.sh +++ b/build.sh @@ -7,7 +7,7 @@ if [ -d ".git" ] then export BUILD=`git rev-parse --short HEAD || ""` export BRANCH=`(git symbolic-ref --short HEAD | tr -d \/ ) || ""` - if [ "$BRANCH" = master ] + if [ "$BRANCH" = main ] then export BRANCH="" fi diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index ea4b2b27d..fe7127c76 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -37,7 +37,7 @@ If a job fails, click the "details" button and you should be taken to the job's logs. ![Click the details button on the failing build -step](https://raw.githubusercontent.com/matrix-org/dendrite/master/docs/images/details-button-location.jpg) +step](https://raw.githubusercontent.com/matrix-org/dendrite/main/docs/images/details-button-location.jpg) Scroll down to the failing step and you should see some log output. Scan the logs until you find what it's complaining about, fix it, submit a new commit, @@ -57,7 +57,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/master/docs/sytest.md#using-a-sytest-docker-image) +[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. @@ -94,4 +94,4 @@ For more general questions please use We ask that everyone who contributes to the project signs off their contributions, in accordance with the -[DCO](https://github.com/matrix-org/matrix-doc/blob/master/CONTRIBUTING.rst#sign-off). +[DCO](https://github.com/matrix-org/matrix-doc/blob/main/CONTRIBUTING.rst#sign-off). diff --git a/docs/p2p.md b/docs/p2p.md index e858ba114..4e9a50524 100644 --- a/docs/p2p.md +++ b/docs/p2p.md @@ -6,7 +6,7 @@ These are the instructions for setting up P2P Dendrite, current as of May 2020. #### Build -- The `master` branch has a WASM-only binary for dendrite: `./cmd/dendritejs`. +- The `main` branch has a WASM-only binary for dendrite: `./cmd/dendritejs`. - Build it and copy assets to riot-web. ``` $ ./build-dendritejs.sh From bb39149ff8c0a14f312aec694f93902c6f409cc0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 16:18:16 +0000 Subject: [PATCH 36/81] Fix DendriteJS dockerfile --- build/docker/DendriteJS.Dockerfile | 74 +++++++++++++++--------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/build/docker/DendriteJS.Dockerfile b/build/docker/DendriteJS.Dockerfile index e8d742b7e..5e1cffcad 100644 --- a/build/docker/DendriteJS.Dockerfile +++ b/build/docker/DendriteJS.Dockerfile @@ -9,9 +9,9 @@ FROM golang:1.14-alpine AS gobuild # Download and build dendrite WORKDIR /build -ADD https://github.com/matrix-org/dendrite/archive/master.tar.gz /build/master.tar.gz -RUN tar xvfz master.tar.gz -WORKDIR /build/dendrite-master +ADD https://github.com/matrix-org/dendrite/archive/main.tar.gz /build/main.tar.gz +RUN tar xvfz main.tar.gz +WORKDIR /build/dendrite-main RUN GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs @@ -21,7 +21,7 @@ RUN apt-get update && apt-get -y install python # Download riot-web and libp2p repos WORKDIR /build -ADD https://github.com/matrix-org/go-http-js-libp2p/archive/master.tar.gz /build/libp2p.tar.gz +ADD https://github.com/matrix-org/go-http-js-libp2p/archive/main.tar.gz /build/libp2p.tar.gz RUN tar xvfz libp2p.tar.gz ADD https://github.com/vector-im/element-web/archive/matthew/p2p.tar.gz /build/p2p.tar.gz RUN tar xvfz p2p.tar.gz @@ -31,21 +31,21 @@ WORKDIR /build/element-web-matthew-p2p RUN yarn install RUN ln -s /build/go-http-js-libp2p-master /build/element-web-matthew-p2p/node_modules/go-http-js-libp2p RUN (cd node_modules/go-http-js-libp2p && yarn install) -COPY --from=gobuild /build/dendrite-master/main.wasm ./src/vector/dendrite.wasm +COPY --from=gobuild /build/dendrite-main/main.wasm ./src/vector/dendrite.wasm # build it all RUN yarn build:p2p SHELL ["/bin/bash", "-c"] RUN echo $'\ -{ \n\ + { \n\ "default_server_config": { \n\ - "m.homeserver": { \n\ - "base_url": "https://p2p.riot.im", \n\ - "server_name": "p2p.riot.im" \n\ - }, \n\ - "m.identity_server": { \n\ - "base_url": "https://vector.im" \n\ - } \n\ + "m.homeserver": { \n\ + "base_url": "https://p2p.riot.im", \n\ + "server_name": "p2p.riot.im" \n\ + }, \n\ + "m.identity_server": { \n\ + "base_url": "https://vector.im" \n\ + } \n\ }, \n\ "disable_custom_urls": false, \n\ "disable_guests": true, \n\ @@ -55,57 +55,57 @@ RUN echo $'\ "integrations_ui_url": "https://scalar.vector.im/", \n\ "integrations_rest_url": "https://scalar.vector.im/api", \n\ "integrations_widgets_urls": [ \n\ - "https://scalar.vector.im/_matrix/integrations/v1", \n\ - "https://scalar.vector.im/api", \n\ - "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ - "https://scalar-staging.vector.im/api", \n\ - "https://scalar-staging.riot.im/scalar/api" \n\ + "https://scalar.vector.im/_matrix/integrations/v1", \n\ + "https://scalar.vector.im/api", \n\ + "https://scalar-staging.vector.im/_matrix/integrations/v1", \n\ + "https://scalar-staging.vector.im/api", \n\ + "https://scalar-staging.riot.im/scalar/api" \n\ ], \n\ "integrations_jitsi_widget_url": "https://scalar.vector.im/api/widgets/jitsi.html", \n\ "bug_report_endpoint_url": "https://riot.im/bugreports/submit", \n\ "defaultCountryCode": "GB", \n\ "showLabsSettings": false, \n\ "features": { \n\ - "feature_pinning": "labs", \n\ - "feature_custom_status": "labs", \n\ - "feature_custom_tags": "labs", \n\ - "feature_state_counters": "labs" \n\ + "feature_pinning": "labs", \n\ + "feature_custom_status": "labs", \n\ + "feature_custom_tags": "labs", \n\ + "feature_state_counters": "labs" \n\ }, \n\ "default_federate": true, \n\ "default_theme": "light", \n\ "roomDirectory": { \n\ - "servers": [ \n\ - "matrix.org" \n\ - ] \n\ + "servers": [ \n\ + "matrix.org" \n\ + ] \n\ }, \n\ "welcomeUserId": "", \n\ "piwik": { \n\ - "url": "https://piwik.riot.im/", \n\ - "whitelistedHSUrls": ["https://matrix.org"], \n\ - "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ - "siteId": 1 \n\ + "url": "https://piwik.riot.im/", \n\ + "whitelistedHSUrls": ["https://matrix.org"], \n\ + "whitelistedISUrls": ["https://vector.im", "https://matrix.org"], \n\ + "siteId": 1 \n\ }, \n\ "enable_presence_by_hs_url": { \n\ - "https://matrix.org": false, \n\ - "https://matrix-client.matrix.org": false \n\ + "https://matrix.org": false, \n\ + "https://matrix-client.matrix.org": false \n\ }, \n\ "settingDefaults": { \n\ - "breadcrumbs": true \n\ + "breadcrumbs": true \n\ } \n\ -}' > webapp/config.json + }' > webapp/config.json FROM nginx # Add "Service-Worker-Allowed: /" header so the worker can sniff traffic on this domain rather # than just the path this gets hosted under. NB this newline echo syntax only works on bash. SHELL ["/bin/bash", "-c"] RUN echo $'\ -server { \n\ + server { \n\ listen 80; \n\ add_header \'Service-Worker-Allowed\' \'/\'; \n\ location / { \n\ - root /usr/share/nginx/html; \n\ - index index.html index.htm; \n\ + root /usr/share/nginx/html; \n\ + index index.html index.htm; \n\ } \n\ -}' > /etc/nginx/conf.d/default.conf + }' > /etc/nginx/conf.d/default.conf RUN sed -i 's/}/ application\/wasm wasm;\n}/g' /etc/nginx/mime.types COPY --from=jsbuild /build/element-web-matthew-p2p/webapp /usr/share/nginx/html From a84f50f4fb0963e15b6aca60f658bed45b779127 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 16:49:49 +0000 Subject: [PATCH 37/81] Demote logging entry for backoff --- federationapi/queue/destinationqueue.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 1306e8588..09814b31f 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -297,7 +297,7 @@ func (oq *destinationQueue) backgroundSend() { // We haven't backed off yet, so wait for the suggested amount of // time. duration := time.Until(*until) - logrus.Warnf("Backing off %q for %s", oq.destination, duration) + logrus.Debugf("Backing off %q for %s", oq.destination, duration) oq.backingOff.Store(true) destinationQueueBackingOff.Inc() select { From 457a07eac5d668a0f04c273e086d321cab7ea640 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 8 Feb 2022 17:06:13 +0000 Subject: [PATCH 38/81] More relaxed auth event fetching (#2161) * Tweaks around auth event fetching * More tweaking --- roomserver/internal/input/input_events.go | 51 ++++++++++++++++------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 85189e476..bb35ade9c 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -195,9 +195,26 @@ func (r *Inputer) processRoomEvent( authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) for _, authEventID := range authEventIDs { if _, ok := knownEvents[authEventID]; !ok { - return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID) + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return commitTransaction, fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return commitTransaction, fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) } var softfail bool @@ -416,6 +433,10 @@ func (r *Inputer) fetchAuthEvents( return fmt.Errorf("no servers provided event auth for event ID %q, tried servers %v", event.EventID(), servers) } + // Reuse these to reduce allocations. + authEventNIDs := make([]types.EventNID, 0, 5) + isRejected := false +nextAuthEvent: for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( res.AuthEvents, gomatrixserverlib.TopologicalOrderByAuthEvents, @@ -424,36 +445,30 @@ func (r *Inputer) fetchAuthEvents( // need to store it again or do anything further with it, so just skip // over it rather than wasting cycles. if ev, ok := known[authEvent.EventID()]; ok && ev != nil { - continue + continue nextAuthEvent } // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { - continue + continue nextAuthEvent } // In order to store the new auth event, we need to know its auth chain // as NIDs for the `auth_event_nids` column. Let's see if we can find those. - authEventNIDs := make([]types.EventNID, 0, len(authEvent.AuthEventIDs())) + authEventNIDs = authEventNIDs[:0] for _, eventID := range authEvent.AuthEventIDs() { knownEvent, ok := known[eventID] if !ok { - return fmt.Errorf("missing auth event %s for %s", eventID, authEvent.EventID()) + continue nextAuthEvent } authEventNIDs = append(authEventNIDs, knownEvent.EventNID) } - // Let's take a note of the fact that we now know about this event. - if err := auth.AddEvent(authEvent); err != nil { - return fmt.Errorf("auth.AddEvent: %w", err) - } - // Check if the auth event should be rejected. - isRejected := false - if err := gomatrixserverlib.Allowed(authEvent, auth); err != nil { - isRejected = true + err := gomatrixserverlib.Allowed(authEvent, auth) + if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } @@ -463,6 +478,14 @@ func (r *Inputer) fetchAuthEvents( return fmt.Errorf("updater.StoreEvent: %w", err) } + // Let's take a note of the fact that we now know about this event for + // authenticating future events. + if !isRejected { + if err := auth.AddEvent(authEvent); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } + } + // Now we know about this event, it was stored and the signatures were OK. known[authEvent.EventID()] = &types.Event{ EventNID: eventNID, From 2771d93748380aa7dc21adca0ef690348d79f002 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Tue, 8 Feb 2022 18:13:38 +0100 Subject: [PATCH 39/81] Remove OutputKeyChangeEvent consumer on keyserver (#2160) * Remove keyserver consumer * Remove keyserver from eduserver * Directly upload device keys without eduserver * Add passing tests --- eduserver/api/input.go | 6 -- eduserver/eduserver.go | 1 - eduserver/input/input.go | 31 ------- eduserver/inthttp/client.go | 20 +---- eduserver/inthttp/server.go | 13 --- federationapi/consumers/roomserver.go | 1 - federationapi/routing/send.go | 44 ++++++--- keyserver/consumers/cross_signing.go | 123 -------------------------- keyserver/internal/cross_signing.go | 50 +++++------ keyserver/keyserver.go | 8 -- syncapi/internal/keychange.go | 2 + sytest-whitelist | 1 + 12 files changed, 59 insertions(+), 241 deletions(-) delete mode 100644 keyserver/consumers/cross_signing.go diff --git a/eduserver/api/input.go b/eduserver/api/input.go index 2fa253f4d..2aab107b2 100644 --- a/eduserver/api/input.go +++ b/eduserver/api/input.go @@ -100,10 +100,4 @@ type EDUServerInputAPI interface { request *InputReceiptEventRequest, response *InputReceiptEventResponse, ) error - - InputCrossSigningKeyUpdate( - ctx context.Context, - request *InputCrossSigningKeyUpdateRequest, - response *InputCrossSigningKeyUpdateResponse, - ) error } diff --git a/eduserver/eduserver.go b/eduserver/eduserver.go index febcf2864..9b7e21651 100644 --- a/eduserver/eduserver.go +++ b/eduserver/eduserver.go @@ -51,7 +51,6 @@ func NewInternalAPI( OutputTypingEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputTypingEvent), OutputSendToDeviceEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputSendToDeviceEvent), OutputReceiptEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), - OutputKeyChangeEventTopic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), ServerName: cfg.Matrix.ServerName, } } diff --git a/eduserver/input/input.go b/eduserver/input/input.go index 4f8ab3e34..e58f0dd34 100644 --- a/eduserver/input/input.go +++ b/eduserver/input/input.go @@ -23,7 +23,6 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/cache" - keyapi "github.com/matrix-org/dendrite/keyserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -40,8 +39,6 @@ type EDUServerInputAPI struct { OutputSendToDeviceEventTopic string // The kafka topic to output new receipt events to OutputReceiptEventTopic string - // The kafka topic to output new key change events to - OutputKeyChangeEventTopic string // kafka producer JetStream nats.JetStreamContext // Internal user query API @@ -80,34 +77,6 @@ func (t *EDUServerInputAPI) InputSendToDeviceEvent( return t.sendToDeviceEvent(ise) } -// InputCrossSigningKeyUpdate implements api.EDUServerInputAPI -func (t *EDUServerInputAPI) InputCrossSigningKeyUpdate( - ctx context.Context, - request *api.InputCrossSigningKeyUpdateRequest, - response *api.InputCrossSigningKeyUpdateResponse, -) error { - eventJSON, err := json.Marshal(&keyapi.DeviceMessage{ - Type: keyapi.TypeCrossSigningUpdate, - OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ - CrossSigningKeyUpdate: request.CrossSigningKeyUpdate, - }, - }) - if err != nil { - return err - } - - logrus.WithFields(logrus.Fields{ - "user_id": request.UserID, - }).Tracef("Producing to topic '%s'", t.OutputKeyChangeEventTopic) - - _, err = t.JetStream.PublishMsg(&nats.Msg{ - Subject: t.OutputKeyChangeEventTopic, - Header: nats.Header{}, - Data: eventJSON, - }) - return err -} - func (t *EDUServerInputAPI) sendTypingEvent(ite *api.InputTypingEvent) error { ev := &api.TypingEvent{ Type: gomatrixserverlib.MTyping, diff --git a/eduserver/inthttp/client.go b/eduserver/inthttp/client.go index 9a6f483c2..0690ed827 100644 --- a/eduserver/inthttp/client.go +++ b/eduserver/inthttp/client.go @@ -12,10 +12,9 @@ import ( // HTTP paths for the internal HTTP APIs const ( - EDUServerInputTypingEventPath = "/eduserver/input" - EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" - EDUServerInputReceiptEventPath = "/eduserver/receipt" - EDUServerInputCrossSigningKeyUpdatePath = "/eduserver/crossSigningKeyUpdate" + EDUServerInputTypingEventPath = "/eduserver/input" + EDUServerInputSendToDeviceEventPath = "/eduserver/sendToDevice" + EDUServerInputReceiptEventPath = "/eduserver/receipt" ) // NewEDUServerClient creates a EDUServerInputAPI implemented by talking to a HTTP POST API. @@ -69,16 +68,3 @@ func (h *httpEDUServerInputAPI) InputReceiptEvent( apiURL := h.eduServerURL + EDUServerInputReceiptEventPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } - -// InputCrossSigningKeyUpdate implements EDUServerInputAPI -func (h *httpEDUServerInputAPI) InputCrossSigningKeyUpdate( - ctx context.Context, - request *api.InputCrossSigningKeyUpdateRequest, - response *api.InputCrossSigningKeyUpdateResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputCrossSigningKeyUpdate") - defer span.Finish() - - apiURL := h.eduServerURL + EDUServerInputCrossSigningKeyUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} diff --git a/eduserver/inthttp/server.go b/eduserver/inthttp/server.go index a50ca84f9..a34943750 100644 --- a/eduserver/inthttp/server.go +++ b/eduserver/inthttp/server.go @@ -51,17 +51,4 @@ func AddRoutes(t api.EDUServerInputAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(EDUServerInputCrossSigningKeyUpdatePath, - httputil.MakeInternalAPI("inputCrossSigningKeyUpdate", func(req *http.Request) util.JSONResponse { - var request api.InputCrossSigningKeyUpdateRequest - var response api.InputCrossSigningKeyUpdateResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := t.InputCrossSigningKeyUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) } diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index e9862000a..60066bb2f 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 524fd510e..dd4fe13a8 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -382,20 +382,8 @@ func (t *txnReq) processEDUs(ctx context.Context) { } } case eduserverAPI.MSigningKeyUpdate: - var updatePayload eduserverAPI.CrossSigningKeyUpdate - if err := json.Unmarshal(e.Content, &updatePayload); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "user_id": updatePayload.UserID, - }).Debug("Failed to send signing key update to edu server") - continue - } - inputReq := &eduserverAPI.InputCrossSigningKeyUpdateRequest{ - CrossSigningKeyUpdate: updatePayload, - } - inputRes := &eduserverAPI.InputCrossSigningKeyUpdateResponse{} - if err := t.eduAPI.InputCrossSigningKeyUpdate(ctx, inputReq, inputRes); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal cross-signing update") - continue + if err := t.processSigningKeyUpdate(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process signing key update") } default: util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") @@ -403,6 +391,34 @@ func (t *txnReq) processEDUs(ctx context.Context) { } } +func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverlib.EDU) error { + var updatePayload eduserverAPI.CrossSigningKeyUpdate + if err := json.Unmarshal(e.Content, &updatePayload); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "user_id": updatePayload.UserID, + }).Debug("Failed to unmarshal signing key update") + return err + } + + keys := gomatrixserverlib.CrossSigningKeys{} + if updatePayload.MasterKey != nil { + keys.MasterKey = *updatePayload.MasterKey + } + if updatePayload.SelfSigningKey != nil { + keys.SelfSigningKey = *updatePayload.SelfSigningKey + } + uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ + CrossSigningKeys: keys, + UserID: updatePayload.UserID, + } + uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} + t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + if uploadRes.Error != nil { + return uploadRes.Error + } + return nil +} + // processReceiptEvent sends receipt events to the edu server func (t *txnReq) processReceiptEvent(ctx context.Context, userID, roomID, receiptType string, diff --git a/keyserver/consumers/cross_signing.go b/keyserver/consumers/cross_signing.go deleted file mode 100644 index aae69e960..000000000 --- a/keyserver/consumers/cross_signing.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" -) - -type OutputCrossSigningKeyUpdateConsumer struct { - ctx context.Context - keyDB storage.Database - keyAPI api.KeyInternalAPI - serverName string - jetstream nats.JetStreamContext - durable string - topic string -} - -func NewOutputCrossSigningKeyUpdateConsumer( - process *process.ProcessContext, - cfg *config.Dendrite, - js nats.JetStreamContext, - keyDB storage.Database, - keyAPI api.KeyInternalAPI, -) *OutputCrossSigningKeyUpdateConsumer { - // The keyserver both produces and consumes on the TopicOutputKeyChangeEvent - // topic. We will only produce events where the UserID matches our server name, - // and we will only consume events where the UserID does NOT match our server - // name (because the update came from a remote server). - s := &OutputCrossSigningKeyUpdateConsumer{ - ctx: process.Context(), - keyDB: keyDB, - jetstream: js, - durable: cfg.Global.JetStream.Durable("KeyServerCrossSigningConsumer"), - topic: cfg.Global.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), - keyAPI: keyAPI, - serverName: string(cfg.Global.ServerName), - } - - return s -} - -func (s *OutputCrossSigningKeyUpdateConsumer) Start() error { - return jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, - nats.DeliverAll(), nats.ManualAck(), - ) -} - -// onMessage is called in response to a message received on the -// key change events topic from the key server. -func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { - var m api.DeviceMessage - if err := json.Unmarshal(msg.Data, &m); err != nil { - logrus.WithError(err).Errorf("failed to read device message from key change topic") - return true - } - if m.OutputCrossSigningKeyUpdate == nil { - // This probably shouldn't happen but stops us from panicking if we come - // across an update that doesn't satisfy either types. - return true - } - switch m.Type { - case api.TypeCrossSigningUpdate: - return t.onCrossSigningMessage(m) - default: - return true - } -} - -func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { - output := m.CrossSigningKeyUpdate - _, host, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - logrus.WithError(err).Errorf("eduserver output log: user ID parse failure") - return true - } - if host == gomatrixserverlib.ServerName(s.serverName) { - // Ignore any messages that contain information about our own users, as - // they already originated from this server. - return true - } - uploadReq := &api.PerformUploadDeviceKeysRequest{ - UserID: output.UserID, - } - if output.MasterKey != nil { - uploadReq.MasterKey = *output.MasterKey - } - if output.SelfSigningKey != nil { - uploadReq.SelfSigningKey = *output.SelfSigningKey - } - uploadRes := &api.PerformUploadDeviceKeysResponse{} - s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes) - if uploadRes.Error != nil { - // If the error is due to a missing or invalid parameter then we'd might - // as well just acknowledge the message, because otherwise otherwise we'll - // just keep getting delivered a faulty message over and over again. - return uploadRes.Error.IsMissingParam || uploadRes.Error.IsInvalidParam - } - return true -} diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 1e1871b8b..527990cf9 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -219,25 +219,23 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } // Finally, generate a notification that we updated the keys. - if _, host, err := gomatrixserverlib.SplitID('@', req.UserID); err == nil && host == a.ThisServer { - update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: req.UserID, - } - if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { - update.MasterKey = &mk - } - if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { - update.SelfSigningKey = &ssk - } - if update.MasterKey == nil && update.SelfSigningKey == nil { - return - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return + update := eduserverAPI.CrossSigningKeyUpdate{ + UserID: req.UserID, + } + if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { + update.MasterKey = &mk + } + if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { + update.SelfSigningKey = &ssk + } + if update.MasterKey == nil && update.SelfSigningKey == nil { + return + } + if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } + return } } @@ -310,16 +308,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req // Finally, generate a notification that we updated the signatures. for userID := range req.Signatures { - if _, host, err := gomatrixserverlib.SplitID('@', userID); err == nil && host == a.ThisServer { - update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: userID, - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return + update := eduserverAPI.CrossSigningKeyUpdate{ + UserID: userID, + } + if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } + return } } } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 61ccc0303..bd36fd9f9 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/producers" @@ -65,12 +64,5 @@ func NewInternalAPI( } }() - keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer( - base.ProcessContext, base.Cfg, js, db, ap, - ) - if err := keyconsumer.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer") - } - return ap } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index fa1064b70..37a9e2d39 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -282,6 +282,8 @@ func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { if strings.Contains(string(ev.Content), `"join"`) { joinUserIDs = append(joinUserIDs, *ev.StateKey) + } else if strings.Contains(string(ev.Content), `"invite"`) { + joinUserIDs = append(joinUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"leave"`) { leaveUserIDs = append(leaveUserIDs, *ev.StateKey) } else if strings.Contains(string(ev.Content), `"ban"`) { diff --git a/sytest-whitelist b/sytest-whitelist index 7d26c610e..c6ce1daad 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -590,3 +590,4 @@ Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one +uploading self-signing key notifies over federation From b4687f2ed24ae4f397e039776118c6efee306fa9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Feb 2022 11:24:49 +0000 Subject: [PATCH 40/81] Fix storage bug in PSQL events table --- roomserver/storage/postgres/events_table.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index ece1d9e3c..c136f039a 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -74,7 +74,7 @@ const insertEventSQL = "" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.is_rejected = FALSE" + + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -192,7 +192,8 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) + err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, From cf447dd52a0015c2c5b10813ed11e59a3712607e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Feb 2022 11:41:21 +0000 Subject: [PATCH 41/81] Revert "Fix storage bug in PSQL events table" This reverts commit b4687f2ed24ae4f397e039776118c6efee306fa9. --- roomserver/storage/postgres/events_table.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c136f039a..ece1d9e3c 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -74,7 +74,7 @@ const insertEventSQL = "" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + + " SET is_rejected = $8 WHERE e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -192,8 +192,7 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - stmt := sqlutil.TxStmt(txn, s.insertEventStmt) - err := stmt.QueryRowContext( + err := s.insertEventStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, From ac25065a54149117761e7a1b471a9b742f920ebc Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Wed, 9 Feb 2022 13:11:43 +0100 Subject: [PATCH 42/81] Fix sytest `uploading signed devices gets propagated over federation` (#2162) * Remove unneeded logging * Add MasterKey & SelfSigningKey to update Avoid panic if signatures are not present * Add passing test * Revert "Add MasterKey & SelfSigningKey to update" This reverts commit 2c81b34884be8b5b875a33420c0f985b578d3fb8. * Send MasterKey & SelfSigningKey with update * Debugging * Remove delete() so we also query signingkeys --- federationapi/consumers/roomserver.go | 6 +----- keyserver/internal/cross_signing.go | 6 +++++- keyserver/internal/internal.go | 7 ++++++- sytest-whitelist | 1 + 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 60066bb2f..ac29f930b 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" @@ -113,11 +114,6 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) } } - case api.OutputTypeNewInviteEvent: - log.WithField("type", output.Type).Debug( - "received new invite, send device keys", - ) - case api.OutputTypeNewInboundPeek: if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { log.WithFields(log.Fields{ diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 527990cf9..bfb2037f8 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -308,8 +308,12 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req // Finally, generate a notification that we updated the signatures. for userID := range req.Signatures { + masterKey := queryRes.MasterKeys[userID] + selfSigningKey := queryRes.SelfSigningKeys[userID] update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: userID, + UserID: userID, + MasterKey: &masterKey, + SelfSigningKey: &selfSigningKey, } if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 259249217..2536c1f76 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -326,8 +326,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if err = json.Unmarshal(key, &deviceKey); err != nil { continue } + if deviceKey.Signatures == nil { + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := deviceKey.Signatures[sourceUserID]; !ok { + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } } @@ -447,7 +453,6 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( for userID, deviceIDs := range devKeys { if len(deviceIDs) == 0 { userIDsForAllDevices[userID] = struct{}{} - delete(devKeys, userID) } } // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing diff --git a/sytest-whitelist b/sytest-whitelist index c6ce1daad..04b1bbf36 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -591,3 +591,4 @@ Can receive redactions from regular users over federation in room version 9 Forward extremities remain so even after the next events are populated as outliers If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation +uploading signed devices gets propagated over federation From cc688a9a386f48e38687a697b50f9be7d2b06fb0 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Wed, 9 Feb 2022 15:46:52 +0100 Subject: [PATCH 43/81] Avoid unnecessary logs and marshaling (#2167) Co-authored-by: kegsay --- federationapi/consumers/eduserver.go | 2 +- federationapi/consumers/keychange.go | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/federationapi/consumers/eduserver.go b/federationapi/consumers/eduserver.go index 22fedbeb4..1f81fa258 100644 --- a/federationapi/consumers/eduserver.go +++ b/federationapi/consumers/eduserver.go @@ -134,7 +134,7 @@ func (t *OutputEDUConsumer) onSendToDeviceEvent(ctx context.Context, msg *nats.M return true } - log.Infof("Sending send-to-device message into %q destination queue", destServerName) + log.Debugf("Sending send-to-device message into %q destination queue", destServerName) if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { log.WithError(err).Error("failed to send EDU") return false diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 1ec9f4c18..22dbc32da 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -127,6 +127,9 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { return true } + if len(destinations) == 0 { + return true + } // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: gomatrixserverlib.MDeviceListUpdate, @@ -146,7 +149,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { return true } - logger.Infof("Sending device list update message to %q", destinations) + logger.Debugf("Sending device list update message to %q", destinations) err = t.queues.SendEDU(edu, t.serverName, destinations) return err == nil } @@ -181,6 +184,10 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } + if len(destinations) == 0 { + return true + } + // Pack the EDU and marshal it edu := &gomatrixserverlib.EDU{ Type: eduserverAPI.MSigningKeyUpdate, @@ -191,7 +198,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } - logger.Infof("Sending cross-signing update message to %q", destinations) + logger.Debugf("Sending cross-signing update message to %q", destinations) err = t.queues.SendEDU(edu, t.serverName, destinations) return err == nil } From aa5c3b88dea207410461820ee480b002d185aa54 Mon Sep 17 00:00:00 2001 From: kegsay Date: Wed, 9 Feb 2022 20:31:24 +0000 Subject: [PATCH 44/81] Unmarshal events at the Dendrite level not GMSL level (#2164) * Use new event json types in gmsl * Fix EventJSON to actually unmarshal events * Update GMSL * Bump GMSL and improve error messages * Send back the correct RespState * Update GMSL --- federationapi/internal/perform.go | 20 ++++-- federationapi/inthttp/client.go | 9 +-- federationapi/routing/eventauth.go | 2 +- federationapi/routing/invite.go | 4 +- federationapi/routing/join.go | 4 +- federationapi/routing/missingevents.go | 2 +- federationapi/routing/peek.go | 4 +- federationapi/routing/state.go | 32 ++++----- federationapi/routing/threepid.go | 7 +- go.mod | 8 +-- go.sum | 16 ++--- roomserver/api/wrapper.go | 9 +-- roomserver/internal/input/input_events.go | 2 +- roomserver/internal/input/input_missing.go | 79 ++++++++++++++-------- setup/mscs/msc2836/msc2836.go | 67 +++++++++++------- 15 files changed, 158 insertions(+), 107 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 7850f206c..c51ecf146 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -201,7 +201,6 @@ func (r *FederationInternalAPI) performJoinUsingServer( context.Background(), serverName, event, - respMakeJoin.RoomVersion, ) if err != nil { r.statistics.ForServer(serverName).Failure() @@ -209,9 +208,11 @@ func (r *FederationInternalAPI) performJoinUsingServer( } r.statistics.ForServer(serverName).Success() + authEvents := respSendJoin.AuthEvents.UntrustedEvents(respMakeJoin.RoomVersion) + // Sanity-check the join response to ensure that it has a create // event, that the room version is known, etc. - if err = sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil { + if err = sanityCheckAuthChain(authEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } @@ -225,6 +226,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( var respState *gomatrixserverlib.RespState respState, err = respSendJoin.Check( context.Background(), + respMakeJoin.RoomVersion, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), @@ -392,12 +394,13 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( ctx = context.Background() respState := respPeek.ToRespState() + authEvents := respState.AuthEvents.UntrustedEvents(respPeek.RoomVersion) // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - if err = sanityCheckAuthChain(respState.AuthEvents); err != nil { + if err = sanityCheckAuthChain(authEvents); err != nil { return fmt.Errorf("sanityCheckAuthChain: %w", err) } - if err = respState.Check(ctx, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil { + if err = respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -549,10 +552,15 @@ func (r *FederationInternalAPI) PerformInvite( inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) if err != nil { - return fmt.Errorf("r.federation.SendInviteV2: %w", err) + return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } + logrus.Infof("GOT INVITE RESPONSE %s", string(inviteRes.Event)) - response.Event = inviteRes.Event.Headered(request.RoomVersion) + inviteEvent, err := inviteRes.Event.UntrustedEvent(request.RoomVersion) + if err != nil { + return fmt.Errorf("r.federation.SendInviteV2 failed to decode event response: %w", err) + } + response.Event = inviteEvent.Headered(request.RoomVersion) return nil } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index a65df906f..f9b2a33d2 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -387,14 +387,7 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( if request.Err != nil { return res, request.Err } - res.Events = make([]*gomatrixserverlib.Event, 0, len(request.Res.Events)) - for _, js := range request.Res.Events { - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(js, roomVersion) - if err != nil { - return res, err - } - res.Events = append(res.Events, ev) - } + res.Events = request.Res.Events return res, nil } diff --git a/federationapi/routing/eventauth.go b/federationapi/routing/eventauth.go index d92b66f4b..0a03a0cb4 100644 --- a/federationapi/routing/eventauth.go +++ b/federationapi/routing/eventauth.go @@ -65,7 +65,7 @@ func GetEventAuth( return util.JSONResponse{ Code: http.StatusOK, JSON: gomatrixserverlib.RespEventAuth{ - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), }, } } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 468659651..58bf99f4a 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -178,12 +178,12 @@ func processInvite( if isInviteV2 { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInviteV2{Event: &signedEvent}, + JSON: gomatrixserverlib.RespInviteV2{Event: signedEvent.JSON()}, } } else { return util.JSONResponse{ Code: http.StatusOK, - JSON: gomatrixserverlib.RespInvite{Event: &signedEvent}, + JSON: gomatrixserverlib.RespInvite{Event: signedEvent.JSON()}, } } default: diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 7f8d31505..495b8c914 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -351,8 +351,8 @@ func SendJoin( return util.JSONResponse{ Code: http.StatusOK, JSON: gomatrixserverlib.RespSendJoin{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(stateAndAuthChainResponse.AuthChainEvents), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents), Origin: cfg.Matrix.ServerName, }, } diff --git a/federationapi/routing/missingevents.go b/federationapi/routing/missingevents.go index f79a2d2d8..dd3df7aa9 100644 --- a/federationapi/routing/missingevents.go +++ b/federationapi/routing/missingevents.go @@ -62,7 +62,7 @@ func GetMissingEvents( eventsResponse.Events = filterEvents(eventsResponse.Events, roomID) resp := gomatrixserverlib.RespMissingEvents{ - Events: gomatrixserverlib.UnwrapEventHeaders(eventsResponse.Events), + Events: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(eventsResponse.Events), } return util.JSONResponse{ diff --git a/federationapi/routing/peek.go b/federationapi/routing/peek.go index 511329997..827d1116d 100644 --- a/federationapi/routing/peek.go +++ b/federationapi/routing/peek.go @@ -88,8 +88,8 @@ func Peek( } respPeek := gomatrixserverlib.RespPeek{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(response.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.StateEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(response.AuthChainEvents), RoomVersion: response.RoomVersion, LatestEvent: response.LatestEvent.Unwrap(), RenewalInterval: renewalInterval, diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 128df6187..37cbb9d1e 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -35,12 +35,15 @@ func GetState( return *err } - state, err := getState(ctx, request, rsAPI, roomID, eventID) + stateEvents, authChain, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } - return util.JSONResponse{Code: http.StatusOK, JSON: state} + return util.JSONResponse{Code: http.StatusOK, JSON: &gomatrixserverlib.RespState{ + AuthEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(authChain), + StateEvents: gomatrixserverlib.NewEventJSONsFromHeaderedEvents(stateEvents), + }} } // GetStateIDs returns state event IDs & auth event IDs for the roomID, eventID @@ -55,13 +58,13 @@ func GetStateIDs( return *err } - state, err := getState(ctx, request, rsAPI, roomID, eventID) + stateEvents, authEvents, err := getState(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } - stateEventIDs := getIDsFromEvent(state.StateEvents) - authEventIDs := getIDsFromEvent(state.AuthEvents) + stateEventIDs := getIDsFromEvent(stateEvents) + authEventIDs := getIDsFromEvent(authEvents) return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.RespStateIDs{ StateEventIDs: stateEventIDs, @@ -97,18 +100,18 @@ func getState( rsAPI api.RoomserverInternalAPI, roomID string, eventID string, -) (*gomatrixserverlib.RespState, *util.JSONResponse) { +) (stateEvents, authEvents []*gomatrixserverlib.HeaderedEvent, errRes *util.JSONResponse) { event, resErr := fetchEvent(ctx, rsAPI, eventID) if resErr != nil { - return nil, resErr + return nil, nil, resErr } if event.RoomID() != roomID { - return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} } resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) if resErr != nil { - return nil, resErr + return nil, nil, resErr } var response api.QueryStateAndAuthChainResponse @@ -123,20 +126,17 @@ func getState( ) if err != nil { resErr := util.ErrorResponse(err) - return nil, &resErr + return nil, nil, &resErr } if !response.RoomExists { - return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} + return nil, nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} } - return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(response.StateEvents), - AuthEvents: gomatrixserverlib.UnwrapEventHeaders(response.AuthChainEvents), - }, nil + return response.StateEvents, response.AuthChainEvents, nil } -func getIDsFromEvent(events []*gomatrixserverlib.Event) []string { +func getIDsFromEvent(events []*gomatrixserverlib.HeaderedEvent) []string { IDs := make([]string, len(events)) for i := range events { IDs[i] = events[i].EventID() diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index b16c68d25..8ae7130c3 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -170,13 +170,18 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() } + inviteEvent, err := signedEvent.Event.UntrustedEvent(verRes.RoomVersion) + if err != nil { + util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") + return jsonerror.InternalServerError() + } // Send the event to the roomserver if err = api.SendEvents( httpReq.Context(), rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{ - signedEvent.Event.Headered(verRes.RoomVersion), + inviteEvent.Headered(verRes.RoomVersion), }, request.Origin(), cfg.Matrix.ServerName, diff --git a/go.mod b/go.mod index 4759ed295..a1dc04084 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335 github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 @@ -59,17 +59,17 @@ require ( github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect github.com/sirupsen/logrus v1.8.1 - github.com/tidwall/gjson v1.13.0 + github.com/tidwall/gjson v1.14.0 github.com/tidwall/sjson v1.2.4 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.2 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 + golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a golang.org/x/image v0.0.0-20211028202545-6944b10bf410 golang.org/x/mobile v0.0.0-20220112015953-858099ff7816 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect + golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 gopkg.in/h2non/bimg.v1 v1.1.5 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 6d7da3717..1483c792f 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0 h1:ZCD8xUM9ppUwW99SzXLOFwWLfdfYRKihj/CCDnMuYMw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220204112336-a05e156fd8a0/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335 h1:xzK9Q9VGqsZNGx5ANFOCWkJ8R+W1J2BOguxsVZw6m8M= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -1363,8 +1363,8 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.0 h1:6aeJ0bzojgWLa82gDQHcx3S0Lr/O51I9bJ5nv6JFx5w= +github.com/tidwall/gjson v1.14.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -1509,8 +1509,8 @@ golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8 h1:kACShD3qhmr/3rLmg1yXyt+N4HcwutKyPRB93s54TIU= -golang.org/x/crypto v0.0.0-20220126234351-aa10faf2a1f8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= +golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1734,8 +1734,8 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= +golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index e9b94e48c..012094c62 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -51,7 +51,7 @@ func SendEventWithState( state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { - outliers, err := state.Events() + outliers, err := state.Events(event.RoomVersion) if err != nil { return err } @@ -68,9 +68,10 @@ func SendEventWithState( }) } - stateEventIDs := make([]string, len(state.StateEvents)) - for i := range state.StateEvents { - stateEventIDs[i] = state.StateEvents[i].EventID() + stateEvents := state.StateEvents.UntrustedEvents(event.RoomVersion) + stateEventIDs := make([]string, len(stateEvents)) + for i := range stateEvents { + stateEventIDs[i] = stateEvents[i].EventID() } ires = append(ires, InputRoomEvent{ diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index bb35ade9c..774e71dd3 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -438,7 +438,7 @@ func (r *Inputer) fetchAuthEvents( isRejected := false nextAuthEvent: for _, authEvent := range gomatrixserverlib.ReverseTopologicalOrdering( - res.AuthEvents, + res.AuthEvents.UntrustedEvents(event.RoomVersion), gomatrixserverlib.TopologicalOrderByAuthEvents, ) { // If we already know about this event from the database then we don't diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 7a72b0381..497c049dc 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -18,6 +18,11 @@ import ( "github.com/sirupsen/logrus" ) +type parsedRespState struct { + AuthEvents []*gomatrixserverlib.Event + StateEvents []*gomatrixserverlib.Event +} + type missingStateReq struct { origin gomatrixserverlib.ServerName db *shared.RoomUpdater @@ -98,7 +103,7 @@ func (t *missingStateReq) processEventWithMissingState( // That's because the state will have been through state resolution once // already in QueryStateAfterEvent. trustworthy bool - *gomatrixserverlib.RespState + *parsedRespState } // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. @@ -125,7 +130,7 @@ func (t *missingStateReq) processEventWithMissingState( // 1. Ensures that the state is deduplicated fully for each state-key tuple // 2. Ensures that we pick the latest events from both sets, in the case that // one of the prev_events is quite a bit older than the others - resolvedState := &gomatrixserverlib.RespState{} + resolvedState := &parsedRespState{} switch len(states) { case 0: extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") @@ -140,16 +145,16 @@ func (t *missingStateReq) processEventWithMissingState( // local state snapshot which will already have been through state res), // use it as-is. There's no point in resolving it again. if states[0].trustworthy { - resolvedState = states[0].RespState + resolvedState = states[0].parsedRespState break } // Otherwise, if it isn't trustworthy (came from federation), run it through // state resolution anyway for safety, in case there are duplicates. fallthrough default: - respStates := make([]*gomatrixserverlib.RespState, len(states)) + respStates := make([]*parsedRespState, len(states)) for i := range states { - respStates[i] = states[i].RespState + respStates[i] = states[i].parsedRespState } // There's more than one previous state - run them all through state res t.roomsMu.Lock(e.RoomID()) @@ -169,7 +174,7 @@ func (t *missingStateReq) processEventWithMissingState( t.hadEventsMutex.Unlock() // Send outliers first so we can send the new backwards extremity without causing errors - outliers, err := resolvedState.Events() + outliers, err := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) if err != nil { return err } @@ -239,7 +244,7 @@ func (t *missingStateReq) processEventWithMissingState( // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) // added into the mix. -func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*gomatrixserverlib.RespState, bool, error) { +func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (*parsedRespState, bool, error) { // try doing all this locally before we resort to querying federation respState := t.lookupStateAfterEventLocally(ctx, roomID, eventID) if respState != nil { @@ -290,7 +295,7 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g return ev } -func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { +func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { var res api.QueryStateAfterEventsResponse err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ RoomID: roomID, @@ -345,7 +350,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room queryRes.Events = nil } - return &gomatrixserverlib.RespState{ + return &parsedRespState{ StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), AuthEvents: authEvents, } @@ -354,13 +359,13 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what // the server supports. func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( - *gomatrixserverlib.RespState, error) { + *parsedRespState, error) { // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } -func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*gomatrixserverlib.RespState, backwardsExtremity *gomatrixserverlib.Event) (*gomatrixserverlib.RespState, error) { +func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) { var authEventList []*gomatrixserverlib.Event var stateEventList []*gomatrixserverlib.Event for _, state := range states { @@ -379,7 +384,7 @@ retryAllowedState: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) switch err2.(type) { case verifySigError: - return &gomatrixserverlib.RespState{ + return &parsedRespState{ AuthEvents: authEventList, StateEvents: resolvedStateEvents, }, nil @@ -395,7 +400,7 @@ retryAllowedState: } return nil, err } - return &gomatrixserverlib.RespState{ + return &parsedRespState{ AuthEvents: authEventList, StateEvents: resolvedStateEvents, }, nil @@ -452,12 +457,21 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) - for i, ev := range missingResp.Events { - missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + missingEvents := make([]*gomatrixserverlib.Event, len(missingResp.Events)) + for i, evJSON := range missingResp.Events { + ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(evJSON, roomVersion) + if err != nil { + logger.WithError(err).WithField("event", string(evJSON)).Warn("NewEventFromUntrustedJSON: failed") + return nil, false, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } + missingEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() } // topologically sort and sanity check that we are making forward progress - newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) + newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingEvents, gomatrixserverlib.TopologicalOrderByPrevEvents) shouldHaveSomeEventIDs := e.PrevEventIDs() hasPrevEvent := false Event: @@ -498,29 +512,37 @@ Event: return newEvents, true, nil } -func (t *missingStateReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - respState *gomatrixserverlib.RespState, err error) { +func (t *missingStateReq) lookupMissingStateViaState( + ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (respState *parsedRespState, err error) { state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } // Check that the returned state is valid. - if err := state.Check(ctx, t.keys, nil); err != nil { + if err := state.Check(ctx, roomVersion, t.keys, nil); err != nil { return nil, err } + parsedState := &parsedRespState{ + AuthEvents: make([]*gomatrixserverlib.Event, len(state.AuthEvents)), + StateEvents: make([]*gomatrixserverlib.Event, len(state.StateEvents)), + } // Cache the results of this state lookup and deduplicate anything we already // have in the cache, freeing up memory. - for i, ev := range state.AuthEvents { - state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + // We load these as trusted as we called state.Check before which loaded them as untrusted. + for i, evJSON := range state.AuthEvents { + ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) + parsedState.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() } - for i, ev := range state.StateEvents { - state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + for i, evJSON := range state.StateEvents { + ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) + parsedState.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() } - return &state, nil + return parsedState, nil } func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - *gomatrixserverlib.RespState, error) { + *parsedRespState, error) { util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(ctx, t.origin, roomID, eventID) @@ -652,13 +674,14 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo return resp, err } -func (t *missingStateReq) createRespStateFromStateIDs(stateIDs gomatrixserverlib.RespStateIDs) ( - *gomatrixserverlib.RespState, error) { // nolint:unparam +func (t *missingStateReq) createRespStateFromStateIDs( + stateIDs gomatrixserverlib.RespStateIDs, +) (*parsedRespState, error) { // nolint:unparam t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() // create a RespState response using the response to /state_ids as a guide - respState := gomatrixserverlib.RespState{} + respState := parsedRespState{} for i := range stateIDs.StateEventIDs { ev, ok := t.haveEvents[stateIDs.StateEventIDs[i]] diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 8a35e4143..0af22c19a 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -82,9 +82,15 @@ type EventRelationshipResponse struct { Limited bool `json:"limited"` } -func toClientResponse(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +type MSC2836EventRelationshipsResponse struct { + gomatrixserverlib.MSC2836EventRelationshipsResponse + ParsedEvents []*gomatrixserverlib.Event + ParsedAuthChain []*gomatrixserverlib.Event +} + +func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: gomatrixserverlib.ToClientEvents(res.Events, gomatrixserverlib.FormatAll), + Events: gomatrixserverlib.ToClientEvents(res.ParsedEvents, gomatrixserverlib.FormatAll), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -210,7 +216,7 @@ func federatedEventRelationship( // add auth chain information requiredAuthEventsSet := make(map[string]bool) var requiredAuthEvents []string - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { for _, a := range ev.AuthEventIDs() { if requiredAuthEventsSet[a] { continue @@ -227,19 +233,24 @@ func federatedEventRelationship( // they may already have the auth events so don't fail this request util.GetLogger(ctx).WithError(err).Error("Failed to QueryAuthChain") } - res.AuthChain = make([]*gomatrixserverlib.Event, len(queryRes.AuthChain)) + res.AuthChain = make(gomatrixserverlib.EventJSONs, len(queryRes.AuthChain)) for i := range queryRes.AuthChain { - res.AuthChain[i] = queryRes.AuthChain[i].Unwrap() + res.AuthChain[i] = queryRes.AuthChain[i].JSON() + } + + res.Events = make(gomatrixserverlib.EventJSONs, len(res.ParsedEvents)) + for i := range res.ParsedEvents { + res.Events[i] = res.ParsedEvents[i].JSON() } return util.JSONResponse{ Code: 200, - JSON: res, + JSON: res.MSC2836EventRelationshipsResponse, } } -func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsResponse, *util.JSONResponse) { - var res gomatrixserverlib.MSC2836EventRelationshipsResponse +func (rc *reqCtx) process() (*MSC2836EventRelationshipsResponse, *util.JSONResponse) { + var res MSC2836EventRelationshipsResponse var returnEvents []*gomatrixserverlib.HeaderedEvent // Can the user see (according to history visibility) event_id? If no, reject the request, else continue. event := rc.getLocalEvent(rc.req.EventID) @@ -290,11 +301,11 @@ func (rc *reqCtx) process() (*gomatrixserverlib.MSC2836EventRelationshipsRespons ) returnEvents = append(returnEvents, events...) } - res.Events = make([]*gomatrixserverlib.Event, len(returnEvents)) + res.ParsedEvents = make([]*gomatrixserverlib.Event, len(returnEvents)) for i, ev := range returnEvents { // for each event, extract the children_count | hash and add it as unsigned data. rc.addChildMetadata(ev) - res.Events[i] = ev.Unwrap() + res.ParsedEvents[i] = ev.Unwrap() } res.Limited = remaining == 0 || walkLimited return &res, nil @@ -357,7 +368,7 @@ func (rc *reqCtx) fetchUnknownEvent(eventID, roomID string) *gomatrixserverlib.H continue } rc.injectResponseToRoomserver(res) - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { if ev.EventID() == eventID { return ev.Headered(ev.Version()) } @@ -384,7 +395,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen if rc.hasUnexploredChildren(parentID) { // we need to do a remote request to pull in the children as we are missing them locally. serversToQuery := rc.getServersForEventID(parentID) - var result *gomatrixserverlib.MSC2836EventRelationshipsResponse + var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, @@ -397,7 +408,12 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen if err != nil { util.GetLogger(rc.ctx).WithError(err).WithField("server", srv).Error("includeChildren: failed to call MSC2836EventRelationships") } else { - result = &res + mscRes := &MSC2836EventRelationshipsResponse{ + MSC2836EventRelationshipsResponse: res, + } + mscRes.ParsedEvents = res.Events.UntrustedEvents(rc.roomVersion) + mscRes.ParsedAuthChain = res.AuthChain.UntrustedEvents(rc.roomVersion) + result = mscRes break } } @@ -467,7 +483,7 @@ func walkThread( } // MSC2836EventRelationships performs an /event_relationships request to a remote server -func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { +func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, @@ -481,7 +497,12 @@ func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverli util.GetLogger(rc.ctx).WithError(err).Error("Failed to call MSC2836EventRelationships") return nil, err } - return &res, nil + mscRes := &MSC2836EventRelationshipsResponse{ + MSC2836EventRelationshipsResponse: res, + } + mscRes.ParsedEvents = res.Events.UntrustedEvents(ver) + mscRes.ParsedAuthChain = res.AuthChain.UntrustedEvents(ver) + return mscRes, nil } @@ -550,12 +571,12 @@ func (rc *reqCtx) getServersForEventID(eventID string) []gomatrixserverlib.Serve return serversToQuery } -func (rc *reqCtx) remoteEventRelationships(eventID string) *gomatrixserverlib.MSC2836EventRelationshipsResponse { +func (rc *reqCtx) remoteEventRelationships(eventID string) *MSC2836EventRelationshipsResponse { if rc.isFederatedRequest { return nil // we don't query remote servers for remote requests } serversToQuery := rc.getServersForEventID(eventID) - var res *gomatrixserverlib.MSC2836EventRelationshipsResponse + var res *MSC2836EventRelationshipsResponse var err error for _, srv := range serversToQuery { res, err = rc.MSC2836EventRelationships(eventID, srv, rc.roomVersion) @@ -577,7 +598,7 @@ func (rc *reqCtx) lookForEvent(eventID string) *gomatrixserverlib.HeaderedEvent if queryRes != nil { // inject all the events into the roomserver then return the event in question rc.injectResponseToRoomserver(queryRes) - for _, ev := range queryRes.Events { + for _, ev := range queryRes.ParsedEvents { if ev.EventID() == eventID && rc.req.RoomID == ev.RoomID() { return ev.Headered(ev.Version()) } @@ -619,12 +640,12 @@ func (rc *reqCtx) getLocalEvent(eventID string) *gomatrixserverlib.HeaderedEvent // injectResponseToRoomserver injects the events // into the roomserver as KindOutlier, with auth chains. -func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836EventRelationshipsResponse) { - var stateEvents []*gomatrixserverlib.Event +func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsResponse) { + var stateEvents gomatrixserverlib.EventJSONs var messageEvents []*gomatrixserverlib.Event - for _, ev := range res.Events { + for _, ev := range res.ParsedEvents { if ev.StateKey() != nil { - stateEvents = append(stateEvents, ev) + stateEvents = append(stateEvents, ev.JSON()) } else { messageEvents = append(messageEvents, ev) } @@ -633,7 +654,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *gomatrixserverlib.MSC2836Event AuthEvents: res.AuthChain, StateEvents: stateEvents, } - eventsInOrder, err := respState.Events() + eventsInOrder, err := respState.Events(rc.roomVersion) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to calculate order to send events in MSC2836EventRelationshipsResponse") return From 37cbe263ce89681b0aeb7fef30e05d6125df162f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Feb 2022 09:30:16 +0000 Subject: [PATCH 45/81] Fix transaction issues in events table in PSQL (#2165) * Revert "Revert "Fix storage bug in PSQL events table"" This reverts commit cf447dd52a0015c2c5b10813ed11e59a3712607e. * Membership updater to use updater * Fix membership updater to use transactions properly --- roomserver/internal/input/input_membership.go | 2 +- roomserver/storage/postgres/events_table.go | 5 +++-- roomserver/storage/shared/membership_updater.go | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index ff3ed7e5d..3953586b2 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -48,7 +48,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := r.DB.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index ece1d9e3c..c136f039a 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -74,7 +74,7 @@ const insertEventSQL = "" + "INSERT INTO roomserver_events AS e (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique DO UPDATE" + - " SET is_rejected = $8 WHERE e.is_rejected = FALSE" + + " SET is_rejected = $8 WHERE e.event_id = $4 AND e.is_rejected = FALSE" + " RETURNING event_nid, state_snapshot_nid" const selectEventSQL = "" + @@ -192,7 +192,8 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRowContext( + stmt := sqlutil.TxStmt(txn, s.insertEventStmt) + err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f1f589a31..66ac2f5b6 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } From 9130156b131ec5a1714f4a7f6e54f12657d655a1 Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 10 Feb 2022 09:37:46 +0000 Subject: [PATCH 46/81] Make the Complement Dockerfile use a fresh directory for runtime (#2168) --- build/scripts/Complement.Dockerfile | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 401695abf..a54fab1d4 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -2,6 +2,10 @@ FROM golang:1.16-stretch as build RUN apt-get update && apt-get install -y sqlite3 WORKDIR /build +# we will dump the binaries and config file to this location to ensure any local untracked files +# that come from the COPY . . file don't contaminate the build +RUN mkdir /dendrite + # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. COPY go.mod . @@ -9,9 +13,11 @@ COPY go.sum . RUN go mod download COPY . . -RUN go build ./cmd/dendrite-monolith-server -RUN go build ./cmd/generate-keys -RUN go build ./cmd/generate-config +RUN go build -o /dendrite ./cmd/dendrite-monolith-server +RUN go build -o /dendrite ./cmd/generate-keys +RUN go build -o /dendrite ./cmd/generate-config + +WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost From 2782ae3d5635de4eaa4b8b4ca5b1a289745ba554 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Feb 2022 10:05:14 +0000 Subject: [PATCH 47/81] Fix fetching missing state (#2163) * Check that we have a populated state snapshot when determining if we closed the gap * Do the same in the query API * Use HasState more opportunistically * Try to avoid falling down the hole of using a trustworthy but empty state snapshot for non-create events * Refactor missing state and make sure that we really solve the problem for the new event * Comments * Review comments * Tweak that check again * Tidy up that create check further * Fix build hopefully * Update sendOutliers to use OrderAuthAndStateEvents * Don't go out of bounds on missingEvents --- roomserver/internal/input/input_events.go | 23 +- roomserver/internal/input/input_missing.go | 316 +++++++++++++-------- roomserver/internal/query/query.go | 2 +- roomserver/storage/shared/room_updater.go | 6 + roomserver/storage/shared/storage.go | 2 + roomserver/types/types.go | 4 + 6 files changed, 227 insertions(+), 126 deletions(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 774e71dd3..873a051cd 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -255,13 +255,32 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + // Something went wrong with retrieving the missing state, so we can't + // really do anything with the event other than reject it at this point. isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) + } else if stateSnapshot != nil { + // We retrieved some state and we ended up having to call /state_ids for + // the new event in question (probably because closing the gap by using + // /get_missing_events didn't do what we hoped) so we'll instead overwrite + // the state snapshot with the newly resolved state. + missingPrev = false + input.HasState = true + input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents)) + for _, e := range stateSnapshot.StateEvents { + input.StateEventIDs = append(input.StateEventIDs, e.EventID()) + } } else { + // We retrieved some state and it would appear that rolling forward the + // state did everything we needed it to do, so we can just resolve the + // state for the event in the normal way. missingPrev = false } } else { + // We're missing prev events or state for the event, but for some reason + // we don't know any servers to ask. In this case we can't do anything but + // reject the event and hope that it gets unrejected later. isRejected = true rejectionErr = fmt.Errorf("missing prev events and no other servers to ask") } @@ -299,7 +318,7 @@ func (r *Inputer) processRoomEvent( return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } - if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { + if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 497c049dc..19771d4bd 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -40,9 +40,10 @@ type missingStateReq struct { // processEventWithMissingState is the entrypoint for a missingStateReq // request, as called from processRoomEvent. +// nolint:gocyclo func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, -) error { +) (*parsedRespState, error) { // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -68,15 +69,15 @@ func (t *missingStateReq) processEventWithMissingState( // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, isGapFilled, err := t.getMissingEvents(ctx, e, roomVersion) + newEvents, isGapFilled, prevStatesKnown, err := t.getMissingEvents(ctx, e, roomVersion) if err != nil { - return fmt.Errorf("t.getMissingEvents: %w", err) + return nil, fmt.Errorf("t.getMissingEvents: %w", err) } if len(newEvents) == 0 { - return fmt.Errorf("expected to find missing events but didn't") + return nil, fmt.Errorf("expected to find missing events but didn't") } if isGapFilled { - logger.Infof("gap filled by /get_missing_events, injecting %d new events", len(newEvents)) + logger.Infof("Gap filled by /get_missing_events, injecting %d new events", len(newEvents)) // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { @@ -88,82 +89,31 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) } } } - return nil } + // If we filled the gap *and* we know the state before the prev events + // then there's nothing else to do, we have everything we need to deal + // with the new event. + if isGapFilled && prevStatesKnown { + logger.Infof("Gap filled and state found for all prev events") + return nil, nil + } + + // Otherwise, if we've reached this point, it's possible that we've + // either not closed the gap, or we did but we still don't seem to + // know the events before the new event. Start by looking up the + // state at the event at the back of the gap and we'll try to roll + // forward the state first. backwardsExtremity := newEvents[0] newEvents = newEvents[1:] - type respState struct { - // A snapshot is considered trustworthy if it came from our own roomserver. - // That's because the state will have been through state resolution once - // already in QueryStateAfterEvent. - trustworthy bool - *parsedRespState - } - - // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. - // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query - // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. - var states []*respState - for _, prevEventID := range backwardsExtremity.PrevEventIDs() { - // Look up what the state is after the backward extremity. This will either - // come from the roomserver, if we know all the required events, or it will - // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if lerr != nil { - logger.WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return lerr - } - // Append the state onto the collected state. We'll run this through the - // state resolution next. - states = append(states, &respState{trustworthy, prevState}) - } - - // Now that we have collected all of the state from the prev_events, we'll - // run the state through the appropriate state resolution algorithm for the - // room if needed. This does a couple of things: - // 1. Ensures that the state is deduplicated fully for each state-key tuple - // 2. Ensures that we pick the latest events from both sets, in the case that - // one of the prev_events is quite a bit older than the others - resolvedState := &parsedRespState{} - switch len(states) { - case 0: - extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") - if !extremityIsCreate { - // There are no previous states and this isn't the beginning of the - // room - this is an error condition! - logger.Errorf("Failed to lookup any state after prev_events") - return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) - } - case 1: - // There's only one previous state - if it's trustworthy (came from a - // local state snapshot which will already have been through state res), - // use it as-is. There's no point in resolving it again. - if states[0].trustworthy { - resolvedState = states[0].parsedRespState - break - } - // Otherwise, if it isn't trustworthy (came from federation), run it through - // state resolution anyway for safety, in case there are duplicates. - fallthrough - default: - respStates := make([]*parsedRespState, len(states)) - for i := range states { - respStates[i] = states[i].parsedRespState - } - // There's more than one previous state - run them all through state res - t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) - t.roomsMu.Unlock(e.RoomID()) - if err != nil { - logger.WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err - } + resolvedState, err := t.lookupResolvedStateBeforeEvent(ctx, backwardsExtremity, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (backwards extremity): %w", err) } hadEvents := map[string]bool{} @@ -173,30 +123,37 @@ func (t *missingStateReq) processEventWithMissingState( } t.hadEventsMutex.Unlock() - // Send outliers first so we can send the new backwards extremity without causing errors - outliers, err := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if err != nil { - return err - } - var outlierRoomEvents []api.InputRoomEvent - for _, outlier := range outliers { - if hadEvents[outlier.EventID()] { - continue + sendOutliers := func(resolvedState *parsedRespState) error { + outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) + if oerr != nil { + return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) } - outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ - Kind: api.KindOutlier, - Event: outlier.Headered(roomVersion), - Origin: t.origin, - }) - } - // TODO: we could do this concurrently? - for _, ire := range outlierRoomEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) - if err != nil { - if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + var outlierRoomEvents []api.InputRoomEvent + for _, outlier := range outliers { + if hadEvents[outlier.EventID()] { + continue + } + outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ + Kind: api.KindOutlier, + Event: outlier.Headered(roomVersion), + Origin: t.origin, + }) + } + for _, ire := range outlierRoomEvents { + _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + } } } + return nil + } + + // Send outliers first so we can send the state along with the new backwards + // extremity without any missing auth events. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) } // Now send the backward extremity into the roomserver with the @@ -217,7 +174,7 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) } } @@ -234,12 +191,109 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) } } } - return nil + // Finally, check again if we know everything we need to know in order to + // make forward progress. If the prev state is known then we consider the + // rolled forward state to be sufficient — we now know all of the state + // before the prev events. If we don't then we need to look up the state + // before the new event as well, otherwise we will never make any progress. + if t.isPrevStateKnown(ctx, e) { + return nil, nil + } + + // If we still haven't got the state for the prev events then we'll go and + // ask the federation for it if needed. + resolvedState, err = t.lookupResolvedStateBeforeEvent(ctx, e, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (new event): %w", err) + } + + // Send the outliers for the retrieved state. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) + } + + // Then return the resolved state, for which the caller can replace the + // HasState with the event IDs to create a new state snapshot when we + // process the new event. + return resolvedState, nil +} + +func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { + type respState struct { + // A snapshot is considered trustworthy if it came from our own roomserver. + // That's because the state will have been through state resolution once + // already in QueryStateAfterEvent. + trustworthy bool + *parsedRespState + } + + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. + var states []*respState + for _, prevEventID := range e.PrevEventIDs() { + // Look up what the state is after the backward extremity. This will either + // come from the roomserver, if we know all the required events, or it will + // come from a remote server via /state_ids if not. + prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) + if err != nil { + return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err) + } + // Append the state onto the collected state. We'll run this through the + // state resolution next. + states = append(states, &respState{trustworthy, prevState}) + } + + // Now that we have collected all of the state from the prev_events, we'll + // run the state through the appropriate state resolution algorithm for the + // room if needed. This does a couple of things: + // 1. Ensures that the state is deduplicated fully for each state-key tuple + // 2. Ensures that we pick the latest events from both sets, in the case that + // one of the prev_events is quite a bit older than the others + resolvedState := &parsedRespState{} + switch len(states) { + case 0: + extremityIsCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + if !extremityIsCreate { + // There are no previous states and this isn't the beginning of the + // room - this is an error condition! + return nil, fmt.Errorf("expected %d states but got %d", len(e.PrevEventIDs()), len(states)) + } + case 1: + // There's only one previous state - if it's trustworthy (came from a + // local state snapshot which will already have been through state res), + // use it as-is. There's no point in resolving it again. Only trust a + // trustworthy state snapshot if it actually contains some state for all + // non-create events, otherwise we need to resolve what came from federation. + isCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + if states[0].trustworthy && (isCreate || len(states[0].StateEvents) > 0) { + resolvedState = states[0].parsedRespState + break + } + // Otherwise, if it isn't trustworthy (came from federation), run it through + // state resolution anyway for safety, in case there are duplicates. + fallthrough + default: + respStates := make([]*parsedRespState, len(states)) + for i := range states { + respStates[i] = states[i].parsedRespState + } + // There's more than one previous state - run them all through state res + var err error + t.roomsMu.Lock(e.RoomID()) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) + t.roomsMu.Unlock(e.RoomID()) + if err != nil { + return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err) + } + } + + return resolvedState, nil } // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) @@ -408,7 +462,7 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events -func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled bool, err error) { +func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) latest := t.db.LatestEvents() @@ -435,7 +489,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve if errors.Is(err, context.DeadlineExceeded) { select { case <-ctx.Done(): // the parent request context timed out - return nil, false, context.DeadlineExceeded + return nil, false, false, context.DeadlineExceeded default: // this request exceed its own timeout continue } @@ -448,7 +502,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, len(t.servers), ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } @@ -457,17 +511,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) - missingEvents := make([]*gomatrixserverlib.Event, len(missingResp.Events)) - for i, evJSON := range missingResp.Events { - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(evJSON, roomVersion) - if err != nil { - logger.WithError(err).WithField("event", string(evJSON)).Warn("NewEventFromUntrustedJSON: failed") - return nil, false, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - missingEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) + for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { + missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()) } // topologically sort and sanity check that we are making forward progress @@ -489,27 +535,51 @@ Event: "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } } if len(newEvents) == 0 { - return nil, false, nil // TODO: error instead? + return nil, false, false, nil // TODO: error instead? } - // now check if we can fill the gap. Look to see if we have state snapshot IDs for the earliest event earliestNewEvent := newEvents[0] - if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 { - if earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { - // we got to the beginning of the room so there will be no state! It's all good we can process this - return newEvents, true, nil - } - // we don't have the state at this earliest event from /g_m_e so we won't have state for later events either - return newEvents, false, nil + + // If we retrieved back to the beginning of the room then there's nothing else + // to do - we closed the gap. + if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + return newEvents, true, t.isPrevStateKnown(ctx, e), nil } - // StateAtEventIDs returned some kind of state for the earliest event so we can fill in the gap! - return newEvents, true, nil + + // If our backward extremity was not a known event to us then we obviously didn't + // close the gap. + if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 && state[0].BeforeStateSnapshotNID == 0 { + return newEvents, false, false, nil + } + + // At this point we are satisfied that we know the state both at the earliest + // retrieved event and at the prev events of the new event. + return newEvents, true, t.isPrevStateKnown(ctx, e), nil +} + +func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserverlib.Event) bool { + expected := len(e.PrevEventIDs()) + state, err := t.db.StateAtEventIDs(ctx, e.PrevEventIDs()) + if err != nil || len(state) != expected { + // We didn't get as many state snapshots as we expected, or there was an error, + // so we haven't completely solved the problem for the new event. + return false + } + // Check to see if we have a populated state snapshot for all of the prev events. + for _, stateAtEvent := range state { + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // One of the prev events still has unknown state, so we haven't really + // solved the problem. + return false + } + } + return true } func (t *missingStateReq) lookupMissingStateViaState( diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 845533032..05cd686f4 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -150,7 +150,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents( for _, prevEventID := range request.PrevEventIDs { state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 { + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) } } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index bb9f5dc62..fc75a2606 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -187,6 +187,12 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } +func (u *RoomUpdater) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs) +} + func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 127cd1f52..8319de265 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -603,6 +603,8 @@ func (d *Database) storeEvent( if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } else if err != nil { + return fmt.Errorf("d.EventsTable.InsertEvent: %w", err) } if err != nil { return fmt.Errorf("d.EventsTable.SelectEvent: %w", err) diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 5e1eebe98..5d52ccfcd 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -83,6 +83,10 @@ type StateKeyTuple struct { EventStateKeyNID EventStateKeyNID } +func (a StateKeyTuple) IsCreate() bool { + return a.EventTypeNID == MRoomCreateNID && a.EventStateKeyNID == EmptyStateKeyNID +} + // LessThan returns true if this state key is less than the other state key. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateKeyTuple) LessThan(b StateKeyTuple) bool { From 432c35a307860de029c73a69789421a9e7a60eb9 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Thu, 10 Feb 2022 11:05:37 +0100 Subject: [PATCH 48/81] Allow user to forget a room, even if they never were a member (#2166) * Allow user to forget a room, even if they never were a member * Return "M_UNKNOWN" as per the spec Co-authored-by: kegsay --- clientapi/routing/membership.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 4ce820797..58f187608 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -17,6 +17,7 @@ package routing import ( "context" "errors" + "fmt" "net/http" "time" @@ -459,13 +460,7 @@ func SendForget( if membershipRes.IsInRoom { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden("user is still a member of the room"), - } - } - if !membershipRes.HasBeenInRoom { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.Forbidden("user did not belong to room"), + JSON: jsonerror.Unknown(fmt.Sprintf("User %s is in room %s", device.UserID, roomID)), } } From c36e4546c36a3381814cd72930349a0df21b1dd4 Mon Sep 17 00:00:00 2001 From: tommie Date: Thu, 10 Feb 2022 11:27:26 +0100 Subject: [PATCH 49/81] Support for `m.login.token` (#2014) * Add GOPATH to PATH in find-lint.sh. The user doesn't necessarily have it in PATH. * Refactor LoginTypePassword and Type to support m.login.token and m.login.sso. For login token: * m.login.token will require deleting the token after completeAuth has generated an access token, so a cleanup function is returned by Type.Login. * Allowing different login types will require parsing the /login body twice: first to extract the "type" and then the type-specific parsing. Thus, we will have to buffer the request JSON in /login, like UserInteractive already does. For SSO: * NewUserInteractive will have to also use GetAccountByLocalpart. It makes more sense to just pass a (narrowed-down) accountDB interface to it than adding more function pointers. Code quality: * Passing around (and down-casting) interface{} for login request types has drawbacks in terms of type-safety, and no inherent benefits. We always decode JSON anyway. Hence renaming to Type.LoginFromJSON. Code that directly uses LoginTypePassword with parsed data can still use Login. * Removed a TODO for SSO. This is already tracked in #1297. * httputil.UnmarshalJSON is useful because it returns a JSONResponse. This change is intended to have no functional changes. * Support login tokens in User API. This adds full lifecycle functions for login tokens: create, query, delete. * Support m.login.token in /login. * Fixes for PR review. * Set @matrix-org/dendrite-core as repository code owner * Return event NID from `StoreEvent`, match PSQL vs SQLite behaviour, tweak backfill persistence (#2071) Co-authored-by: kegsay Co-authored-by: Neil Alexander --- build/scripts/find-lint.sh | 2 +- clientapi/auth/auth.go | 1 + clientapi/auth/authtypes/logintypes.go | 1 + clientapi/auth/login.go | 83 ++++++++ clientapi/auth/login_test.go | 194 ++++++++++++++++++ clientapi/auth/login_token.go | 83 ++++++++ clientapi/auth/password.go | 20 +- clientapi/auth/user_interactive.go | 43 ++-- clientapi/auth/user_interactive_test.go | 8 +- clientapi/httputil/httputil.go | 4 + clientapi/routing/login.go | 16 +- clientapi/routing/routing.go | 2 +- userapi/api/api.go | 2 + userapi/api/api_logintoken.go | 69 +++++++ userapi/api/api_trace_logintoken.go | 39 ++++ userapi/internal/api_logintoken.go | 78 +++++++ userapi/inthttp/client_logintoken.go | 65 ++++++ userapi/inthttp/server.go | 2 + userapi/inthttp/server_logintoken.go | 68 ++++++ userapi/storage/devices/interface.go | 11 + .../devices/postgres/logintoken_table.go | 93 +++++++++ userapi/storage/devices/postgres/storage.go | 72 ++++++- .../devices/sqlite3/logintoken_table.go | 93 +++++++++ userapi/storage/devices/sqlite3/storage.go | 75 ++++++- userapi/storage/devices/storage.go | 10 +- userapi/storage/devices/storage_wasm.go | 4 +- userapi/userapi.go | 21 +- userapi/userapi_test.go | 165 +++++++++++++-- 28 files changed, 1244 insertions(+), 80 deletions(-) create mode 100644 clientapi/auth/login.go create mode 100644 clientapi/auth/login_test.go create mode 100644 clientapi/auth/login_token.go create mode 100644 userapi/api/api_logintoken.go create mode 100644 userapi/api/api_trace_logintoken.go create mode 100644 userapi/internal/api_logintoken.go create mode 100644 userapi/inthttp/client_logintoken.go create mode 100644 userapi/inthttp/server_logintoken.go create mode 100644 userapi/storage/devices/postgres/logintoken_table.go create mode 100644 userapi/storage/devices/sqlite3/logintoken_table.go diff --git a/build/scripts/find-lint.sh b/build/scripts/find-lint.sh index af87e14d7..e3564ae38 100755 --- a/build/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -33,7 +33,7 @@ echo "Looking for lint..." # Capture exit code to ensure go.{mod,sum} is restored before exiting exit_code=0 -golangci-lint run $args || exit_code=1 +PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 # Restore go.{mod,sum} mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index c850bf91e..575c5377f 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -42,6 +42,7 @@ type DeviceDatabase interface { type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index da0324251..f01e48f80 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -10,4 +10,5 @@ const ( LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" + LoginTypeToken = "m.login.token" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go new file mode 100644 index 000000000..1c14c6fbd --- /dev/null +++ b/clientapi/auth/login.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginFromJSONReader performs authentication given a login request body reader and +// some context. It returns the basic login information and a cleanup function to be +// called after authorization has completed, with the result of the authorization. +// If the final return value is non-nil, an error occurred and the cleanup function +// is nil. +func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { + reqBytes, err := ioutil.ReadAll(r) + if err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var header struct { + Type string `json:"type"` + } + if err := json.Unmarshal(reqBytes, &header); err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var typ Type + switch header.Type { + case authtypes.LoginTypePassword: + typ = &LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + case authtypes.LoginTypeToken: + typ = &LoginTypeToken{ + UserAPI: userAPI, + Config: cfg, + } + default: + err := util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + } + return nil, nil, &err + } + + return typ.LoginFromJSON(ctx, reqBytes) +} + +// UserInternalAPIForLogin contains the aspects of UserAPI required for logging in. +type UserInternalAPIForLogin interface { + uapi.LoginTokenInternalAPI +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go new file mode 100644 index 000000000..e295f8f07 --- /dev/null +++ b/clientapi/auth/login_test.go @@ -0,0 +1,194 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "database/sql" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +func TestLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantUsername string + WantDeviceID string + WantDeletedTokens []string + }{ + { + Name: "passwordWorks", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "herpassword", + "device_id": "adevice" + }`, + WantUsername: "alice", + WantDeviceID: "adevice", + }, + { + Name: "tokenWorks", + Body: `{ + "type": "m.login.token", + "token": "atoken", + "device_id": "adevice" + }`, + WantUsername: "@auser:example.com", + WantDeviceID: "adevice", + WantDeletedTokens: []string{"atoken"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if err != nil { + t.Fatalf("LoginFromJSONReader failed: %+v", err) + } + cleanup(ctx, &util.JSONResponse{Code: http.StatusOK}) + + if login.Username() != tst.WantUsername { + t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername) + } + + if login.DeviceID == nil { + if tst.WantDeviceID != "" { + t.Errorf("DeviceID: got %v, want %q", login.DeviceID, tst.WantDeviceID) + } + } else { + if *login.DeviceID != tst.WantDeviceID { + t.Errorf("DeviceID: got %q, want %q", *login.DeviceID, tst.WantDeviceID) + } + } + + if !reflect.DeepEqual(userAPI.DeletedTokens, tst.WantDeletedTokens) { + t.Errorf("DeletedTokens: got %+v, want %+v", userAPI.DeletedTokens, tst.WantDeletedTokens) + } + }) + } +} + +func TestBadLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantErrCode string + }{ + {Name: "empty", WantErrCode: "M_BAD_JSON"}, + { + Name: "badUnmarshal", + Body: `badsyntaxJSON`, + WantErrCode: "M_BAD_JSON", + }, + { + Name: "badPassword", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "invalidpassword", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badToken", + Body: `{ + "type": "m.login.token", + "token": "invalidtoken", + "device_id": "adevice" + }`, + WantErrCode: "M_FORBIDDEN", + }, + { + Name: "badType", + Body: `{ + "type": "m.login.invalid", + "device_id": "adevice" + }`, + WantErrCode: "M_INVALID_ARGUMENT_VALUE", + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if errRes == nil { + cleanup(ctx, nil) + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } + }) + } +} + +type fakeAccountDB struct { + AccountDatabase +} + +func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { + if password == "invalidpassword" { + return nil, sql.ErrNoRows + } + + return &uapi.Account{}, nil +} + +type fakeUserInternalAPI struct { + UserInternalAPIForLogin + + DeletedTokens []string +} + +func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { + ua.DeletedTokens = append(ua.DeletedTokens, req.Token) + return nil +} + +func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { + if req.Token == "invalidtoken" { + return nil + } + + res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"} + return nil +} diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go new file mode 100644 index 000000000..845eb5de9 --- /dev/null +++ b/clientapi/auth/login_token.go @@ -0,0 +1,83 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeToken struct { + UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeToken) Name() string { + return authtypes.LoginTypeToken +} + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + var res uapi.QueryLoginTokenResponse + if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, nil, &jsonErr + } + if res.Data == nil { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("invalid login token"), + } + } + + r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.User = res.Data.UserID + + cleanup := func(ctx context.Context, authRes *util.JSONResponse) { + if authRes == nil { + util.GetLogger(ctx).Error("No JSONResponse provided to LoginTokenType cleanup function") + return + } + if authRes.Code == http.StatusOK { + var res uapi.PerformLoginTokenDeletionResponse + if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed") + } + } + } + return &r.Login, cleanup, nil +} + +// loginTokenRequest struct to hold the possible parameters from an HTTP request. +type loginTokenRequest struct { + Login + Token string `json:"token"` +} diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 9179d8da1..18cf94979 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -20,6 +20,8 @@ import ( "net/http" "strings" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -41,16 +43,26 @@ type LoginTypePassword struct { } func (t *LoginTypePassword) Name() string { - return "m.login.password" + return authtypes.LoginTypePassword } -func (t *LoginTypePassword) Request() interface{} { - return &PasswordRequest{} +func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r PasswordRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + login, err := t.Login(ctx, &r) + if err != nil { + return nil, nil, err + } + + return login, func(context.Context, *util.JSONResponse) {}, nil } func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := r.Username() + username := strings.ToLower(r.Username()) if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 30469fc47..9cab7956c 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -32,22 +32,24 @@ import ( type Type interface { // Name returns the name of the auth type e.g `m.login.password` Name() string - // Request returns a pointer to a new request body struct to unmarshal into. - Request() interface{} // Login with the auth type, returning an error response on failure. // Not all types support login, only m.login.password and m.login.token // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login - // `req` is guaranteed to be the type returned from Request() // This function will be called when doing login and when doing 'sudo' style // actions e.g deleting devices. The response must be a 401 as per: // "If the homeserver decides that an attempt on a stage was unsuccessful, but the // client may make a second attempt, it returns the same HTTP status 401 response as above, // with the addition of the standard errcode and error fields describing the error." - Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) + // + // The returned cleanup function must be non-nil on success, and will be called after + // authorization has been completed. Its argument is the final result of authorization. + LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse) // TODO: Extend to support Register() flow // Register(ctx context.Context, sessionID string, req interface{}) } +type LoginCleanupFunc func(context.Context, *util.JSONResponse) + // LoginIdentifier represents identifier types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types type LoginIdentifier struct { @@ -61,11 +63,8 @@ type LoginIdentifier struct { // Login represents the shared fields used in all forms of login/sudo endpoints. type Login struct { - Type string `json:"type"` - Identifier LoginIdentifier `json:"identifier"` - User string `json:"user"` // deprecated in favour of identifier - Medium string `json:"medium"` // deprecated in favour of identifier - Address string `json:"address"` // deprecated in favour of identifier + LoginIdentifier // Flat fields deprecated in favour of `identifier`. + Identifier LoginIdentifier `json:"identifier"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -111,12 +110,11 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: getAccByPass, + GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - // TODO: Add SSO login return &UserInteractive{ Completed: []string{}, Flows: []userInteractiveFlow{ @@ -236,18 +234,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * } } - r := loginType.Request() - if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), - } + login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw)) + if resErr != nil { + return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) } - login, resErr := loginType.Login(ctx, r) - if resErr == nil { - u.AddCompletedStage(sessionID, authType) - // TODO: Check if there's more stages to go and return an error - return login, nil - } - return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) + + u.AddCompletedStage(sessionID, authType) + cleanup(ctx, nil) + // TODO: Check if there's more stages to go and return an error + return login, nil } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 0b7df3545..76d161a74 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,11 @@ var ( } ) -func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { +type fakeAccountDatabase struct { + AccountDatabase +} + +func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { acc, ok := lookup[localpart+" "+plaintextPassword] if !ok { return nil, fmt.Errorf("unknown user/password") @@ -38,7 +42,7 @@ func setup() *UserInteractive { ServerName: serverName, }, } - return NewUserInteractive(getAccountByPassword, cfg) + return NewUserInteractive(&fakeAccountDatabase{}, cfg) } func TestUserInteractiveChallenge(t *testing.T) { diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 29d7b0b37..b47701368 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon return &resp } + return UnmarshalJSON(body, iface) +} + +func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 589efe0b2..b48b9e93b 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,7 +19,6 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -65,21 +64,14 @@ func Login( JSON: passwordLogin(), } } else if req.Method == http.MethodPost { - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, - Config: cfg, - } - r := typePassword.Request() - resErr := httputil.UnmarshalJSONRequest(req, r) - if resErr != nil { - return *resErr - } - login, authErr := typePassword.Login(req.Context(), r) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) if authErr != nil { return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authErr2 := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authErr2) + return authErr2 } return util.JSONResponse{ Code: http.StatusMethodNotAllowed, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 9263c66bb..732066166 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -62,7 +62,7 @@ func Setup( mscCfg *config.MSCs, ) { rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) - userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) + userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, diff --git a/userapi/api/api.go b/userapi/api/api.go index 04609659c..46a13d971 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -24,6 +24,8 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + LoginTokenInternalAPI + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go new file mode 100644 index 000000000..f3aa037e4 --- /dev/null +++ b/userapi/api/api_logintoken.go @@ -0,0 +1,69 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "time" +) + +type LoginTokenInternalAPI interface { + // PerformLoginTokenCreation creates a new login token and associates it with the provided data. + PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error + + // PerformLoginTokenDeletion ensures the token doesn't exist. Success + // is returned even if the token didn't exist, or was already deleted. + PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error + + // QueryLoginToken returns the data associated with a login token. If + // the token is not valid, success is returned, but res.Data == nil. + QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error +} + +// LoginTokenData is the data that can be retrieved given a login token. This is +// provided by the calling code. +type LoginTokenData struct { + // UserID is the full mxid of the user. + UserID string +} + +// LoginTokenMetadata contains metadata created and maintained by the User API. +type LoginTokenMetadata struct { + Token string + Expiration time.Time +} + +type PerformLoginTokenCreationRequest struct { + Data LoginTokenData +} + +type PerformLoginTokenCreationResponse struct { + Metadata LoginTokenMetadata +} + +type PerformLoginTokenDeletionRequest struct { + Token string +} + +type PerformLoginTokenDeletionResponse struct{} + +type QueryLoginTokenRequest struct { + Token string +} + +type QueryLoginTokenResponse struct { + // Data is nil if the token was invalid. + Data *LoginTokenData +} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go new file mode 100644 index 000000000..e60dae594 --- /dev/null +++ b/userapi/api/api_trace_logintoken.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/util" +) + +func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { + err := t.Impl.PerformLoginTokenCreation(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { + err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { + err := t.Impl.QueryLoginToken(ctx, req, res) + util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) + return err +} diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go new file mode 100644 index 000000000..86ffc58f3 --- /dev/null +++ b/userapi/internal/api_logintoken.go @@ -0,0 +1,78 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// PerformLoginTokenCreation creates a new login token and associates it with the provided data. +func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error { + util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation") + _, domain, err := gomatrixserverlib.SplitID('@', req.Data.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + } + tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + if err != nil { + return err + } + res.Metadata = *tokenMeta + return nil +} + +// PerformLoginTokenDeletion ensures the token doesn't exist. +func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { + util.GetLogger(ctx).Info("PerformLoginTokenDeletion") + return a.DeviceDB.RemoveLoginToken(ctx, req.Token) +} + +// QueryLoginToken returns the data associated with a login token. If +// the token is not valid, success is returned, but res.Data == nil. +func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { + tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) + if err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', tokenData.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + } + if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Data = tokenData + return nil +} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go new file mode 100644 index 000000000..366a97099 --- /dev/null +++ b/userapi/inthttp/client_logintoken.go @@ -0,0 +1,65 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +const ( + PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" + PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" + QueryLoginTokenPath = "/userapi/queryLoginToken" +) + +func (h *httpUserInternalAPI) PerformLoginTokenCreation( + ctx context.Context, + request *api.PerformLoginTokenCreationRequest, + response *api.PerformLoginTokenCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformLoginTokenDeletion( + ctx context.Context, + request *api.PerformLoginTokenDeletionRequest, + response *api.PerformLoginTokenDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryLoginToken( + ctx context.Context, + request *api.QueryLoginTokenRequest, + response *api.QueryLoginTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") + defer span.Finish() + + apiURL := h.apiURL + QueryLoginTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ac05bcd09..d00ee042c 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -27,6 +27,8 @@ import ( // nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + addRoutesLoginToken(internalAPIMux, s) + internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { request := api.PerformAccountCreationRequest{} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go new file mode 100644 index 000000000..1f2eb34b9 --- /dev/null +++ b/userapi/inthttp/server_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// addRoutesLoginToken adds routes for all login token API calls. +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformLoginTokenCreationPath, + httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenCreationRequest{} + response := api.PerformLoginTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformLoginTokenDeletionPath, + httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenDeletionRequest{} + response := api.PerformLoginTokenDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryLoginTokenPath, + httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { + request := api.QueryLoginTokenRequest{} + response := api.QueryLoginTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 95fe99f33..8ff91cf1c 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -38,4 +38,15 @@ type Database interface { RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/devices/postgres/logintoken_table.go new file mode 100644 index 000000000..f601fc7db --- /dev/null +++ b/userapi/storage/devices/postgres/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 485234331..fd9d513f1 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,28 +28,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - devices devicesStatements + db *sql.DB + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -58,8 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, serverName); err != nil { return nil, err } + if err = lt.prepare(db); err != nil { + return nil, err + } - return &Database{db, d}, nil + return &Database{db, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +224,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/devices/sqlite3/logintoken_table.go new file mode 100644 index 000000000..75ef272f8 --- /dev/null +++ b/userapi/storage/devices/sqlite3/logintoken_table.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectByTokenStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, + {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely additional expired token)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 538644837..6e90413be 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,30 +28,41 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements + db *sql.DB + writer sqlutil.Writer + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -59,7 +71,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, writer, d}, nil + if err = lt.prepare(db); err != nil { + return nil, err + } + return &Database{db, writer, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +225,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go index 3c2034300..15cf8150c 100644 --- a/userapi/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -19,6 +19,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" @@ -27,13 +28,14 @@ import ( ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { +// and sets postgres connection parameters. loginTokenLifetime determines how long a +// login token from CreateLoginToken is valid. +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName) + return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go index f360f9857..3de7880b9 100644 --- a/userapi/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -16,6 +16,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" @@ -25,10 +26,11 @@ import ( func NewDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020a..c7e1f6674 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,6 +15,8 @@ package userapi import ( + "time" + "github.com/gorilla/mux" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -26,6 +28,13 @@ import ( "github.com/sirupsen/logrus" ) +// defaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const defaultLoginTokenLifetime = 2 * time.Minute + // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -37,11 +46,21 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func NewInternalAPI( accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } + return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) +} + +func newInternalAPI( + accountDB accounts.Database, + deviceDB devices.Database, + cfg *config.UserAPI, + appServices []config.ApplicationService, + keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e6..266f5ed58 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -1,4 +1,18 @@ -package userapi_test +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi import ( "context" @@ -6,15 +20,16 @@ import ( "net/http" "reflect" "testing" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" ) @@ -23,31 +38,41 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) +type apiTestOpts struct { + loginTokenLifetime time.Duration +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { + if opts.loginTokenLifetime == 0 { + opts.loginTokenLifetime = defaultLoginTokenLifetime + } + dbopts := &config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + } + accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) if err != nil { t.Fatalf("failed to create account DB: %s", err) } + deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + cfg := &config.UserAPI{ - DeviceDatabase: config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, Matrix: &config.Global{ ServerName: serverName, }, } - return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err) @@ -106,7 +131,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + AddInternalRoutes(router, userAPI) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) @@ -119,3 +144,115 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestLoginToken(t *testing.T) { + ctx := context.Background() + + t.Run("tokenLoginFlow", func(t *testing.T) { + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) + + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") + if err != nil { + t.Fatalf("failed to make account: %s", err) + } + + t.Log("Creating a login token like the SSO callback would...") + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } + + t.Log("Querying the login token like /login with m.login.token would...") + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } else if want := "@auser:example.com"; qresp.Data.UserID != want { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } + + t.Log("Deleting the login token like /login with m.login.token would...") + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) + + t.Run("expiredTokenIsNotReturned", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteWorks", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "@auser:example.com"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteUnknownIsNoOp", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) +} From 9ac27cabc5f624938264ddccf5500478d7d38bf3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Feb 2022 13:50:13 +0000 Subject: [PATCH 50/81] Version 0.6.3 (#2170) --- CHANGES.md | 19 +++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 07e09480a..4df8e869a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,24 @@ # Changelog +## Dendrite 0.6.3 (2022-02-10) + +### Features + +* Initial support for `m.login.token` +* A number of regressions from earlier v0.6.x versions should now be corrected + +### Fixes + +* Missing state is now correctly retrieved in cases where a gap in the timeline was closed but some of those events were missing state snapshots, which should help to unstick slow or broken rooms +* Fixed a transaction issue where inserting events into the database could deadlock, which should stop rooms from getting stuck +* Fixed a problem where rejected events could result in rolled back database transactions +* Avoided a potential race condition on fetching latest events by using the room updater instead +* Processing events from `/get_missing_events` will no longer result in potential recursion +* Federation events are now correctly generated for updated self-signing keys and signed devices +* Rejected events can now be un-rejected if they are reprocessed and all of the correct conditions are met +* Fetching missing auth events will no longer error as long as all needed events for auth were satisfied +* Users can now correctly forget rooms if they were not a member of the room + ## Dendrite 0.6.2 (2022-02-04) ### Fixes diff --git a/internal/version.go b/internal/version.go index de0b7c8c3..a07f01b61 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 6 - VersionPatch = 2 + VersionPatch = 3 VersionTag = "" // example: "rc1" ) From f800cae6d250e49fed1d96471da82dbaf25c3564 Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 10 Feb 2022 18:12:11 +0000 Subject: [PATCH 51/81] Point to /complement/ca not /ca (#2172) --- build/scripts/Complement.Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index a54fab1d4..1d520b4e7 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -25,7 +25,7 @@ EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca # At runtime, replace the SERVER_NAME with what we are told -CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ +CMD ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ - cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml From 4e75ab9930842e34314b0af89f60149871f6e6f2 Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 11 Feb 2022 12:35:47 +0000 Subject: [PATCH 52/81] Add postgres complement support (#2177) --- build/scripts/ComplementPostgres.Dockerfile | 53 +++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 build/scripts/ComplementPostgres.Dockerfile diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile new file mode 100644 index 000000000..6024ae8da --- /dev/null +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -0,0 +1,53 @@ +FROM golang:1.16-stretch as build +RUN apt-get update && apt-get install -y postgresql +WORKDIR /build + +# No password when connecting over localhost +RUN sed -i "s%127.0.0.1/32 md5%127.0.0.1/32 trust%g" /etc/postgresql/9.6/main/pg_hba.conf && \ + # Bump up max conns for moar concurrency + sed -i 's/max_connections = 100/max_connections = 2000/g' /etc/postgresql/9.6/main/postgresql.conf + +# This entry script starts postgres, waits for it to be up then starts dendrite +RUN echo '\ +#!/bin/bash -eu \n\ +pg_lsclusters \n\ +pg_ctlcluster 9.6 main start \n\ + \n\ +until pg_isready \n\ +do \n\ + echo "Waiting for postgres"; \n\ + sleep 1; \n\ +done \n\ +' > run_postgres.sh && chmod +x run_postgres.sh + +# we will dump the binaries and config file to this location to ensure any local untracked files +# that come from the COPY . . file don't contaminate the build +RUN mkdir /dendrite + +# Utilise Docker caching when downloading dependencies, this stops us needlessly +# downloading dependencies every time. +COPY go.mod . +COPY go.sum . +RUN go mod download + +COPY . . +RUN go build -o /dendrite ./cmd/dendrite-monolith-server +RUN go build -o /dendrite ./cmd/generate-keys +RUN go build -o /dendrite ./cmd/generate-config + +WORKDIR /dendrite +RUN ./generate-keys --private-key matrix_key.pem + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 + + +# At runtime, generate TLS cert based on the CA now mounted at /ca +# At runtime, replace the SERVER_NAME with what we are told +CMD /build/run_postgres.sh && ./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ + ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ + # Replace the connection string with a single postgres DB, using user/db = 'postgres' and no password, bump max_conns + sed -i "s%connection_string:.*$%connection_string: postgresql://postgres@localhost/postgres?sslmode=disable%g" dendrite.yaml && \ + sed -i 's/max_open_conns:.*$/max_open_conns: 100/g' dendrite.yaml && \ + cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ + ./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \ No newline at end of file From 88b45d5cd248794237baebbe4945ef708a7598de Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Feb 2022 15:18:14 +0000 Subject: [PATCH 53/81] Drop `m.room.create` events in federation `/send` transaction (#2179) --- federationapi/routing/send.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index dd4fe13a8..745e36de9 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -258,6 +258,9 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) continue } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { results[event.EventID()] = gomatrixserverlib.PDUResult{ Error: "Forbidden by server ACLs", From a566d53b0b763220b93946c44986e7337549769b Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 11 Feb 2022 16:26:23 +0000 Subject: [PATCH 54/81] Don't allow parallel complement tests (#2169) Fixes flakiness seemingly. See https://github.com/matrix-org/synapse/pull/11910 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4a1720295..124940f71 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -63,7 +63,7 @@ jobs: # Run Complement - run: | set -o pipefail && - go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + go test -v -p 1 -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: From a4e7d471af7e2cf902404f6740f0932a088cb660 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Feb 2022 18:15:44 +0100 Subject: [PATCH 55/81] Remove FederationDisabled error type (#2174) --- federationapi/consumers/roomserver.go | 30 +++++++++++---------------- federationapi/queue/queue.go | 27 ++++++++---------------- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index ac29f930b..173dcff01 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -19,6 +19,10 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/federationapi/queue" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/types" @@ -26,9 +30,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. @@ -97,21 +98,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) } if err := s.processMessage(*output.NewRoomEvent); err != nil { - switch err.(type) { - case *queue.ErrorFederationDisabled: - log.WithField("error", output.Type).Info( - err.Error(), - ) - default: - // panic rather than continue with an inconsistent database - log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "event": string(ev.JSON()), - "add": output.NewRoomEvent.AddsStateEventIDs, - "del": output.NewRoomEvent.RemovesStateEventIDs, - log.ErrorKey: err, - }).Panicf("roomserver output log: write room event failure") - } + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "event": string(ev.JSON()), + "add": output.NewRoomEvent.AddsStateEventIDs, + "del": output.NewRoomEvent.RemovesStateEventIDs, + log.ErrorKey: err, + }).Panicf("roomserver output log: write room event failure") } case api.OutputTypeNewInboundPeek: diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8a6ad1555..dcd090856 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -22,15 +22,16 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/gomatrixserverlib" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" ) // OutgoingQueues is a collection of queues for sending transactions to other @@ -182,23 +183,14 @@ func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) { destinationQueueTotal.Dec() } -type ErrorFederationDisabled struct { - Message string -} - -func (e *ErrorFederationDisabled) Error() string { - return e.Message -} - // SendEvent sends an event to the destinations func (oqs *OutgoingQueues) SendEvent( ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, destinations []gomatrixserverlib.ServerName, ) error { if oqs.disabled { - return &ErrorFederationDisabled{ - Message: "Federation disabled", - } + log.Trace("Federation is disabled, not sending event") + return nil } if origin != oqs.origin { // TODO: Support virtual hosting; gh issue #577. @@ -262,9 +254,8 @@ func (oqs *OutgoingQueues) SendEDU( destinations []gomatrixserverlib.ServerName, ) error { if oqs.disabled { - return &ErrorFederationDisabled{ - Message: "Federation disabled", - } + log.Trace("Federation is disabled, not sending EDU") + return nil } if origin != oqs.origin { // TODO: Support virtual hosting; gh issue #577. From 5106cc807cf22a95420b24f6bfdd5c9ac8aa06de Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Feb 2022 17:40:14 +0000 Subject: [PATCH 56/81] Ensure only one transaction is used for RS input per room (#2178) * Ensure the input API only uses a single transaction * Remove more of the dead query API call * Tidy up * Fix tests hopefully * Don't do unnecessary work for rooms that don't exist * Improve error, fix another case where transaction wasn't used properly * Add a unit test for checking single transaction on RS input API * Fix logic oops when deciding whether to use a transaction in storeEvent --- federationapi/routing/send_test.go | 43 +----- roomserver/api/api.go | 7 - roomserver/api/api_trace.go | 10 -- roomserver/api/query.go | 21 --- roomserver/internal/input/input_events.go | 33 ++--- roomserver/internal/input/input_missing.go | 123 +++++++++--------- roomserver/internal/input/input_test.go | 93 +++++++++++++ roomserver/internal/query/query.go | 33 ----- roomserver/inthttp/client.go | 14 -- roomserver/inthttp/server.go | 14 -- .../storage/postgres/event_json_table.go | 3 +- roomserver/storage/shared/room_updater.go | 29 ++++- roomserver/storage/shared/storage.go | 2 +- 13 files changed, 211 insertions(+), 214 deletions(-) create mode 100644 roomserver/internal/input/input_test.go diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index f1f6169d9..4280643e9 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -93,11 +93,10 @@ func (o *testEDUProducer) InputCrossSigningKeyUpdate( type testRoomserverAPI struct { api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) InputRoomEvents( @@ -140,20 +139,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryMissingAuthPrevEvents(request) - response.RoomExists = res.RoomExists - response.MissingAuthEventIDs = res.MissingAuthEventIDs - response.MissingPrevEventIDs = res.MissingPrevEventIDs - return nil -} - // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -312,15 +297,7 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomat // The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } @@ -332,15 +309,7 @@ func TestBasicTransaction(t *testing.T) { // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index d35fd84df..e6d37e8f1 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -83,13 +83,6 @@ type RoomserverInternalAPI interface { response *QueryStateAfterEventsResponse, ) error - // Query whether the roomserver is missing any auth or prev events. - QueryMissingAuthPrevEvents( - ctx context.Context, - request *QueryMissingAuthPrevEventsRequest, - response *QueryMissingAuthPrevEventsResponse, - ) error - // Query a list of events by event ID. QueryEventsByID( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 64cbaca49..16f52abb7 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -129,16 +129,6 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( return err } -func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents( - ctx context.Context, - req *QueryMissingAuthPrevEventsRequest, - res *QueryMissingAuthPrevEventsResponse, -) error { - err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res)) - return err -} - func (t *RoomserverInternalAPITrace) QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 283217157..96d6711c6 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -83,27 +83,6 @@ type QueryStateAfterEventsResponse struct { StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` } -type QueryMissingAuthPrevEventsRequest struct { - // The room ID to query the state in. - RoomID string `json:"room_id"` - // The list of auth events to check the existence of. - AuthEventIDs []string `json:"auth_event_ids"` - // The list of previous events to check the existence of. - PrevEventIDs []string `json:"prev_event_ids"` -} - -type QueryMissingAuthPrevEventsResponse struct { - // Does the room exist on this roomserver? - // If the room doesn't exist all other fields will be empty. - RoomExists bool `json:"room_exists"` - // The room version of the room. - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - // The event IDs of the auth events that we don't know locally. - MissingAuthEventIDs []string `json:"missing_auth_event_ids"` - // The event IDs of the previous events that we don't know locally. - MissingPrevEventIDs []string `json:"missing_prev_event_ids"` -} - // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { // The event IDs to look up. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 873a051cd..4e151699e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,20 +128,24 @@ func (r *Inputer) processRoomEvent( } } - missingRes := &api.QueryMissingAuthPrevEventsResponse{} - serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} - if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") { - missingReq := &api.QueryMissingAuthPrevEventsRequest{ - RoomID: event.RoomID(), - AuthEventIDs: event.AuthEventIDs(), - PrevEventIDs: event.PrevEventIDs(), - } - if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) - } + // Don't waste time processing the event if the room doesn't exist. + // A room entry locally will only be created in response to a create + // event. + isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + if !updater.RoomExists() && !isCreateEvent { + return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + } + + var missingAuth, missingPrev bool + serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} + if !isCreateEvent { + missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + if err != nil { + return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + } + missingAuth = len(missingAuthIDs) > 0 + missingPrev = !input.HasState && len(missingPrevIDs) > 0 } - missingAuth := len(missingRes.MissingAuthEventIDs) > 0 - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ @@ -246,14 +250,13 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - queryer: r.Queryer, db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), servers: serverRes.ServerNames, hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { // Something went wrong with retrieving the missing state, so we can't diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 19771d4bd..fc3be7987 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -10,7 +10,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -27,14 +27,13 @@ type missingStateReq struct { origin gomatrixserverlib.ServerName db *shared.RoomUpdater inputer *Inputer - queryer *query.Queryer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex - haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEvents map[string]*gomatrixserverlib.Event haveEventsMutex sync.Mutex } @@ -326,20 +325,20 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion for i := range respState.StateEvents { se := respState.StateEvents[i] if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { - respState.StateEvents[i] = h.Unwrap() + respState.StateEvents[i] = h addedToState = true break } } if !addedToState { - respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + respState.StateEvents = append(respState.StateEvents, h) } } return respState, false, nil } -func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { +func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event { t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() if cached, exists := t.haveEvents[ev.EventID()]; exists { @@ -350,32 +349,49 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g } func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { - var res api.QueryStateAfterEventsResponse - err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ - RoomID: roomID, - PrevEventIDs: []string{eventID}, - }, &res) - if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) + var res parsedRespState + roomInfo, err := t.db.RoomInfo(ctx, roomID) + if err != nil { return nil } - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) - for i, ev := range res.StateEvents { + roomState := state.NewStateResolution(t.db, roomInfo) + stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID) + return nil + } + stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, stateAtEvents) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load combined state after %s locally", eventID) + return nil + } + stateEventNIDs := make([]types.EventNID, 0, len(stateEntries)) + for _, entry := range stateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEvents, err := t.db.Events(ctx, stateEventNIDs) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load state events locally") + return nil + } + res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)) + for _, ev := range stateEvents { // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. - stateEvents[i] = t.cacheAndReturn(ev) + res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event)) t.hadEvent(ev.EventID()) } - // we should never access res.StateEvents again so we delete it here to make GC faster - res.StateEvents = nil - var authEvents []*gomatrixserverlib.Event + // encourage GC + stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign + missingAuthEvents := map[string]bool{} + res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3) for _, ev := range stateEvents { t.haveEventsMutex.Lock() for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { - authEvents = append(authEvents, aev.Unwrap()) + res.AuthEvents = append(res.AuthEvents, aev) } else { missingAuthEvents[ae] = true } @@ -389,25 +405,18 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room for evID := range missingAuthEvents { missingEventList = append(missingEventList, evID) } - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { return nil } - for i, ev := range queryRes.Events { - authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + for i, ev := range events { + res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event)) t.hadEvent(ev.EventID()) } - queryRes.Events = nil } - return &parsedRespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), - AuthEvents: authEvents, - } + return &res } // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what @@ -448,7 +457,7 @@ retryAllowedState: return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID) - resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + resolvedStateEvents = append(resolvedStateEvents, h) goto retryAllowedState default: } @@ -513,7 +522,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()) + missingEvents = append(missingEvents, t.cacheAndReturn(ev)) } // topologically sort and sanity check that we are making forward progress @@ -602,11 +611,11 @@ func (t *missingStateReq) lookupMissingStateViaState( // We load these as trusted as we called state.Check before which loaded them as untrusted. for i, evJSON := range state.AuthEvents { ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + parsedState.AuthEvents[i] = t.cacheAndReturn(ev) } for i, evJSON := range state.StateEvents { ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + parsedState.StateEvents[i] = t.cacheAndReturn(ev) } return parsedState, nil } @@ -634,23 +643,22 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - // fetch as many as we can from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { + return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil, err - } - for i, ev := range queryRes.Events { - queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + + for i, ev := range events { + events[i].Event = t.cacheAndReturn(events[i].Event) t.hadEvent(ev.EventID()) - evID := queryRes.Events[i].EventID() + evID := events[i].EventID() if missing[evID] { delete(missing, evID) } } - queryRes.Events = nil // allow it to be GCed + + // encourage GC + events = nil // nolint:ineffassign concurrentRequests := 8 missingCount := len(missing) @@ -704,7 +712,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { - var h *gomatrixserverlib.HeaderedEvent + var h *gomatrixserverlib.Event h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) switch err.(type) { case verifySigError: @@ -759,7 +767,7 @@ func (t *missingStateReq) createRespStateFromStateIDs( logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) continue } - respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) + respState.StateEvents = append(respState.StateEvents, ev) } for i := range stateIDs.AuthEventIDs { ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] @@ -767,7 +775,7 @@ func (t *missingStateReq) createRespStateFromStateIDs( logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) continue } - respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) + respState.AuthEvents = append(respState.AuthEvents, ev) } // We purposefully do not do auth checks on the returned events, as they will still // be processed in the exact same way, just as a 'rejected' event @@ -775,17 +783,14 @@ func (t *missingStateReq) createRespStateFromStateIDs( return &respState, nil } -func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { if localFirst { // fetch from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: []string{missingEventID}, - } - var queryRes api.QueryEventsByIDResponse - if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + if err != nil { util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) - } else if len(queryRes.Events) == 1 { - return queryRes.Events[0], nil + } else if len(events) == 1 { + return events[0].Event, nil } } var event *gomatrixserverlib.Event @@ -822,7 +827,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } - return t.cacheAndReturn(event.Headered(roomVersion)), nil + return t.cacheAndReturn(event), nil } func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go new file mode 100644 index 000000000..4fa966281 --- /dev/null +++ b/roomserver/internal/input/input_test.go @@ -0,0 +1,93 @@ +package input_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +func psqlConnectionString() config.DataSource { + user := os.Getenv("POSTGRES_USER") + if user == "" { + user = "dendrite" + } + dbName := os.Getenv("POSTGRES_DB") + if dbName == "" { + dbName = "dendrite" + } + connStr := fmt.Sprintf( + "user=%s dbname=%s sslmode=disable", user, dbName, + ) + password := os.Getenv("POSTGRES_PASSWORD") + if password != "" { + connStr += fmt.Sprintf(" password=%s", password) + } + host := os.Getenv("POSTGRES_HOST") + if host != "" { + connStr += fmt.Sprintf(" host=%s", host) + } + return config.DataSource(connStr) +} + +func TestSingleTransactionOnInput(t *testing.T) { + deadline, _ := t.Deadline() + if max := time.Now().Add(time.Second * 3); deadline.After(max) { + deadline = max + } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + event, err := gomatrixserverlib.NewEventFromTrustedJSON( + []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), + false, gomatrixserverlib.RoomVersionV6, + ) + if err != nil { + t.Fatal(err) + } + in := api.InputRoomEvent{ + Kind: api.KindOutlier, // don't panic if we generate an output event + Event: event.Headered(gomatrixserverlib.RoomVersionV6), + } + cache, err := caching.NewInMemoryLRUCache(false) + if err != nil { + t.Fatal(err) + } + db, err := storage.Open( + &config.DatabaseOptions{ + ConnectionString: psqlConnectionString(), + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + cache, + ) + if err != nil { + t.Logf("PostgreSQL not available (%s), skipping", err) + t.SkipNow() + } + inputter := &input.Inputer{ + DB: db, + } + res := &api.InputRoomEventsResponse{} + inputter.InputRoomEvents( + ctx, + &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{in}, + Asynchronous: false, + }, + res, + ) + // If we fail here then it's because we've hit the test deadline, + // so we probably deadlocked + if err := res.Err(); err != nil { + t.Fatal(err) + } +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 05cd686f4..c8bbe7705 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -125,39 +125,6 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI -func (r *Queryer) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info == nil { - return errors.New("room doesn't exist") - } - - response.RoomExists = !info.IsStub - response.RoomVersion = info.RoomVersion - - for _, authEventID := range request.AuthEventIDs { - if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID) - } - } - - for _, prevEventID := range request.PrevEventIDs { - state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { - response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) - } - } - - return nil -} - // QueryEventsByID implements api.RoomserverInternalAPI func (r *Queryer) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 4f6a58bde..a61404efe 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -40,7 +40,6 @@ const ( // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" - RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents" RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" @@ -302,19 +301,6 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // QueryEventsByID implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index bf319262f..691a45830 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -149,20 +149,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - RoomserverQueryMissingAuthPrevEventsPath, - httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingAuthPrevEventsRequest - var response api.QueryMissingAuthPrevEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 433e445d8..b3220effd 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -76,7 +76,8 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + stmt := sqlutil.TxStmt(txn, s.insertEventJSONStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index fc75a2606..89b878b9d 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -16,6 +16,7 @@ type RoomUpdater struct { latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID + roomExists bool } func rollback(txn *sql.Tx) { @@ -33,7 +34,7 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ // succeed, processing a create event which creates the room, or it won't. if roomInfo == nil { return &RoomUpdater{ - transaction{ctx, txn}, d, nil, nil, "", 0, + transaction{ctx, txn}, d, nil, nil, "", 0, false, }, nil } @@ -57,10 +58,15 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ } } return &RoomUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, true, }, nil } +// RoomExists returns true if the room exists and false otherwise. +func (u *RoomUpdater) RoomExists() bool { + return u.roomExists +} + // Implements sqlutil.Transaction func (u *RoomUpdater) Commit() error { if u.txn == nil { // SQLite mode probably @@ -97,6 +103,25 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } +func (u *RoomUpdater) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + for _, authEventID := range e.AuthEventIDs() { + if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8319de265..e96c77afa 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -553,7 +553,7 @@ func (d *Database) storeEvent( err error ) var txn *sql.Tx - if updater != nil { + if updater != nil && updater.txn != nil { txn = updater.txn } err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { From e22e87c01223eb556184edc719fd703a9b88e7c8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Feb 2022 11:14:43 +0000 Subject: [PATCH 57/81] Update to matrix-org/gomatrixserverlib@20632dd --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index a1dc04084..8cfedd7bb 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.10 diff --git a/go.sum b/go.sum index 1483c792f..e6678e51a 100644 --- a/go.sum +++ b/go.sum @@ -983,8 +983,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335 h1:xzK9Q9VGqsZNGx5ANFOCWkJ8R+W1J2BOguxsVZw6m8M= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220209202448-9805ef634335/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed h1:R8EiLWArq7KT96DrUq1xq9scPh8vLwKKeCTnORPyjhU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220214133635-20632dd262ed/go.mod h1:qFvhfbQ5orQxlH9vCiFnP4dW27xxnWHdNUBKyj/fbiY= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf h1:/nqfHUdQHr3WVdbZieaYFvHF1rin5pvDTa/NOZ/qCyE= github.com/matrix-org/pinecone v0.0.0-20220121094951-351265543ddf/go.mod h1:r6dsL+ylE0yXe/7zh8y/Bdh6aBYI1r+u4yZni9A4iyk= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= From 4c8c53244ef3448a63714a24a5d86a3ace4e7189 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Feb 2022 16:27:22 +0000 Subject: [PATCH 58/81] Update prometheus --- go.mod | 6 +----- go.sum | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 8cfedd7bb..11f5b0608 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/codeclysm/extract v2.2.0+incompatible github.com/containerd/containerd v1.5.9 // indirect github.com/docker/docker v20.10.12+incompatible @@ -23,7 +22,6 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 - github.com/json-iterator/go v1.1.12 // indirect github.com/juju/testing v0.0.0-20211215003918-77eb13d6cad2 // indirect github.com/klauspost/compress v1.14.2 // indirect github.com/lib/pq v1.10.4 @@ -55,9 +53,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pressly/goose v2.7.0+incompatible - github.com/prometheus/client_golang v1.11.0 - github.com/prometheus/common v0.32.1 // indirect - github.com/prometheus/procfs v0.7.3 // indirect + github.com/prometheus/client_golang v1.12.1 github.com/sirupsen/logrus v1.8.1 github.com/tidwall/gjson v1.14.0 github.com/tidwall/sjson v1.2.4 diff --git a/go.sum b/go.sum index e6678e51a..8732d27ec 100644 --- a/go.sum +++ b/go.sum @@ -1231,8 +1231,9 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.1.0/go.mod h1:I1FGZT9+L76gKKOs5djB6ezCbFQP1xR9D75/vuwEF3g= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_golang v1.11.0 h1:HNkLOAEQMIDv/K+04rukrLx6ch7msSRwf3/SASFAGtQ= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_golang v1.12.1 h1:ZiaPsmm9uiBeaSMRznKsCDNtPCS0T3JVDGF+06gjBzk= +github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -1734,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= From f92b048fec29b711ebb853193fac9977af8d5a90 Mon Sep 17 00:00:00 2001 From: kegsay Date: Tue, 15 Feb 2022 17:40:48 +0000 Subject: [PATCH 59/81] Add host mount compatible Complement image (#2187) --- build/scripts/ComplementLocal.Dockerfile | 53 ++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 build/scripts/ComplementLocal.Dockerfile diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile new file mode 100644 index 000000000..60b4d983a --- /dev/null +++ b/build/scripts/ComplementLocal.Dockerfile @@ -0,0 +1,53 @@ +# A local development Complement dockerfile, to be used with host mounts +# /cache -> Contains the entire dendrite code at Dockerfile build time. Builds binaries but only keeps the generate-* ones. Pre-compilation saves time. +# /dendrite -> Host-mounted sources +# /runtime -> Binaries and config go here and are run at runtime +# At runtime, dendrite is built from /dendrite and run in /runtime. +# +# Use these mounts to make use of this dockerfile: +# COMPLEMENT_HOST_MOUNTS='/your/local/dendrite:/dendrite:ro;/your/go/path:/go:ro' +FROM golang:1.16-stretch +RUN apt-get update && apt-get install -y sqlite3 + +WORKDIR /runtime + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 + +# This script compiles Dendrite for us. +RUN echo '\ +#!/bin/bash -eux \n\ +if test -f "/runtime/dendrite-monolith-server"; then \n\ + echo "Skipping compilation; binaries exist" \n\ + exit 0 \n\ +fi \n\ +cd /dendrite \n\ +go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ +' > compile.sh && chmod +x compile.sh + +# This script runs Dendrite for us. Must be run in the /runtime directory. +RUN echo '\ +#!/bin/bash -eu \n\ +./generate-keys --private-key matrix_key.pem \n\ +./generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ +./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ +cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ +./dendrite-monolith-server --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ +' > run.sh && chmod +x run.sh + + +WORKDIR /cache +# Pre-download deps; we don't need to do this if the GOPATH is mounted. +COPY go.mod . +COPY go.sum . +RUN go mod download + +# Build the monolith in /cache - we won't actually use this but will rely on build artifacts to speed +# up the real compilation. Build the generate-* binaries in the true /runtime locations. +# If the generate-* source is changed, this dockerfile needs re-running. +COPY . . +RUN go build ./cmd/dendrite-monolith-server && go build -o /runtime ./cmd/generate-keys && go build -o /runtime ./cmd/generate-config + + +WORKDIR /runtime +CMD /runtime/compile.sh && /runtime/run.sh From fa1e12b503376cd34bac730f81aeedc8dca1b2aa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Feb 2022 11:56:08 +0000 Subject: [PATCH 60/81] Don't panic on retiring an invite that we haven't seen yet (#2189) --- syncapi/consumers/roomserver.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 7fe52b728..15485bb35 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -16,6 +16,7 @@ package consumers import ( "context" + "database/sql" "encoding/json" "fmt" @@ -307,7 +308,9 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( ctx context.Context, msg api.OutputRetireInviteEvent, ) { pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID) - if err != nil { + // It's possible we just haven't heard of this invite yet, so + // we should not panic if we try to retire it. + if err != nil && err != sql.ErrNoRows { sentry.CaptureException(err) // panic rather than continue with an inconsistent database log.WithFields(log.Fields{ From e9b672a34e08bce9d12b2a2454c19fde6e52036e Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Wed, 16 Feb 2022 17:56:45 +0100 Subject: [PATCH 61/81] Make "Device list doesn't change if remote server is down" pass (#2190) --- keyserver/internal/internal.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 2536c1f76..ffbcac94b 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -513,6 +513,11 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( // drop the error as it's already a failure at this point _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys) } + + // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + if len(res.DeviceKeys) > 0 { + delete(res.Failures, serverName) + } respMu.Unlock() } From 5a39512f5f35b13adea3afc2e366e01ec73924de Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Wed, 16 Feb 2022 18:55:38 +0100 Subject: [PATCH 62/81] Add account type (#2171) * Add account_type for sqlite3 * Add account_type for postgres * Remove CreateGuestAccount from interface * Add new AccountTypes & update test * Use newly added AccountType for account creation * Add migrations * Reuse type * Add AccounnType to Device, so it can be verified on requests * Rename migration, add missing update for appservices * Rename sqlite3 migration * Add missing AccountType to return value * Update sqlite migration Change allowance check on /admin/whois * Fix migration, add IS NULL * Move accountType to completeRegistration * Fix migrations * Add passing test --- appservice/appservice.go | 5 +- clientapi/routing/admin_whois.go | 4 +- clientapi/routing/register.go | 26 +++++---- cmd/create-account/main.go | 15 ++++-- sytest-whitelist | 1 + userapi/api/api.go | 10 +++- userapi/internal/api.go | 35 +++++++----- userapi/storage/accounts/interface.go | 3 +- .../accounts/postgres/accounts_table.go | 24 +++++---- .../deltas/20200929203058_is_active.go | 4 +- .../2022021013023800_add_account_type.go | 34 ++++++++++++ userapi/storage/accounts/postgres/storage.go | 41 +++++++------- .../accounts/sqlite3/accounts_table.go | 23 ++++---- .../deltas/20200929203058_is_active.go | 4 +- .../2022021012490600_add_account_type.go | 54 +++++++++++++++++++ userapi/storage/accounts/sqlite3/storage.go | 52 +++++++----------- userapi/storage/accounts/storage.go | 3 +- userapi/userapi_test.go | 9 ++-- 18 files changed, 230 insertions(+), 117 deletions(-) create mode 100644 userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go create mode 100644 userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go diff --git a/appservice/appservice.go b/appservice/appservice.go index 7e7c67f53..b33d7b701 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -22,6 +22,8 @@ import ( "time" "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" @@ -34,7 +36,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -121,7 +122,7 @@ func generateAppServiceAccount( ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ - AccountType: userapi.AccountTypeUser, + AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, diff --git a/clientapi/routing/admin_whois.go b/clientapi/routing/admin_whois.go index b448791c3..87bb79366 100644 --- a/clientapi/routing/admin_whois.go +++ b/clientapi/routing/admin_whois.go @@ -47,8 +47,8 @@ func GetAdminWhois( req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, ) util.JSONResponse { - if userID != device.UserID { - // TODO: Still allow if user is admin + allowed := device.AccountType == api.AccountTypeAdmin || userID == device.UserID + if !allowed { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("userID does not match the current user"), diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 8823a41e3..fc275a5d1 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -32,6 +32,12 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -39,11 +45,6 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/gomatrixserverlib/tokens" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" ) var ( @@ -701,7 +702,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -720,7 +721,7 @@ func checkAndCompleteFlow( // This flow was completed, registration can continue return completeRegistration( req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } @@ -745,6 +746,7 @@ func completeRegistration( username, password, appserviceID, ipAddr, userAgent string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, + accType userapi.AccountType, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -759,13 +761,12 @@ func completeRegistration( JSON: jsonerror.BadJSON("missing password"), } } - var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, Password: password, - AccountType: userapi.AccountTypeUser, + AccountType: accType, OnConflict: userapi.ConflictAbort, }, &accRes) if err != nil { @@ -963,5 +964,10 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS return *resErr } deviceID := "shared_secret_registration" - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID) + + accType := userapi.AccountTypeUser + if ssrr.Admin { + accType = userapi.AccountTypeAdmin + } + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) } diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 3ac077705..d9202eb0d 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -23,12 +23,14 @@ import ( "os" "strings" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "golang.org/x/term" + + "github.com/matrix-org/dendrite/setup" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" ) const usage = `Usage: %s @@ -57,6 +59,7 @@ var ( pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") askPass = flag.Bool("ask-pass", false, "Ask for the password to use") + isAdmin = flag.Bool("admin", false, "Create an admin account") ) func main() { @@ -81,7 +84,11 @@ func main() { logrus.Fatalln("Failed to connect to the database:", err.Error()) } - _, err = accountDB.CreateAccount(context.Background(), *username, pass, "") + accType := api.AccountTypeUser + if *isAdmin { + accType = api.AccountTypeAdmin + } + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "", accType) if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } diff --git a/sytest-whitelist b/sytest-whitelist index 04b1bbf36..d739313ac 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -592,3 +592,4 @@ Forward extremities remain so even after the next events are populated as outlie If a device list update goes missing, the server resyncs on the next one uploading self-signing key notifies over federation uploading signed devices gets propagated over federation +Device list doesn't change if remote server is down \ No newline at end of file diff --git a/userapi/api/api.go b/userapi/api/api.go index 46a13d971..2be662e55 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -18,8 +18,9 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) // UserInternalAPI is the internal API for information about users and devices. @@ -353,6 +354,7 @@ type Device struct { // If the device is for an appservice user, // this is the appservice ID. AppserviceID string + AccountType AccountType } // Account represents a Matrix account on this home server. @@ -361,7 +363,7 @@ type Account struct { Localpart string ServerName gomatrixserverlib.ServerName AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest + AccountType AccountType // TODO: Associations (e.g. with application services) } @@ -417,4 +419,8 @@ const ( AccountTypeUser AccountType = 1 // AccountTypeGuest indicates this is a guest account AccountTypeGuest AccountType = 2 + // AccountTypeAdmin indicates this is an admin account + AccountTypeAdmin AccountType = 3 + // AccountTypeAppService indicates this is an appservice account + AccountTypeAppService AccountType = 4 ) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 5d91383de..f96d4804c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -21,6 +21,10 @@ import ( "errors" "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -29,9 +33,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" ) type UserInternalAPI struct { @@ -58,16 +59,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - if req.AccountType == api.AccountTypeGuest { - acc, err := a.AccountDB.CreateGuestAccount(ctx) - if err != nil { - return err - } - res.AccountCreated = true - res.Account = acc - return nil - } - acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) + acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -86,10 +78,17 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P Localpart: req.Localpart, ServerName: a.ServerName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + AccountType: req.AccountType, } return nil } + if req.AccountType == api.AccountTypeGuest { + res.AccountCreated = true + res.Account = acc + return nil + } + if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -375,6 +374,15 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } + localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + return err + } + acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart) + if err != nil { + return err + } + device.AccountType = acc.AccountType res.Device = device return nil } @@ -401,6 +409,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // AS dummy device has AS's token. AccessToken: token, AppserviceID: appService.ID, + AccountType: api.AccountTypeAppService, } localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index f03b3774c..a2185774a 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -32,8 +32,7 @@ type Database interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) - CreateGuestAccount(ctx context.Context) (*api.Account, error) + CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index b57aa901f..9e3e456a7 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -19,10 +19,11 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT FALSE + is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type SMALLINT NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -57,7 +60,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" @@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error - if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -116,6 +119,7 @@ func (s *accountsStatements) insertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } @@ -147,7 +151,7 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go index 9e14286e0..32d3235be 100644 --- a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go +++ b/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go new file mode 100644 index 000000000..2fae00cb9 --- /dev/null +++ b/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go @@ -0,0 +1,34 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; +ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, + ) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2f8290623..d31efd257 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -23,13 +23,14 @@ import ( "strconv" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" // Import the postgres database driver. _ "github.com/lib/pq" @@ -73,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } @@ -155,37 +157,32 @@ func (d *Database) SetPassword( return d.accounts.updatePassword(ctx, localpart, hash) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, sqlutil.ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) return err }) return } func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var account *api.Account var err error @@ -197,7 +194,7 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if sqlutil.IsUniqueConstraintViolationErr(err) { return nil, sqlutil.ErrUserExists } diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 8a7c8fba7..5a918e034 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -19,10 +19,11 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -39,14 +40,16 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT 0 + is_deactivated BOOLEAN DEFAULT 0, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type INTEGER NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -55,7 +58,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" @@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error - if appserviceID == "" { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -147,7 +150,7 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go index 9fddb05a1..c69614e83 100644 --- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go new file mode 100644 index 000000000..9b058dedd --- /dev/null +++ b/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go @@ -0,0 +1,54 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func init() { + goose.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp; +CREATE TABLE account_accounts ( + localpart TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL, + password_hash TEXT, + appservice_id TEXT, + is_deactivated BOOLEAN DEFAULT 0, + account_type INTEGER NOT NULL +); +INSERT + INTO account_accounts ( + localpart, created_ts, password_hash, appservice_id, account_type + ) SELECT + localpart, created_ts, password_hash, appservice_id, 1 + FROM account_accounts_tmp +; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE account_accounts_tmp;`) + if err != nil { + return fmt.Errorf("failed to add column: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 2b731b759..0bab16ca3 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -24,13 +24,14 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" ) // Database represents an account database @@ -77,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } @@ -170,38 +172,11 @@ func (d *Database) SetPassword( }) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - // We need to lock so we sequentially create numeric localparts. If we don't, two calls to - // this function will cause the same number to be selected and one will fail with 'database is locked' - // when the first txn upgrades to a write txn. We also need to lock the account creation else we can - // race with CreateAccount - // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. d.profilesMu.Lock() @@ -211,7 +186,18 @@ func (d *Database) CreateAccount( defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) return err }) return @@ -220,7 +206,7 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -232,7 +218,7 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/accounts/storage.go index a21f7d94e..f43f7efd6 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/accounts/storage.go @@ -20,10 +20,11 @@ package accounts import ( "fmt" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" - "github.com/matrix-org/gomatrixserverlib" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 266f5ed58..141dd96d1 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -23,6 +23,9 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" @@ -30,8 +33,6 @@ import ( "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" ) const ( @@ -73,7 +74,7 @@ func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") + _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -151,7 +152,7 @@ func TestLoginToken(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) { userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "") + _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } From a4681bc7f77530cd73fc510d641c51fbb2855f2e Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 17 Feb 2022 10:59:44 +0000 Subject: [PATCH 63/81] Set 'complement' as the shared secret for CI (#2194) --- cmd/generate-config/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 60729672e..f87665fbe 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -91,6 +91,7 @@ func main() { cfg.Logging[0].Type = "std" cfg.UserAPI.BCryptCost = bcrypt.MinCost cfg.Global.JetStream.InMemory = true + cfg.ClientAPI.RegistrationSharedSecret = "complement" } j, err := yaml.Marshal(cfg) From f51e2a99e93f03eb9aed59b2643170dede1f4fe8 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Thu, 17 Feb 2022 13:54:29 +0100 Subject: [PATCH 64/81] Remove outbound proxy, http.ProxyFromEnvironment is now used (#2191) --- dendrite-config.yaml | 7 ------- setup/config/config_federationapi.go | 4 ---- setup/config/config_test.go | 5 ----- 3 files changed, 16 deletions(-) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 38b146d70..b71e8d845 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -204,13 +204,6 @@ federation_api: # enable this option in production as it presents a security risk! disable_tls_validation: false - # Use the following proxy server for outbound federation traffic. - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 - # Perspective keyservers to use as a backup when direct key fetches fail. This may # be required to satisfy key requests for servers that are no longer online when # joining some rooms. diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 4f5f49de8..95e705033 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -29,8 +29,6 @@ type FederationAPI struct { // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` - Proxy Proxy `yaml:"proxy_outbound"` - // Perspective keyservers, to use as a backup when direct key fetch // requests don't succeed KeyPerspectives KeyPerspectives `yaml:"key_perspectives"` @@ -50,8 +48,6 @@ func (c *FederationAPI) Defaults(generate bool) { c.FederationMaxRetries = 16 c.DisableTLSValidation = false - - c.Proxy.Defaults() } func (c *FederationAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 5aa54929e..97c98e57f 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -118,11 +118,6 @@ federation_sender: conn_max_lifetime: -1 send_max_retries: 16 disable_tls_validation: false - proxy_outbound: - enabled: false - protocol: http - host: localhost - port: 8080 key_server: internal_api: listen: http://localhost:7779 From 934491eda5c12a913cb5bcc06aac31aae843c461 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 13:15:35 +0000 Subject: [PATCH 65/81] Update NATS Server to v2.7.2 (#2193) * Update NATS JetStream to v2.7.2 * Remove deprecated option --- go.mod | 4 ++-- go.sum | 15 ++++++++------- setup/jetstream/nats.go | 13 ++++++------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index 11f5b0608..2316096df 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matrix-org/dendrite -replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 +replace github.com/nats-io/nats-server/v2 => github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad replace github.com/nats-io/nats.go => github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c @@ -45,7 +45,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.10 github.com/morikuni/aec v1.0.0 // indirect github.com/nats-io/nats-server/v2 v2.3.2 - github.com/nats-io/nats.go v1.13.1-0.20211122170419-d7c1d78a50fc + github.com/nats-io/nats.go v1.13.1-0.20220121202836-972a071d373d github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20211220175533-9d16fdc47b31 diff --git a/go.sum b/go.sum index 8732d27ec..e79015e51 100644 --- a/go.sum +++ b/go.sum @@ -1122,8 +1122,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8m github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= -github.com/nats-io/jwt/v2 v2.2.0 h1:Yg/4WFK6vsqMudRg91eBb7Dh6XeVcDMPHycDE8CfltE= -github.com/nats-io/jwt/v2 v2.2.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296 h1:vU9tpM3apjYlLLeY23zRWJ9Zktr5jp+mloR942LEOpY= +github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -1132,8 +1132,8 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/ncw/swift v1.0.47/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423 h1:BLQVdjMH5XD4BYb0fa+c2Oh2Nr1vrO7GKvRnIJDxChc= -github.com/neilalexander/nats-server/v2 v2.3.3-0.20220104162330-c76d5fd70423/go.mod h1:9sdEkBhyZMQG1M9TevnlYUwMusRACn2vlgOeqoHKwVo= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad h1:Z2nWMQsXWWqzj89nW6OaLJSdkFknqhaR5whEOz4++Y8= +github.com/neilalexander/nats-server/v2 v2.7.2-0.20220217100407-087330ed46ad/go.mod h1:tckmrt0M6bVaDT3kmh9UrIq/CBOBBse+TpXQi5ldaa8= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c h1:G2qsv7D0rY94HAu8pXmElMluuMHQ85waxIDQBhIzV2Q= github.com/neilalexander/nats.go v1.11.1-0.20220104162523-f4ddebe1061c/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= github.com/neilalexander/utp v0.1.1-0.20210622132614-ee9a34a30488/go.mod h1:NPHGhPc0/wudcaCqL/H5AOddkRf8GPRhzOujuUKGQu8= @@ -1508,8 +1508,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a h1:atOEWVSedO4ksXBe/UrlbSLVxQQ9RxM/tT2Jy10IaHo= golang.org/x/crypto v0.0.0-20220209195652-db638375bc3a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1735,6 +1735,7 @@ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220111092808-5a964db01320/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220207234003-57398862261d h1:Bm7BNOQt2Qv7ZqysjeLjgCBanX+88Z/OtdvsrEv1Djc= golang.org/x/sys v0.0.0-20220207234003-57398862261d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1756,10 +1757,10 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 h1:GZokNIeuVkl3aZHJchRrr13WCsols02MLUcz1U9is6M= +golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 77ad2b721..562b0131e 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -24,13 +24,12 @@ func Prepare(cfg *config.JetStream) natsclient.JetStreamContext { if natsServer == nil { var err error natsServer, err = natsserver.NewServer(&natsserver.Options{ - ServerName: "monolith", - DontListen: true, - JetStream: true, - StoreDir: string(cfg.StoragePath), - NoSystemAccount: true, - AllowNewAccounts: false, - MaxPayload: 16 * 1024 * 1024, + ServerName: "monolith", + DontListen: true, + JetStream: true, + StoreDir: string(cfg.StoragePath), + NoSystemAccount: true, + MaxPayload: 16 * 1024 * 1024, }) if err != nil { panic(err) From 89b7519089d4dbebfb5222c7a3e969d6e4786248 Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Thu, 17 Feb 2022 14:15:49 +0100 Subject: [PATCH 66/81] Raise waitTime for network related issues (#2192) --- keyserver/internal/device_list_update.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 1b6e2d428..2f967a40f 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -367,6 +367,9 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime = fcerr.RetryAfter } else if fcerr.Blacklisted { waitTime = time.Hour * 8 + } else { + // For all other errors (DNS resolution, network etc.) wait 1 hour. + waitTime = time.Hour } } else { waitTime = time.Hour From 353168a9e93803bc9c5608d2e0ec55ba7fc581d9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 13:25:41 +0000 Subject: [PATCH 67/81] Fix potential panic in `NewStreamTokenFromString` caused by off-by-one error (#2196) Line 291 could panic when trying to set `positions[i]` if `i == len(positions)`. --- syncapi/types/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 68c308d83..c2e8ed01c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -279,7 +279,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { parts := strings.Split(tok[1:], "_") var positions [7]StreamPosition for i, p := range parts { - if i > len(positions) { + if i >= len(positions) { break } var pos int From 7dfc7c3d7067c0f247f79f299e80e100244f5121 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 13:53:48 +0000 Subject: [PATCH 68/81] Don't re-send sent events in `add_state_events` (#2195) * Only add events to `add_state_events` that haven't already been sent to the roomserver output before * Filter on event NIDs instead, hopefully bring joy to SQLite * UnsentFilter, review comments --- .../internal/input/input_latest_events.go | 2 +- roomserver/storage/postgres/events_table.go | 27 +++++++++++++-- .../storage/shared/membership_updater.go | 6 ++-- roomserver/storage/shared/room_updater.go | 14 ++++++-- roomserver/storage/shared/storage.go | 26 +++++++++++---- roomserver/storage/sqlite3/events_table.go | 33 ++++++++++++++++--- roomserver/storage/tables/interface.go | 1 + 7 files changed, 90 insertions(+), 19 deletions(-) diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 5173d3ab2..ae28ebefa 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -405,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro if len(extraEventIDs) == 0 { return nil, nil } - extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs) + extraEvents, err := u.updater.UnsentEventsFromIDs(u.ctx, extraEventIDs) if err != nil { return nil, err } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c136f039a..8012174a0 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -127,6 +127,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)" @@ -147,6 +150,7 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt + bulkSelectUnsentEventNIDStmt *sql.Stmt selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt } @@ -173,6 +177,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, }.Prepare(db) @@ -458,10 +463,28 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev return results, nil } -// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { - stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { + var stmt *sql.Stmt + if onlyUnsent { + stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt) + } else { + stmt = sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt) + } rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 66ac2f5b6..8f3f3d631 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -136,7 +136,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } // Look up the NID of the new join event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -170,7 +170,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } // Look up the NID of the new leave event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } @@ -196,7 +196,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er } if u.membership != tables.MembershipStateKnock { // Look up the NID of the new knock event - nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}) + nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false) if err != nil { return fmt.Errorf("u.d.EventNIDs: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 89b878b9d..810a18ef2 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -215,7 +215,13 @@ func (u *RoomUpdater) EventIDs( func (u *RoomUpdater) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return u.d.eventNIDs(ctx, u.txn, eventIDs) + return u.d.eventNIDs(ctx, u.txn, eventIDs, NoFilter) +} + +func (u *RoomUpdater) UnsentEventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs, FilterUnsentOnly) } func (u *RoomUpdater) StateAtEventIDs( @@ -231,7 +237,11 @@ func (u *RoomUpdater) StateEntriesForEventIDs( } func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return u.d.eventsFromIDs(ctx, u.txn, eventIDs) + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false) +} + +func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true) } func (u *RoomUpdater) GetMembershipEventNIDsForRoom( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e96c77afa..9f3b8b1da 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -238,13 +238,27 @@ func (d *Database) addState( func (d *Database) EventNIDs( ctx context.Context, eventIDs []string, ) (map[string]types.EventNID, error) { - return d.eventNIDs(ctx, nil, eventIDs) + return d.eventNIDs(ctx, nil, eventIDs, NoFilter) } +type UnsentFilter bool + +const ( + NoFilter UnsentFilter = false + FilterUnsentOnly UnsentFilter = true +) + func (d *Database) eventNIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter, ) (map[string]types.EventNID, error) { - return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + switch filter { + case FilterUnsentOnly: + return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs) + case NoFilter: + return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs) + default: + panic("impossible case") + } } func (d *Database) SetState( @@ -281,11 +295,11 @@ func (d *Database) EventIDs( } func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - return d.eventsFromIDs(ctx, nil, eventIDs) + return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter) } -func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.eventNIDs(ctx, txn, eventIDs) +func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) { + nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter) if err != nil { return nil, err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index cef09fe60..969a10ce5 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,6 +99,9 @@ const bulkSelectEventIDSQL = "" + const bulkSelectEventNIDSQL = "" + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" +const bulkSelectUnsentEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)" + const selectMaxEventDepthSQL = "" + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" @@ -118,8 +121,9 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt - bulkSelectEventNIDStmt *sql.Stmt - //selectRoomNIDsForEventNIDsStmt *sql.Stmt + //bulkSelectEventNIDStmt *sql.Stmt + //bulkSelectUnsentEventNIDStmt *sql.Stmt + //selectRoomNIDsForEventNIDsStmt *sql.Stmt } func createEventsTable(db *sql.DB) error { @@ -144,7 +148,8 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, - {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, }.Prepare(db) } @@ -494,15 +499,33 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev return results, nil } -// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, false) +} + +// BulkSelectEventNIDs returns a map from string event ID to numeric event ID +// only for events that haven't already been sent to the roomserver output. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + return s.bulkSelectEventNID(ctx, txn, eventIDs, true) +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) { /////////////// iEventIDs := make([]interface{}, len(eventIDs)) for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + var selectOrig string + if onlyUnsent { + selectOrig = strings.Replace(bulkSelectUnsentEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } else { + selectOrig = strings.Replace(bulkSelectEventNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + } selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index fed39b944..e3fed700b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -59,6 +59,7 @@ type Events interface { // BulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) + BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) } From 140077265e2842bf8e2d2c6399343490740cd8a6 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 15:02:06 +0000 Subject: [PATCH 69/81] Make GetUserDevices logging entry more useful --- keyserver/internal/device_list_update.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 2f967a40f..c5a5d40c7 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -373,7 +373,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam } } else { waitTime = time.Hour - logger.WithError(err).Warn("GetUserDevices returned unknown error type") + logger.WithError(err).WithField("user_id", userID).Warn("GetUserDevices returned unknown error type") } continue } From 0b123b29f5304603d32d790512c091ac942fb37d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Feb 2022 15:58:54 +0000 Subject: [PATCH 70/81] Use process context for roomserver input (#2198) --- roomserver/internal/api.go | 10 +++++++--- roomserver/internal/input/input.go | 4 +++- roomserver/roomserver.go | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index fd963ad83..e58f11c13 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -32,6 +33,7 @@ type RoomserverInternalAPI struct { *perform.Publisher *perform.Backfiller *perform.Forgetter + ProcessContext *process.ProcessContext DB storage.Database Cfg *config.RoomServer Cache caching.RoomServerCaches @@ -48,12 +50,13 @@ type RoomserverInternalAPI struct { } func NewRoomserverAPI( - cfg *config.RoomServer, roomserverDB storage.Database, consumer nats.JetStreamContext, - inputRoomEventTopic, outputRoomEventTopic string, caches caching.RoomServerCaches, - perspectiveServerNames []gomatrixserverlib.ServerName, + processCtx *process.ProcessContext, cfg *config.RoomServer, roomserverDB storage.Database, + consumer nats.JetStreamContext, inputRoomEventTopic, outputRoomEventTopic string, + caches caching.RoomServerCaches, perspectiveServerNames []gomatrixserverlib.ServerName, ) *RoomserverInternalAPI { serverACLs := acls.NewServerACLs(roomserverDB) a := &RoomserverInternalAPI{ + ProcessContext: processCtx, DB: roomserverDB, Cfg: cfg, Cache: caches, @@ -83,6 +86,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA r.KeyRing = keyRing r.Inputer = &input.Inputer{ + ProcessContext: r.ProcessContext, DB: r.DB, InputRoomEventTopic: r.InputRoomEventTopic, OutputRoomEventTopic: r.OutputRoomEventTopic, diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 5bdec0a24..22e4b67a0 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/prometheus/client_golang/prometheus" @@ -59,6 +60,7 @@ var keyContentFields = map[string]string{ } type Inputer struct { + ProcessContext *process.ProcessContext DB storage.Database JetStream nats.JetStreamContext Durable nats.SubOpt @@ -115,7 +117,7 @@ func (r *Inputer) Start() error { _ = msg.InProgress() // resets the acknowledgement wait timer defer eventsInProgress.Delete(index) defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent) + action, err := r.processRoomEventUsingUpdater(r.ProcessContext.Context(), roomID, &inputRoomEvent) if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { sentry.CaptureException(err) diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index e1b84b80c..950c6b4e7 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -53,7 +53,7 @@ func NewInternalAPI( js := jetstream.Prepare(&cfg.Matrix.JetStream) return internal.NewRoomserverAPI( - cfg, roomserverDB, js, + base.ProcessContext, cfg, roomserverDB, js, cfg.Matrix.JetStream.TopicFor(jetstream.InputRoomEvent), cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), base.Caches, perspectiveServerNames, From 5dd203fde3d3b86719354245ac341dbf67fa1851 Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 17 Feb 2022 17:38:22 +0000 Subject: [PATCH 71/81] Listen for /v3 on CSAPI (#2197) * Listen for /v3 on CSAPI * Docs * More docs * Rename path variable to fix key backup tests * Update routing.go Co-authored-by: Neil Alexander --- clientapi/routing/routing.go | 193 ++++++++++++++++++----------------- 1 file changed, 100 insertions(+), 93 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 732066166..da2ccf2fa 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -117,15 +117,22 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } - r0mux := publicAPIMux.PathPrefix("/r0").Subrouter() + // You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables. + // So make a named path variable called 'apiversion' (which we will never read in handlers) and then do + // (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group + // from a match (gorilla/mux exposes no way to do this) so it demands you make it a non-capturing group + // using ?: so the final regexp becomes what is below. We also need a trailing slash to stop 'v33333' matching. + // Note that 'apiversion' is chosen because it must not collide with a variable used in any of the routing! + v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() + unstableMux := publicAPIMux.PathPrefix("/unstable").Subrouter() - r0mux.Handle("/createRoom", + v3mux.Handle("/createRoom", httputil.MakeAuthAPI("createRoom", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CreateRoom(req, device, cfg, accountDB, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/join/{roomIDOrAlias}", + v3mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -141,7 +148,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { - r0mux.Handle("/peek/{roomIDOrAlias}", + v3mux.Handle("/peek/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -156,12 +163,12 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) } - r0mux.Handle("/joined_rooms", + v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/join", + v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -175,7 +182,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/leave", + v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -189,7 +196,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unpeek", + v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -200,7 +207,7 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/ban", + v3mux.Handle("/rooms/{roomID}/ban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -209,7 +216,7 @@ func Setup( return SendBan(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/invite", + v3mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -221,7 +228,7 @@ func Setup( return SendInvite(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/kick", + v3mux.Handle("/rooms/{roomID}/kick", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -230,7 +237,7 @@ func Setup( return SendKick(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/unban", + v3mux.Handle("/rooms/{roomID}/unban", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -239,7 +246,7 @@ func Setup( return SendUnban(req, accountDB, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -248,7 +255,7 @@ func Setup( return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", + v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -259,7 +266,7 @@ func Setup( nil, cfg, rsAPI, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/event/{eventID}", + v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -269,7 +276,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -277,7 +284,7 @@ func Setup( return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -285,7 +292,7 @@ func Setup( return GetAliases(req, rsAPI, device, vars["roomID"]) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type:[^/]+/?}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -296,7 +303,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -305,7 +312,7 @@ func Setup( return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", + v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -317,7 +324,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", + v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -328,21 +335,21 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + v3mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeExternalAPI("directory_room", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -352,7 +359,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -362,7 +369,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/directory/room/{roomAlias}", + v3mux.Handle("/directory/room/{roomAlias}", httputil.MakeAuthAPI("directory_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -371,7 +378,7 @@ func Setup( return RemoveLocalAlias(req, device, vars["roomAlias"], rsAPI) }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -381,7 +388,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) // TODO: Add AS support - r0mux.Handle("/directory/list/room/{roomID}", + v3mux.Handle("/directory/list/room/{roomID}", httputil.MakeAuthAPI("directory_list", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -390,25 +397,25 @@ func Setup( return SetVisibility(req, rsAPI, device, vars["roomID"]) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/publicRooms", + v3mux.Handle("/publicRooms", httputil.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse { return GetPostPublicRooms(req, rsAPI, extRoomsProvider, federation, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout", + v3mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/logout/all", + v3mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/typing/{userID}", + v3mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -420,7 +427,7 @@ func Setup( return SendTyping(req, device, vars["roomID"], vars["userID"], accountDB, eduAPI, rsAPI) }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}", httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -429,7 +436,7 @@ func Setup( return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", + v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", httputil.MakeAuthAPI("rooms_redact", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -439,7 +446,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/sendToDevice/{eventType}/{txnID}", + v3mux.Handle("/sendToDevice/{eventType}/{txnID}", httputil.MakeAuthAPI("send_to_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -464,7 +471,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/account/whoami", + v3mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -473,7 +480,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/password", + v3mux.Handle("/account/password", httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -482,7 +489,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/account/deactivate", + v3mux.Handle("/account/deactivate", httputil.MakeAuthAPI("deactivate", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -493,7 +500,7 @@ func Setup( // Stub endpoints required by Element - r0mux.Handle("/login", + v3mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -502,14 +509,14 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/auth/{authType}/fallback/web", + v3mux.Handle("/auth/{authType}/fallback/web", httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) - r0mux.Handle("/pushrules/", + v3mux.Handle("/pushrules/", httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse { // TODO: Implement push rules API res := json.RawMessage(`{ @@ -530,7 +537,7 @@ func Setup( // Element user settings - r0mux.Handle("/profile/{userID}", + v3mux.Handle("/profile/{userID}", httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -540,7 +547,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -550,7 +557,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/avatar_url", + v3mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -565,7 +572,7 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -575,7 +582,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/profile/{userID}/displayname", + v3mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -590,13 +597,13 @@ func Setup( // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetAssociated3PIDs(req, accountDB, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/account/3pid", + v3mux.Handle("/account/3pid", httputil.MakeAuthAPI("account_3pid", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return CheckAndSave3PIDAssociation(req, accountDB, device, cfg) }), @@ -608,14 +615,14 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", + v3mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken", httputil.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse { return RequestEmailToken(req, accountDB, cfg) }), ).Methods(http.MethodPost, http.MethodOptions) // Element logs get flooded unless this is handled - r0mux.Handle("/presence/{userID}/status", + v3mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -628,7 +635,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/voip/turnServer", + v3mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -637,7 +644,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/thirdparty/protocols", + v3mux.Handle("/thirdparty/protocols", httputil.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse { // TODO: Return the third party protcols return util.JSONResponse{ @@ -647,7 +654,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/initialSync", + v3mux.Handle("/rooms/{roomID}/initialSync", httputil.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse { // TODO: Allow people to peek into rooms. return util.JSONResponse{ @@ -657,7 +664,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -667,7 +674,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -677,7 +684,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userID}/account_data/{type}", + v3mux.Handle("/user/{userID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -687,7 +694,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + v3mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", httputil.MakeAuthAPI("user_account_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -697,7 +704,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/admin/whois/{userID}", + v3mux.Handle("/admin/whois/{userID}", httputil.MakeAuthAPI("admin_whois", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -707,7 +714,7 @@ func Setup( }), ).Methods(http.MethodGet) - r0mux.Handle("/user/{userID}/openid/request_token", + v3mux.Handle("/user/{userID}/openid/request_token", httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -720,7 +727,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user_directory/search", + v3mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -745,7 +752,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/members", + v3mux.Handle("/rooms/{roomID}/members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -755,7 +762,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/joined_members", + v3mux.Handle("/rooms/{roomID}/joined_members", httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -765,7 +772,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/read_markers", + v3mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeAuthAPI("rooms_read_markers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -778,7 +785,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/forget", + v3mux.Handle("/rooms/{roomID}/forget", httputil.MakeAuthAPI("rooms_forget", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -791,13 +798,13 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/devices", + v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("get_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -807,7 +814,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("device_data", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -817,7 +824,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/devices/{deviceID}", + v3mux.Handle("/devices/{deviceID}", httputil.MakeAuthAPI("delete_device", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -827,14 +834,14 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/delete_devices", + v3mux.Handle("/delete_devices", httputil.MakeAuthAPI("delete_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return DeleteDevices(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) // Stub implementations for sytest - r0mux.Handle("/events", + v3mux.Handle("/events", httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, @@ -844,7 +851,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/initialSync", + v3mux.Handle("/initialSync", httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", @@ -852,7 +859,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", httputil.MakeAuthAPI("get_tags", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -862,7 +869,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("put_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -872,7 +879,7 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) - r0mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", + v3mux.Handle("/user/{userId}/rooms/{roomId}/tags/{tag}", httputil.MakeAuthAPI("delete_tag", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -882,7 +889,7 @@ func Setup( }), ).Methods(http.MethodDelete, http.MethodOptions) - r0mux.Handle("/capabilities", + v3mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r @@ -925,11 +932,11 @@ func Setup( return CreateKeyBackupVersion(req, userAPI, device) }) - r0mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) - r0mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) - r0mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/version/{version}", putBackupKeysVersion).Methods(http.MethodPut) + v3mux.Handle("/room_keys/version/{version}", deleteBackupKeysVersion).Methods(http.MethodDelete) + v3mux.Handle("/room_keys/version", postNewBackupKeysVersion).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/room_keys/version/{version}", getBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/version", getLatestBackupKeysVersion).Methods(http.MethodGet, http.MethodOptions) @@ -1021,9 +1028,9 @@ func Setup( return UploadBackupKeys(req, userAPI, device, version, &keyReq) }) - r0mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", putBackupKeysRoomSession).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys", putBackupKeys).Methods(http.MethodPut) unstableMux.Handle("/room_keys/keys/{roomID}", putBackupKeysRoom).Methods(http.MethodPut) @@ -1051,9 +1058,9 @@ func Setup( return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"]) }) - r0mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/room_keys/keys/{roomID}/{sessionID}", getBackupKeysRoomSession).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys", getBackupKeys).Methods(http.MethodGet, http.MethodOptions) unstableMux.Handle("/room_keys/keys/{roomID}", getBackupKeysRoom).Methods(http.MethodGet, http.MethodOptions) @@ -1071,34 +1078,34 @@ func Setup( return UploadCrossSigningDeviceSignatures(req, keyAPI, device) }) - r0mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) + v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) // Supplying a device ID is deprecated. - r0mux.Handle("/keys/upload/{deviceID}", + v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/upload", + v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/query", + v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/keys/claim", + v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", + v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req); r != nil { return *r From e1eb5807b66940490291983be905f2849539aa7f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 10:12:26 +0000 Subject: [PATCH 72/81] Allow preventing guest registration (#2199) * Allow disabling guest registration separately * Update sample config * Set `guests_disabled` to `true` in the sample config --- clientapi/routing/register.go | 7 +++++++ dendrite-config.yaml | 4 ++++ setup/config/config_clientapi.go | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index fc275a5d1..f73cc662f 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -532,6 +532,13 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.UserInternalAPI, ) util.JSONResponse { + if cfg.RegistrationDisabled || cfg.GuestsDisabled { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("Guest registration is disabled"), + } + } + var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, diff --git a/dendrite-config.yaml b/dendrite-config.yaml index b71e8d845..35f72222e 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -142,6 +142,10 @@ client_api: # using the registration shared secret below. registration_disabled: false + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + # If set, allows registration by anyone who knows the shared secret, regardless of # whether registration is otherwise disabled. registration_shared_secret: "" diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 75f5e3df3..4590e752b 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -18,6 +18,10 @@ type ClientAPI struct { // If set, allows registration by anyone who also has the shared // secret, even if registration is otherwise disabled. RegistrationSharedSecret string `yaml:"registration_shared_secret"` + // If set, prevents guest accounts from being created. Only takes + // effect if registration is enabled, otherwise guests registration + // is forbidden either way. + GuestsDisabled bool `yaml:"guests_disabled"` // Boolean stating whether catpcha registration is enabled // and required From 131bedc1a11135eb1f67a26389fe8f53c82c537d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 10:58:41 +0000 Subject: [PATCH 73/81] Remove event type and state key caches (#2200) * Don't proactively cache event types and state keys when we don't know if the transaction has persisted yet * Remove event type and state key caches altogether --- internal/caching/cache_roomservernids.go | 42 ------------------ internal/caching/caches.go | 14 +++--- internal/caching/impl_inmemorylru.go | 33 +++------------ roomserver/storage/shared/storage.go | 54 +++++------------------- 4 files changed, 22 insertions(+), 121 deletions(-) diff --git a/internal/caching/cache_roomservernids.go b/internal/caching/cache_roomservernids.go index bf4fe85ed..6d413093f 100644 --- a/internal/caching/cache_roomservernids.go +++ b/internal/caching/cache_roomservernids.go @@ -7,14 +7,6 @@ import ( ) const ( - RoomServerStateKeyNIDsCacheName = "roomserver_statekey_nids" - RoomServerStateKeyNIDsCacheMaxEntries = 1024 - RoomServerStateKeyNIDsCacheMutable = false - - RoomServerEventTypeNIDsCacheName = "roomserver_eventtype_nids" - RoomServerEventTypeNIDsCacheMaxEntries = 64 - RoomServerEventTypeNIDsCacheMutable = false - RoomServerRoomIDsCacheName = "roomserver_room_ids" RoomServerRoomIDsCacheMaxEntries = 1024 RoomServerRoomIDsCacheMutable = false @@ -29,44 +21,10 @@ type RoomServerCaches interface { // RoomServerNIDsCache contains the subset of functions needed for // a roomserver NID cache. type RoomServerNIDsCache interface { - GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) - StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) - - GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) - StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) - GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) StoreRoomServerRoomID(roomNID types.RoomNID, roomID string) } -func (c Caches) GetRoomServerStateKeyNID(stateKey string) (types.EventStateKeyNID, bool) { - val, found := c.RoomServerStateKeyNIDs.Get(stateKey) - if found && val != nil { - if stateKeyNID, ok := val.(types.EventStateKeyNID); ok { - return stateKeyNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerStateKeyNID(stateKey string, nid types.EventStateKeyNID) { - c.RoomServerStateKeyNIDs.Set(stateKey, nid) -} - -func (c Caches) GetRoomServerEventTypeNID(eventType string) (types.EventTypeNID, bool) { - val, found := c.RoomServerEventTypeNIDs.Get(eventType) - if found && val != nil { - if eventTypeNID, ok := val.(types.EventTypeNID); ok { - return eventTypeNID, true - } - } - return 0, false -} - -func (c Caches) StoreRoomServerEventTypeNID(eventType string, nid types.EventTypeNID) { - c.RoomServerEventTypeNIDs.Set(eventType, nid) -} - func (c Caches) GetRoomServerRoomID(roomNID types.RoomNID) (string, bool) { val, found := c.RoomServerRoomIDs.Get(strconv.Itoa(int(roomNID))) if found && val != nil { diff --git a/internal/caching/caches.go b/internal/caching/caches.go index f04d05d42..e1642a663 100644 --- a/internal/caching/caches.go +++ b/internal/caching/caches.go @@ -4,14 +4,12 @@ package caching // different implementations as long as they satisfy the Cache // interface. type Caches struct { - RoomVersions Cache // RoomVersionCache - ServerKeys Cache // ServerKeyCache - RoomServerStateKeyNIDs Cache // RoomServerNIDsCache - RoomServerEventTypeNIDs Cache // RoomServerNIDsCache - RoomServerRoomNIDs Cache // RoomServerNIDsCache - RoomServerRoomIDs Cache // RoomServerNIDsCache - RoomInfos Cache // RoomInfoCache - FederationEvents Cache // FederationEventsCache + RoomVersions Cache // RoomVersionCache + ServerKeys Cache // ServerKeyCache + RoomServerRoomNIDs Cache // RoomServerNIDsCache + RoomServerRoomIDs Cache // RoomServerNIDsCache + RoomInfos Cache // RoomInfoCache + FederationEvents Cache // FederationEventsCache } // Cache is the interface that an implementation must satisfy. diff --git a/internal/caching/impl_inmemorylru.go b/internal/caching/impl_inmemorylru.go index f0915d7ca..ccb92852b 100644 --- a/internal/caching/impl_inmemorylru.go +++ b/internal/caching/impl_inmemorylru.go @@ -28,24 +28,6 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { if err != nil { return nil, err } - roomServerStateKeyNIDs, err := NewInMemoryLRUCachePartition( - RoomServerStateKeyNIDsCacheName, - RoomServerStateKeyNIDsCacheMutable, - RoomServerStateKeyNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } - roomServerEventTypeNIDs, err := NewInMemoryLRUCachePartition( - RoomServerEventTypeNIDsCacheName, - RoomServerEventTypeNIDsCacheMutable, - RoomServerEventTypeNIDsCacheMaxEntries, - enablePrometheus, - ) - if err != nil { - return nil, err - } roomServerRoomIDs, err := NewInMemoryLRUCachePartition( RoomServerRoomIDsCacheName, RoomServerRoomIDsCacheMutable, @@ -74,18 +56,15 @@ func NewInMemoryLRUCache(enablePrometheus bool) (*Caches, error) { return nil, err } go cacheCleaner( - roomVersions, serverKeys, roomServerStateKeyNIDs, - roomServerEventTypeNIDs, roomServerRoomIDs, + roomVersions, serverKeys, roomServerRoomIDs, roomInfos, federationEvents, ) return &Caches{ - RoomVersions: roomVersions, - ServerKeys: serverKeys, - RoomServerStateKeyNIDs: roomServerStateKeyNIDs, - RoomServerEventTypeNIDs: roomServerEventTypeNIDs, - RoomServerRoomIDs: roomServerRoomIDs, - RoomInfos: roomInfos, - FederationEvents: federationEvents, + RoomVersions: roomVersions, + ServerKeys: serverKeys, + RoomServerRoomIDs: roomServerRoomIDs, + RoomInfos: roomInfos, + FederationEvents: federationEvents, }, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9f3b8b1da..b255cfb3f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -59,23 +59,12 @@ func (d *Database) eventTypeNIDs( ctx context.Context, txn *sql.Tx, eventTypes []string, ) (map[string]types.EventTypeNID, error) { result := make(map[string]types.EventTypeNID) - remaining := []string{} - for _, eventType := range eventTypes { - if nid, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - result[eventType] = nid - } else { - remaining = append(remaining, eventType) - } + nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, eventTypes) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining) - if err != nil { - return nil, err - } - for eventType, nid := range nids { - result[eventType] = nid - d.Cache.StoreRoomServerEventTypeNID(eventType, nid) - } + for eventType, nid := range nids { + result[eventType] = nid } return result, nil } @@ -96,23 +85,12 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) - remaining := []string{} - for _, eventStateKey := range eventStateKeys { - if nid, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - result[eventStateKey] = nid - } else { - remaining = append(remaining, eventStateKey) - } + nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + if err != nil { + return nil, err } - if len(remaining) > 0 { - nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining) - if err != nil { - return nil, err - } - for eventStateKey, nid := range nids { - result[eventStateKey] = nid - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, nid) - } + for eventStateKey, nid := range nids { + result[eventStateKey] = nid } return result, nil } @@ -718,9 +696,6 @@ func (d *Database) assignRoomNID( func (d *Database) assignEventTypeNID( ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { - if eventTypeNID, ok := d.Cache.GetRoomServerEventTypeNID(eventType); ok { - return eventTypeNID, nil - } // Check if we already have a numeric ID in the database. eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) if err == sql.ErrNoRows { @@ -731,18 +706,12 @@ func (d *Database) assignEventTypeNID( eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType) } } - if err == nil { - d.Cache.StoreRoomServerEventTypeNID(eventType, eventTypeNID) - } return eventTypeNID, err } func (d *Database) assignStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - if eventStateKeyNID, ok := d.Cache.GetRoomServerStateKeyNID(eventStateKey); ok { - return eventStateKeyNID, nil - } // Check if we already have a numeric ID in the database. eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) if err == sql.ErrNoRows { @@ -753,9 +722,6 @@ func (d *Database) assignStateKeyNID( eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey) } } - if err == nil { - d.Cache.StoreRoomServerStateKeyNID(eventStateKey, eventStateKeyNID) - } return eventStateKeyNID, err } From 0a7dea44505f703af1e7e069602ca95aa5a83700 Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 18 Feb 2022 11:28:02 +0000 Subject: [PATCH 74/81] Update /whoami response to match Spec v1.2 (#2201) Basically include `is_guest` and `device_id`. The latter is needed for https://github.com/matrix-org/complement/pull/305 --- clientapi/routing/whoami.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/whoami.go b/clientapi/routing/whoami.go index 26280f6cc..a1d9d6675 100644 --- a/clientapi/routing/whoami.go +++ b/clientapi/routing/whoami.go @@ -21,7 +21,9 @@ import ( // whoamiResponse represents an response for a `whoami` request type whoamiResponse struct { - UserID string `json:"user_id"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + IsGuest bool `json:"is_guest"` } // Whoami implements `/account/whoami` which enables client to query their account user id. @@ -29,6 +31,10 @@ type whoamiResponse struct { func Whoami(req *http.Request, device *api.Device) util.JSONResponse { return util.JSONResponse{ Code: http.StatusOK, - JSON: whoamiResponse{UserID: device.UserID}, + JSON: whoamiResponse{ + UserID: device.UserID, + DeviceID: device.ID, + IsGuest: device.AccountType == api.AccountTypeGuest, + }, } } From 153bfbbea579dfa10e8e804036f17c1a33b6fe80 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 11:31:05 +0000 Subject: [PATCH 75/81] Merge both user API databases into one (#2186) * Merge user API databases into one * Remove DeviceDatabase from config * Fix tests * Try that again * Clean up keyserver device keys when the devices no longer exist in the user API * Tweak ordering * Fix UserExists flag, device check * Allow including empty entries so we can clean them up * Remove logging --- appservice/api/query.go | 4 +- build/gobind-pinecone/monolith.go | 2 - build/gobind-yggdrasil/monolith.go | 1 - clientapi/clientapi.go | 4 +- clientapi/routing/createroom.go | 6 +- clientapi/routing/joinroom.go | 4 +- clientapi/routing/key_crosssigning.go | 4 +- clientapi/routing/login.go | 4 +- clientapi/routing/membership.go | 18 +- clientapi/routing/password.go | 4 +- clientapi/routing/peekroom.go | 6 +- clientapi/routing/profile.go | 14 +- clientapi/routing/register.go | 6 +- clientapi/routing/routing.go | 4 +- clientapi/routing/sendtyping.go | 4 +- clientapi/routing/threepid.go | 12 +- clientapi/threepid/invites.go | 8 +- cmd/create-account/main.go | 13 +- cmd/dendrite-demo-libp2p/main.go | 1 - cmd/dendrite-demo-pinecone/main.go | 1 - cmd/dendrite-demo-yggdrasil/main.go | 1 - cmd/dendritejs-pinecone/main.go | 1 - cmd/dendritejs/main.go | 1 - cmd/generate-config/main.go | 1 - cmd/goose/main.go | 24 +- internal/test/config.go | 1 - keyserver/internal/internal.go | 82 ++++-- keyserver/storage/interface.go | 2 +- .../storage/postgres/device_keys_table.go | 33 ++- keyserver/storage/shared/storage.go | 4 +- .../storage/sqlite3/device_keys_table.go | 31 +- keyserver/storage/storage_test.go | 2 +- keyserver/storage/tables/interface.go | 2 +- setup/base/base.go | 12 +- setup/config/config_userapi.go | 6 - setup/monolith.go | 4 +- userapi/api/api_logintoken.go | 7 + userapi/internal/api.go | 67 +++-- userapi/internal/api_logintoken.go | 8 +- userapi/storage/devices/interface.go | 52 ---- userapi/storage/devices/postgres/storage.go | 270 ----------------- userapi/storage/devices/sqlite3/storage.go | 271 ------------------ userapi/storage/devices/storage.go | 42 --- userapi/storage/devices/storage_wasm.go | 39 --- userapi/storage/{accounts => }/interface.go | 31 +- .../postgres/account_data_table.go | 0 .../{accounts => }/postgres/accounts_table.go | 0 .../deltas/20200929203058_is_active.go | 0 .../deltas/20201001204705_last_seen_ts_ip.go | 5 - .../2022021013023800_add_account_type.go | 0 .../{devices => }/postgres/devices_table.go | 3 + .../postgres/key_backup_table.go | 0 .../postgres/key_backup_version_table.go | 0 .../postgres/logintoken_table.go | 3 + .../{accounts => }/postgres/openid_table.go | 0 .../{accounts => }/postgres/profile_table.go | 0 .../{accounts => }/postgres/storage.go | 216 +++++++++++++- .../{accounts => }/postgres/threepid_table.go | 0 .../sqlite3/account_data_table.go | 0 .../{accounts => }/sqlite3/accounts_table.go | 0 .../{accounts => }/sqlite3/constraint_wasm.go | 0 .../deltas/20200929203058_is_active.go | 0 .../deltas/20201001204705_last_seen_ts_ip.go | 5 - .../2022021012490600_add_account_type.go | 0 .../{devices => }/sqlite3/devices_table.go | 3 + .../sqlite3/key_backup_table.go | 0 .../sqlite3/key_backup_version_table.go | 0 .../{devices => }/sqlite3/logintoken_table.go | 3 + .../{accounts => }/sqlite3/openid_table.go | 0 .../{accounts => }/sqlite3/profile_table.go | 0 .../storage/{accounts => }/sqlite3/storage.go | 216 +++++++++++++- .../{accounts => }/sqlite3/threepid_table.go | 0 userapi/storage/{accounts => }/storage.go | 13 +- .../storage/{accounts => }/storage_wasm.go | 8 +- userapi/userapi.go | 22 +- userapi/userapi_test.go | 15 +- 76 files changed, 727 insertions(+), 899 deletions(-) delete mode 100644 userapi/storage/devices/interface.go delete mode 100644 userapi/storage/devices/postgres/storage.go delete mode 100644 userapi/storage/devices/sqlite3/storage.go delete mode 100644 userapi/storage/devices/storage.go delete mode 100644 userapi/storage/devices/storage_wasm.go rename userapi/storage/{accounts => }/interface.go (67%) rename userapi/storage/{accounts => }/postgres/account_data_table.go (100%) rename userapi/storage/{accounts => }/postgres/accounts_table.go (100%) rename userapi/storage/{accounts => }/postgres/deltas/20200929203058_is_active.go (100%) rename userapi/storage/{devices => }/postgres/deltas/20201001204705_last_seen_ts_ip.go (89%) rename userapi/storage/{accounts => }/postgres/deltas/2022021013023800_add_account_type.go (100%) rename userapi/storage/{devices => }/postgres/devices_table.go (99%) rename userapi/storage/{accounts => }/postgres/key_backup_table.go (100%) rename userapi/storage/{accounts => }/postgres/key_backup_version_table.go (100%) rename userapi/storage/{devices => }/postgres/logintoken_table.go (98%) rename userapi/storage/{accounts => }/postgres/openid_table.go (100%) rename userapi/storage/{accounts => }/postgres/profile_table.go (100%) rename userapi/storage/{accounts => }/postgres/storage.go (70%) rename userapi/storage/{accounts => }/postgres/threepid_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/account_data_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/accounts_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/constraint_wasm.go (100%) rename userapi/storage/{accounts => }/sqlite3/deltas/20200929203058_is_active.go (100%) rename userapi/storage/{devices => }/sqlite3/deltas/20201001204705_last_seen_ts_ip.go (94%) rename userapi/storage/{accounts => }/sqlite3/deltas/2022021012490600_add_account_type.go (100%) rename userapi/storage/{devices => }/sqlite3/devices_table.go (99%) rename userapi/storage/{accounts => }/sqlite3/key_backup_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/key_backup_version_table.go (100%) rename userapi/storage/{devices => }/sqlite3/logintoken_table.go (98%) rename userapi/storage/{accounts => }/sqlite3/openid_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/profile_table.go (100%) rename userapi/storage/{accounts => }/sqlite3/storage.go (71%) rename userapi/storage/{accounts => }/sqlite3/threepid_table.go (100%) rename userapi/storage/{accounts => }/storage.go (81%) rename userapi/storage/{accounts => }/storage_wasm.go (87%) diff --git a/appservice/api/query.go b/appservice/api/query.go index cd74d866c..e53ad4259 100644 --- a/appservice/api/query.go +++ b/appservice/api/query.go @@ -23,7 +23,7 @@ import ( "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -85,7 +85,7 @@ func RetrieveUserProfile( ctx context.Context, userID string, asAPI AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) (*authtypes.Profile, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 211b8d653..acf4406ca 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -283,8 +283,6 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix)) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 3d9ba8aa0..8b9c88f2a 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory)) diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index d678ada96..a65f3b70d 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -28,7 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -37,7 +37,7 @@ func AddPublicRoutes( router *mux.Router, synapseAdminRouter *mux.Router, cfg *config.ClientAPI, - accountsDB accounts.Database, + accountsDB userdb.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, eduInputAPI eduServerAPI.EDUServerInputAPI, diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index e89d8ff24..80ac22935 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -137,7 +137,7 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { // TODO (#267): Check room ID doesn't clash with an existing one, and we @@ -151,7 +151,7 @@ func CreateRoom( func createRoom( req *http.Request, device *api.Device, cfg *config.ClientAPI, roomID string, - accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { logger := util.GetLogger(req.Context()) diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 578aaec56..d30a87a57 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // Prepare to ask the roomserver to perform the room join. diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7b9d8acd2..7ecab9d4e 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -24,7 +24,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -36,7 +36,7 @@ type crossSigningRequest struct { func UploadCrossSigningDeviceKeys( req *http.Request, userInteractiveAuth *auth.UserInteractive, keyserverAPI api.KeyInternalAPI, device *userapi.Device, - accountDB accounts.Database, cfg *config.ClientAPI, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { uploadReq := &crossSigningRequest{} uploadRes := &api.PerformUploadDeviceKeysResponse{} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index b48b9e93b..ec5c998be 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -54,7 +54,7 @@ func passwordLogin() flows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, + req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 58f187608..112239241 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -30,7 +30,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -39,7 +39,7 @@ import ( var errMissingUserID = errors.New("'user_id' must be supplied") func SendBan( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -81,7 +81,7 @@ func SendBan( return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI) } -func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device, +func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device, roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time, roomVer gomatrixserverlib.RoomVersion, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse { @@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us } func SendKick( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -165,7 +165,7 @@ func SendKick( } func SendUnban( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -200,7 +200,7 @@ func SendUnban( } func SendInvite( - req *http.Request, accountDB accounts.Database, device *userapi.Device, + req *http.Request, accountDB userdb.Database, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { @@ -271,7 +271,7 @@ func SendInvite( func buildMembershipEvent( ctx context.Context, - targetUserID, reason string, accountDB accounts.Database, + targetUserID, reason string, accountDB userdb.Database, device *userapi.Device, membership, roomID string, isDirect bool, cfg *config.ClientAPI, evTime time.Time, @@ -312,7 +312,7 @@ func loadProfile( ctx context.Context, userID string, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) @@ -366,7 +366,7 @@ func checkAndProcessThreepid( body *threepid.MembershipRequest, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, evTime time.Time, ) (inviteStored bool, errRes *util.JSONResponse) { diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index b24424430..499510193 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -9,7 +9,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -29,7 +29,7 @@ type newPasswordAuth struct { func Password( req *http.Request, userAPI api.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 26aa64ce1..8f89e97f4 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -19,7 +19,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomIDOrAlias string, ) util.JSONResponse { // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to @@ -82,7 +82,7 @@ func UnpeekRoomByID( req *http.Request, device *api.Device, rsAPI roomserverAPI.RoomserverInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, roomID string, ) util.JSONResponse { unpeekReq := roomserverAPI.PerformUnpeekRequest{ diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 017facd20..717cbda75 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -27,7 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrix" @@ -36,7 +36,7 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, @@ -65,7 +65,7 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -92,7 +92,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -182,7 +182,7 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI, + req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -209,7 +209,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, accountDB accounts.Database, + req *http.Request, accountDB userdb.Database, device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { if userID != device.UserID { @@ -302,7 +302,7 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // eventutil.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI, + ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index f73cc662f..d00d9886e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -44,7 +44,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) var ( @@ -448,7 +448,7 @@ func validateApplicationService( func Register( req *http.Request, userAPI userapi.UserInternalAPI, - accountDB accounts.Database, + accountDB userdb.Database, cfg *config.ClientAPI, ) util.JSONResponse { var r registerRequest @@ -899,7 +899,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.ClientAPI, - accountDB accounts.Database, + accountDB userdb.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index da2ccf2fa..63dcaa413 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -34,7 +34,7 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -51,7 +51,7 @@ func Setup( eduAPI eduServerAPI.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, - accountDB accounts.Database, + accountDB userdb.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 3abf3db27..fd214b34b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -20,7 +20,7 @@ import ( "github.com/matrix-org/dendrite/eduserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/util" ) @@ -33,7 +33,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *userapi.Device, roomID string, - userID string, accountDB accounts.Database, + userID string, accountDB userdb.Database, eduAPI api.EDUServerInputAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ) util.JSONResponse { diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index f4d233798..d89b62953 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/threepid" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -40,7 +40,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf Code: http.StatusBadRequest, JSON: jsonerror.MatrixError{ ErrCode: "M_THREEPID_IN_USE", - Err: accounts.Err3PIDInUse.Error(), + Err: userdb.Err3PIDInUse.Error(), }, } } @@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, cfg *config.ClientAPI, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, accountDB userdb.Database, device *api.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -170,7 +170,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { +func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index db62ce060..9d9a2ba7a 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,7 +87,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI, - rsAPI api.RoomserverInternalAPI, db accounts.Database, + rsAPI api.RoomserverInternalAPI, db userdb.Database, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db accounts.Database, cfg *config.ClientAPI, device *userapi.Device, + db userdb.Database, cfg *config.ClientAPI, device *userapi.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index d9202eb0d..3003896c8 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" ) const usage = `Usage: %s @@ -77,9 +77,14 @@ func main() { pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, - }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) + accountDB, err := userdb.NewDatabase( + &config.DatabaseOptions{ + ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, + }, + cfg.Global.ServerName, bcrypt.DefaultCost, + cfg.UserAPI.OpenIDTokenLifetimeMS, + api.DefaultLoginTokenLifetime, + ) if err != nil { logrus.Fatalln("Failed to connect to the database:", err.Error()) } diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 7cbd0b6d4..78536901c 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -126,7 +126,6 @@ func main() { cfg.FederationAPI.FederationMaxRetries = 6 cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a897dcd1a..5810a7f18 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -160,7 +160,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 52e69ee59..49e096bd1 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -79,7 +79,6 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName)) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName)) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName)) diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go index 62eea78f2..664f644f3 100644 --- a/cmd/dendritejs-pinecone/main.go +++ b/cmd/dendritejs-pinecone/main.go @@ -164,7 +164,6 @@ func startup() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index 59de07cd0..0ea41b4c4 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -167,7 +167,6 @@ func main() { cfg.Defaults(true) cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db" cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db" - cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db" cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db" cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db" cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db" diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index f87665fbe..ba5a87a7a 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -32,7 +32,6 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI) cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI) } cfg.Global.TrustedIDServers = []string{ "matrix.org", diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 8ed5cbd91..31a5b0050 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -8,12 +8,11 @@ import ( "log" "os" - pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" - slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" "github.com/pressly/goose" + pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -26,8 +25,7 @@ const ( RoomServer = "roomserver" SigningKeyServer = "signingkeyserver" SyncAPI = "syncapi" - UserAPIAccounts = "userapi_accounts" - UserAPIDevices = "userapi_devices" + UserAPI = "userapi" ) var ( @@ -35,7 +33,7 @@ var ( flags = flag.NewFlagSet("goose", flag.ExitOnError) component = flags.String("component", "", "dendrite component name") knownDBs = []string{ - AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices, + AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI, } ) @@ -143,18 +141,14 @@ Commands: func loadSQLiteDeltas(component string) { switch component { - case UserAPIAccounts: - slaccounts.LoadFromGoose() - case UserAPIDevices: - sldevices.LoadFromGoose() + case UserAPI: + slusers.LoadFromGoose() } } func loadPostgresDeltas(component string) { switch component { - case UserAPIAccounts: - pgaccounts.LoadFromGoose() - case UserAPIDevices: - pgdevices.LoadFromGoose() + case UserAPI: + pgusers.LoadFromGoose() } } diff --git a/internal/test/config.go b/internal/test/config.go index 4fb6a946c..0372fb9c6 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con cfg.RoomServer.Database.ConnectionString = config.DataSource(database) cfg.SyncAPI.Database.ConnectionString = config.DataSource(database) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database) - cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database) cfg.AppServiceAPI.InternalAPI.Listen = assignAddress() cfg.EDUServer.InternalAPI.Listen = assignAddress() diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index ffbcac94b..1c6b06776 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne } func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil) + msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), @@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domain := string(serverName) // query local devices if serverName == a.ThisServer { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), @@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) // if we can't query the db or there are fewer keys than requested, fetch from remote. if err != nil { return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) @@ -554,10 +554,60 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( } func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &userapi.QueryDevicesResponse{} + if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()), + } + return + } + logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) if err != nil { continue // ignore invalid users } @@ -568,6 +618,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { @@ -583,29 +638,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per }) } - // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceMessage, len(keysToStore)) - for i := range keysToStore { - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, - }, - } - } - if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } if req.OnlyDisplayNameUpdates { // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } // store the device keys and emit changes - err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) + err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0110860ea..4dffe695c 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -53,7 +53,7 @@ type Database interface { // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying // cross-signing signatures relating to that device. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 5ae0da969..628301cf7 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5914d28e1..deee76eb4 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) } func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index fa1c930db..b461424c6 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" + "DELETE FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { + return nil, err + } if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } @@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql return err } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true } - rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) if err != nil { return nil, err } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index c4c99d8c4..4d5137249 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false) if err != nil { t.Fatalf("DeviceKeysForUser returned error: %s", err) } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index e44757e1a..ff70a2366 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -38,7 +38,7 @@ type DeviceKeys interface { InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error } diff --git a/setup/base/base.go b/setup/base/base.go index 819fe1ad4..e39977541 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -38,7 +38,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" @@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. -func (b *BaseDendrite) CreateAccountsDB() accounts.Database { - db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS) +func (b *BaseDendrite) CreateAccountsDB() userdb.Database { + db, err := userdb.NewDatabase( + &b.Cfg.UserAPI.AccountDatabase, + b.Cfg.Global.ServerName, + b.Cfg.UserAPI.BCryptCost, + b.Cfg.UserAPI.OpenIDTokenLifetimeMS, + userapi.DefaultLoginTokenLifetime, + ) if err != nil { logrus.WithError(err).Panicf("failed to connect to accounts db") } diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go index b2cde2e96..1cb5eba18 100644 --- a/setup/config/config_userapi.go +++ b/setup/config/config_userapi.go @@ -16,9 +16,6 @@ type UserAPI struct { // The Account database stores the login details and account information // for local users. It is accessed by the UserAPI. AccountDatabase DatabaseOptions `yaml:"account_database"` - // The Device database stores session information for the devices of logged - // in local users. It is accessed by the UserAPI. - DeviceDatabase DatabaseOptions `yaml:"device_database"` } const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes @@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) { c.InternalAPI.Listen = "http://localhost:7781" c.InternalAPI.Connect = "http://localhost:7781" c.AccountDatabase.Defaults(10) - c.DeviceDatabase.Defaults(10) if generate { c.AccountDatabase.ConnectionString = "file:userapi_accounts.db" - c.DeviceDatabase.ConnectionString = "file:userapi_devices.db" } c.BCryptCost = bcrypt.DefaultCost c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS @@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen)) checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect)) checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString)) - checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString)) checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS) } diff --git a/setup/monolith.go b/setup/monolith.go index e6c955222..61125e4a9 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -30,7 +30,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" + userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -38,7 +38,7 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - AccountDB accounts.Database + AccountDB userdb.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client FedClient *gomatrixserverlib.FederationClient diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go index f3aa037e4..e2207bb53 100644 --- a/userapi/api/api_logintoken.go +++ b/userapi/api/api_logintoken.go @@ -19,6 +19,13 @@ import ( "time" ) +// DefaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const DefaultLoginTokenLifetime = 2 * time.Minute + type LoginTokenInternalAPI interface { // PerformLoginTokenCreation creates a new login token and associates it with the provided data. PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f96d4804c..f54cc6137 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -31,13 +31,11 @@ import ( keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" ) type UserInternalAPI struct { - AccountDB accounts.Database - DeviceDB devices.Database + DB storage.Database ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService @@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) + return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { - acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { + if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { return err } @@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { return err } res.PasswordUpdated = true @@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) } if err != nil { return err @@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil @@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if domain != a.ServerName { return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) } - prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { if err == sql.ErrNoRows { return nil @@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil } func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit) + profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit) if err != nil { return err } @@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer } func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error { - devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs) + devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs) if err != nil { return err } @@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if domain != a.ServerName { return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) } - devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { return err } + res.UserExists = true res.Devices = devs return nil } @@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } @@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.AccountDB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local) if err != nil { return err } @@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } - device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) + device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken) if err != nil { if err == sql.ErrNoRows { return nil @@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if err != nil { return err } - acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err } @@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - err := a.AccountDB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart) res.AccountDeactivated = err == nil return err } @@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error { token := util.RandomString(24) - exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID) + exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID) res.Token = api.OpenIDToken{ Token: token, @@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a // QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token) + openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token) if err != nil { return err } @@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } return nil } - exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) + exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { res.Error = fmt.Sprintf("failed to delete backup: %s", err) } @@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Create metadata if req.Version == "" { - version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) + version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to create backup: %s", err) } @@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } // Update metadata if len(req.Keys.Rooms) == 0 { - err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) + err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData) if err != nil { res.Error = fmt.Sprintf("failed to update backup: %s", err) } @@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { // you can only upload keys for the CURRENT version - version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "") + version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "") if err != nil { res.Error = fmt.Sprintf("failed to query version: %s", err) return @@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform }) } } - count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads) + count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads) if err != nil { res.Error = fmt.Sprintf("failed to upsert keys: %s", err) return @@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform } func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version) + version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { @@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB res.Exists = !deleted if !req.ReturnKeys { - res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID) + res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID) if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } return } - result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) + result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) return diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 86ffc58f3..f1bf391e4 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if domain != a.ServerName { return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) } - tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { return err } @@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap // PerformLoginTokenDeletion ensures the token doesn't exist. func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { util.GetLogger(ctx).Info("PerformLoginTokenDeletion") - return a.DeviceDB.RemoveLoginToken(ctx, req.Token) + return a.DB.RemoveLoginToken(ctx, req.Token) } // QueryLoginToken returns the data associated with a login token. If // the token is not valid, success is returned, but res.Data == nil. func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { - tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token) + tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token) if err != nil { res.Data = nil if err == sql.ErrNoRows { @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if domain != a.ServerName { return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) } - if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go deleted file mode 100644 index 8ff91cf1c..000000000 --- a/userapi/storage/devices/interface.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "context" - - "github.com/matrix-org/dendrite/userapi/api" -) - -type Database interface { - GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) - GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - // CreateDevice makes a new device associated with the given user ID localpart. - // If there is already a device with the same device ID for this user, that access token will be revoked - // and replaced with the given accessToken. If the given accessToken is already in use for another device, - // an error will be returned. - // If no device ID is given one is generated. - // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error - // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) - - // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. - CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) - - // RemoveLoginToken removes the named token (and may clean up other expired tokens). - RemoveLoginToken(ctx context.Context, token string) error - - // GetLoginTokenDataByToken returns the data associated with the given token. - // May return sql.ErrNoRows. - GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) -} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go deleted file mode 100644 index fd9d513f1..000000000 --- a/userapi/storage/devices/postgres/storage.go +++ /dev/null @@ -1,270 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - - if err = d.prepare(db, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - - return &Database{db, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go deleted file mode 100644 index 6e90413be..000000000 --- a/userapi/storage/devices/sqlite3/storage.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - - loginTokenByteLength = 32 -) - -// Database represents a device database. -type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration -} - -// NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { - db, err := sqlutil.Open(dbProperties) - if err != nil { - return nil, err - } - writer := sqlutil.NewExclusiveWriter() - var d devicesStatements - var lt loginTokenStatements - - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.execSchema(db); err != nil { - return nil, err - } - if err = lt.execSchema(db); err != nil { - return nil, err - } - - m := sqlutil.NewMigrations() - deltas.LoadLastSeenTSIP(m) - if err = m.RunDeltas(db, dbProperties); err != nil { - return nil, err - } - if err = d.prepare(db, writer, serverName); err != nil { - return nil, err - } - if err = lt.prepare(db); err != nil { - return nil, err - } - return &Database{db, writer, d, lt, loginTokenLifetime}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() - if err != nil { - return nil, err - } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) -} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go deleted file mode 100644 index 15cf8150c..000000000 --- a/userapi/storage/devices/storage.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters. loginTokenLifetime determines how long a -// login token from CreateLoginToken is valid. -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go deleted file mode 100644 index 3de7880b9..000000000 --- a/userapi/storage/devices/storage_wasm.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package devices - -import ( - "fmt" - "time" - - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" - "github.com/matrix-org/gomatrixserverlib" -) - -func NewDatabase( - dbProperties *config.DatabaseOptions, - serverName gomatrixserverlib.ServerName, - loginTokenLifetime time.Duration, -) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go similarity index 67% rename from userapi/storage/accounts/interface.go rename to userapi/storage/interface.go index a2185774a..a131dac47 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/interface.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "context" @@ -60,6 +60,35 @@ type Database interface { UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) + + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenDataByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go similarity index 100% rename from userapi/storage/accounts/postgres/account_data_table.go rename to userapi/storage/postgres/account_data_table.go diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go similarity index 100% rename from userapi/storage/accounts/postgres/accounts_table.go rename to userapi/storage/postgres/accounts_table.go diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go similarity index 100% rename from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go rename to userapi/storage/postgres/deltas/20200929203058_is_active.go diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go similarity index 89% rename from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go index 290f854c8..1bbb0a9d3 100644 --- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go similarity index 100% rename from userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go rename to userapi/storage/postgres/deltas/2022021013023800_add_account_type.go diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go similarity index 99% rename from userapi/storage/devices/postgres/devices_table.go rename to userapi/storage/postgres/devices_table.go index 7de9f5f9e..64cc0b71a 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -117,6 +117,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error { } func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + if err = s.execSchema(db); err != nil { + return + } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go similarity index 100% rename from userapi/storage/accounts/postgres/key_backup_table.go rename to userapi/storage/postgres/key_backup_table.go diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go similarity index 100% rename from userapi/storage/accounts/postgres/key_backup_version_table.go rename to userapi/storage/postgres/key_backup_version_table.go diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go similarity index 98% rename from userapi/storage/devices/postgres/logintoken_table.go rename to userapi/storage/postgres/logintoken_table.go index f601fc7db..508a68989 100644 --- a/userapi/storage/devices/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp // prepare runs statement preparation. func (s *loginTokenStatements) prepare(db *sql.DB) error { + if err := s.execSchema(db); err != nil { + return err + } return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go similarity index 100% rename from userapi/storage/accounts/postgres/openid_table.go rename to userapi/storage/postgres/openid_table.go diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go similarity index 100% rename from userapi/storage/accounts/postgres/profile_table.go rename to userapi/storage/postgres/profile_table.go diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/postgres/storage.go similarity index 70% rename from userapi/storage/accounts/postgres/storage.go rename to userapi/storage/postgres/storage.go index d31efd257..734192798 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -16,7 +16,9 @@ package postgres import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -30,7 +32,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" // Import the postgres database driver. _ "github.com/lib/pq" @@ -47,14 +49,23 @@ type Database struct { threepids threepidStatements openIDTokens tokenStatements keyBackupVersions keyBackupVersionStatements + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 } +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err @@ -63,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver serverName: serverName, db: db, writer: sqlutil.NewDummyWriter(), + loginTokenLifetime: loginTokenLifetime, bcryptCost: bcryptCost, openIDTokenLifetimeMS: openIDTokenLifetimeMS, } @@ -74,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err @@ -103,6 +116,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.keyBackups.prepare(db); err != nil { return nil, err } + if err = d.devices.prepare(db, serverName); err != nil { + return nil, err + } + if err = d.loginTokens.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -515,3 +534,196 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go similarity index 100% rename from userapi/storage/accounts/postgres/threepid_table.go rename to userapi/storage/postgres/threepid_table.go diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/account_data_table.go rename to userapi/storage/sqlite3/account_data_table.go diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/accounts_table.go rename to userapi/storage/sqlite3/accounts_table.go diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go similarity index 100% rename from userapi/storage/accounts/sqlite3/constraint_wasm.go rename to userapi/storage/sqlite3/constraint_wasm.go diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go similarity index 100% rename from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go rename to userapi/storage/sqlite3/deltas/20200929203058_is_active.go diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go similarity index 94% rename from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go rename to userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index 262098265..ebf908001 100644 --- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -5,13 +5,8 @@ import ( "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/pressly/goose" ) -func LoadFromGoose() { - goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) -} - func LoadLastSeenTSIP(m *sqlutil.Migrations) { m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) } diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go similarity index 100% rename from userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go rename to userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go similarity index 99% rename from userapi/storage/devices/sqlite3/devices_table.go rename to userapi/storage/sqlite3/devices_table.go index 955d8ac7f..119ecdf93 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -106,6 +106,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error { func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db s.writer = writer + if err = s.execSchema(db); err != nil { + return + } if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { return } diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/key_backup_table.go rename to userapi/storage/sqlite3/key_backup_table.go diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/key_backup_version_table.go rename to userapi/storage/sqlite3/key_backup_version_table.go diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go similarity index 98% rename from userapi/storage/devices/sqlite3/logintoken_table.go rename to userapi/storage/sqlite3/logintoken_table.go index 75ef272f8..52322b46a 100644 --- a/userapi/storage/devices/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp // prepare runs statement preparation. func (s *loginTokenStatements) prepare(db *sql.DB) error { + if err := s.execSchema(db); err != nil { + return err + } return sqlutil.StatementList{ {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/openid_table.go rename to userapi/storage/sqlite3/openid_table.go diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/profile_table.go rename to userapi/storage/sqlite3/profile_table.go diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go similarity index 71% rename from userapi/storage/accounts/sqlite3/storage.go rename to userapi/storage/sqlite3/storage.go index 0bab16ca3..56ec1b6af 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -16,7 +16,9 @@ package sqlite3 import ( "context" + "crypto/rand" "database/sql" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -31,7 +33,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) // Database represents an account database @@ -47,6 +49,9 @@ type Database struct { openIDTokens tokenStatements keyBackupVersions keyBackupVersionStatements keyBackups keyBackupStatements + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -57,8 +62,14 @@ type Database struct { threepidsMu sync.Mutex } +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err @@ -67,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver serverName: serverName, db: db, writer: sqlutil.NewExclusiveWriter(), + loginTokenLifetime: loginTokenLifetime, bcryptCost: bcryptCost, openIDTokenLifetimeMS: openIDTokenLifetimeMS, } @@ -78,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err @@ -108,6 +121,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.keyBackups.prepare(db); err != nil { return nil, err } + if err = d.devices.prepare(db, d.writer, serverName); err != nil { + return nil, err + } + if err = d.loginTokens.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -547,3 +566,196 @@ func (d *Database) UpsertBackupKeys( }) return } + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.devices.selectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go similarity index 100% rename from userapi/storage/accounts/sqlite3/threepid_table.go rename to userapi/storage/sqlite3/threepid_table.go diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go similarity index 81% rename from userapi/storage/accounts/storage.go rename to userapi/storage/storage.go index f43f7efd6..4711439af 100644 --- a/userapi/storage/accounts/storage.go +++ b/userapi/storage/storage.go @@ -15,26 +15,27 @@ //go:build !wasm // +build !wasm -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go similarity index 87% rename from userapi/storage/accounts/storage_wasm.go rename to userapi/storage/storage_wasm.go index 11a88a20a..701dcd833 100644 --- a/userapi/storage/accounts/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package storage import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) @@ -27,10 +28,11 @@ func NewDatabase( serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS) + return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index c7e1f6674..4a5793abb 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -23,18 +23,10 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) -// defaultLoginTokenLifetime determines how old a valid token may be. -// -// NOTSPEC: The current spec says "SHOULD be limited to around five -// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. -// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). -const defaultLoginTokenLifetime = 2 * time.Minute - // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) + db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) + return newInternalAPI(db, cfg, appServices, keyAPI) } func newInternalAPI( - accountDB accounts.Database, - deviceDB devices.Database, + db storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { return &internal.UserInternalAPI{ - AccountDB: accountDB, - DeviceDB: deviceDB, + DB: db, ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 141dd96d1..4214c07f7 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -31,8 +31,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" + "github.com/matrix-org/dendrite/userapi/storage" ) const ( @@ -43,23 +42,19 @@ type apiTestOpts struct { loginTokenLifetime time.Duration } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) { if opts.loginTokenLifetime == 0 { - opts.loginTokenLifetime = defaultLoginTokenLifetime + opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } dbopts := &config.DatabaseOptions{ ConnectionString: "file::memory:", MaxOpenConnections: 1, MaxIdleConnections: 1, } - accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) + accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime) if err != nil { t.Fatalf("failed to create account DB: %s", err) } - deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) - if err != nil { - t.Fatalf("failed to create device DB: %s", err) - } cfg := &config.UserAPI{ Matrix: &config.Global{ @@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a }, } - return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { From 9bd5e414c9afec735b2309769876b62c4bbd2b8f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 11:32:45 +0000 Subject: [PATCH 76/81] Missing commit from #2186 --- build/gobind-pinecone/monolith.go | 1 + keyserver/internal/internal.go | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index acf4406ca..aa8cc6e6e 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -283,6 +283,7 @@ func (m *DendriteMonolith) Start() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix)) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix)) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory)) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix)) cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix)) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix)) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 1c6b06776..0c264b718 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -595,12 +595,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if len(toClean) > 0 { if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()), - } - return + logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) + } else { + logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) } - logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean)) } var keysToStore []api.DeviceMessage From 9f4a39e8e0334e99cc2b8fe3ef33ebc126c8bced Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 13:51:59 +0000 Subject: [PATCH 77/81] Refactor user API storage (#2202) * Refactor User API database * Fix migration bugs --- .../storage/postgres/account_data_table.go | 16 +- userapi/storage/postgres/accounts_table.go | 28 +- userapi/storage/postgres/devices_table.go | 78 +- userapi/storage/postgres/key_backup_table.go | 22 +- .../postgres/key_backup_version_table.go | 20 +- userapi/storage/postgres/logintoken_table.go | 53 +- userapi/storage/postgres/openid_table.go | 29 +- userapi/storage/postgres/profile_table.go | 24 +- userapi/storage/postgres/storage.go | 718 ++--------------- userapi/storage/postgres/threepid_table.go | 23 +- userapi/storage/shared/storage.go | 672 ++++++++++++++++ userapi/storage/sqlite3/account_data_table.go | 21 +- userapi/storage/sqlite3/accounts_table.go | 30 +- userapi/storage/sqlite3/devices_table.go | 82 +- userapi/storage/sqlite3/key_backup_table.go | 22 +- .../sqlite3/key_backup_version_table.go | 20 +- userapi/storage/sqlite3/logintoken_table.go | 49 +- userapi/storage/sqlite3/openid_table.go | 31 +- userapi/storage/sqlite3/profile_table.go | 25 +- userapi/storage/sqlite3/storage.go | 755 ++---------------- userapi/storage/sqlite3/threepid_table.go | 23 +- userapi/storage/tables/interface.go | 95 +++ 22 files changed, 1165 insertions(+), 1671 deletions(-) create mode 100644 userapi/storage/shared/storage.go create mode 100644 userapi/storage/tables/interface.go diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 8ba890e75..67113367b 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,19 +57,20 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) @@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData( return } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, rows.Err() } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 9e3e456a7..92311d56d 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -78,14 +79,15 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( +func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString @@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 64cc0b71a..7bc5dc69b 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,53 +112,32 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if err = s.execSchema(db); err != nil { - return + if err != nil { + return nil, err } - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.deleteDevicesStmt, deleteDevicesSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice( } // deleteDevice removes a single device by id and user localpart. -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) if err != nil { return nil, err @@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index c1402d4d2..ac0e80617 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go index d73447b49..e78e4cd51 100644 --- a/userapi/storage/postgres/key_backup_version_table.go +++ b/userapi/storage/postgres/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -69,12 +70,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go index 508a68989..4de96f839 100644 --- a/userapi/storage/postgres/logintoken_table.go +++ b/userapi/storage/postgres/logintoken_table.go @@ -21,18 +21,11 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) -type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt -} - -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,24 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err +` + +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - if err := s.execSchema(db); err != nil { - return err +func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err } - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 190d141b7..29c3ddcb4 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return +func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 9313864be..32a4b5506 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -59,12 +60,13 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(profilesSchema) +func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{} + _, err := db.Exec(profilesSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, +func (s *profilesStatements) SetAvatarURL( + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) return } -func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, +func (s *profilesStatements) SetDisplayName( + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 734192798..ac5c59b81 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -15,76 +15,33 @@ package postgres import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" "fmt" - "strconv" "time" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/shared" // Import the postgres database driver. _ "github.com/lib/pq" ) -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration - keyBackups keyBackupStatements - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 -} - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 -) - // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - loginTokenLifetime: loginTokenLifetime, - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet return nil, err } - m := sqlutil.NewMigrations() deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) @@ -92,638 +49,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } - if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - if err = d.devices.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.loginTokens.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + accountDataTable, err := NewPostgresAccountDataTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) + accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { - return err + return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } - return d.accounts.updatePassword(ctx, localpart, hash) -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, sqlutil.ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (acc *api.Account, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - // For guest accounts, we create a new numeric local part - if accountType == api.AccountTypeGuest { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart = strconv.FormatInt(numLocalpart, 10) - plaintextPassword = "" - appserviceID = "" - } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) - return err - }) - return -} - -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (*api.Account, error) { - var account *api.Account - var err error - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { - if sqlutil.IsUniqueConstraintViolationErr(err) { - return nil, sqlutil.ErrUserExists - } - return nil, err - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.accounts.deactivateAccount(ctx, localpart) -} - -// CreateOpenIDToken persists a new token that was issued through OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - return nil - }) - return -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) + devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { - return "", err + return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() + keyBackupTable, err := NewPostgresKeyBackupTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err) } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) + keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err) } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) + loginTokenTable, err := NewPostgresLoginTokenTable(db) if err != nil { - return "", err + return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err) } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) + openIDTable, err := NewPostgresOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err) + } + profilesTable, err := NewPostgresProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err) + } + threePIDTable, err := NewPostgresThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewDummyWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil } diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 9280fc87c..63c08d61f 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -58,12 +59,13 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(threepidSchema) +func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{} + _, err := db.Exec(threepidSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID( return } -func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) +func (s *threepidStatements) DeleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go new file mode 100644 index 000000000..5f1f95005 --- /dev/null +++ b/userapi/storage/shared/storage.go @@ -0,0 +1,672 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +// Database represents an account database +type Database struct { + DB *sql.DB + Writer sqlutil.Writer + Accounts tables.AccountsTable + Profiles tables.ProfileTable + AccountDatas tables.AccountDataTable + ThreePIDs tables.ThreePIDTable + OpenIDTokens tables.OpenIDTable + KeyBackups tables.KeyBackupTable + KeyBackupVersions tables.KeyBackupVersionTable + Devices tables.DevicesTable + LoginTokens tables.LoginTokenTable + LoginTokenLifetime time.Duration + ServerName gomatrixserverlib.ServerName + BcryptCost int + OpenIDTokenLifetimeMS int64 +} + +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*api.Account, error) { + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.Profiles.SelectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + }) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + }) +} + +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := d.hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.UpdatePassword(ctx, localpart, hash) + }) +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, ErrUserExists. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, +) (acc *api.Account, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + return err + }) + return +} + +// WARNING! This function assumes that the relevant mutexes have already +// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, +) (*api.Account, error) { + var err error + var account *api.Account + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = d.hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + return nil, sqlutil.ErrUserExists + } + if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`)); err != nil { + return nil, err + } + return account, nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, + err error, +) { + return d.AccountDatas.SelectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data json.RawMessage, err error) { + return d.AccountDatas.SelectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.Accounts.SelectNewNumericLocalpart(ctx, nil) +} + +func (d *Database) hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("this third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + user, err := d.ThreePIDs.SelectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium) + }) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*api.Account, error) { + return d.Accounts.SelectAccountByLocalpart(ctx, localpart) +} + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit) +} + +// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. +func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Accounts.DeactivateAccount(ctx, localpart) + }) +} + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token) +} + +func (d *Database) CreateKeyBackup( + ctx context.Context, userID, algorithm string, authData json.RawMessage, +) (version string, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "") + return err + }) + return +} + +func (d *Database) UpdateKeyBackupAuthData( + ctx context.Context, userID, version string, authData json.RawMessage, +) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData) + }) + return +} + +func (d *Database) DeleteKeyBackup( + ctx context.Context, userID, version string, +) (exists bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetKeyBackup( + ctx context.Context, userID, version string, +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { + err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err) + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err) + } + } + + count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + + return nil + }) + return +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") +} + +func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + return d.Devices.SelectDevicesByID(ctx, deviceIDs) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, ipAddr, userAgent string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart, exceptDeviceID string, +) (devices []api.Device, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + if err != nil { + return err + } + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + return err + } + return nil + }) + return +} + +// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) + }) +} + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.LoginTokenLifetime), + } + + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.LoginTokens.DeleteLoginToken(ctx, txn, token) + }) +} + +// GetLoginTokenDataByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.LoginTokens.SelectLoginToken(ctx, token) +} diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index 871f996e0..cfd8568a9 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const accountDataSchema = ` @@ -56,27 +57,29 @@ type accountDataStatements struct { selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(accountDataSchema) - if err != nil { - return +func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { + s := &accountDataStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(accountDataSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, }.Prepare(db) } -func (s *accountDataStatements) insertAccountData( +func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) error { _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return err } -func (s *accountDataStatements) selectAccountData( +func (s *accountDataStatements) SelectAccountData( ctx context.Context, localpart string, ) ( /* global */ map[string]json.RawMessage, @@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData( return global, rooms, nil } -func (s *accountDataStatements) selectAccountDataByType( +func (s *accountDataStatements) SelectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 5a918e034..e6c37e58e 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -77,15 +78,16 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( +func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString @@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 119ecdf93..423640e90 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" @@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -98,55 +98,33 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { +func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + db: db, + serverName: serverName, + } _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { - s.db = db - s.writer = writer - if err = s.execSchema(db); err != nil { - return + if err != nil { + return nil, err } - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return - } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return - } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDevicesCountStmt, selectDevicesCountSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice( }, nil } -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice( return err } -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) @@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices( return err } -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) iDeviceIDs := make([]interface{}, len(deviceIDs)) for i := range deviceIDs { @@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 837d38cf1..81726edf9 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupTableSchema = ` @@ -72,12 +73,13 @@ type keyBackupStatements struct { selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupTableSchema) +func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { + s := &keyBackupStatements{} + _, err := db.Exec(keyBackupTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, @@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s keyBackupStatements) countKeys( +func (s keyBackupStatements) CountKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (count int64, err error) { err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) return } -func (s *keyBackupStatements) insertBackupKey( +func (s *keyBackupStatements) InsertBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( @@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey( return } -func (s *keyBackupStatements) updateBackupKey( +func (s *keyBackupStatements) UpdateBackupKey( ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ) (err error) { _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( @@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey( return } -func (s *keyBackupStatements) selectKeys( +func (s *keyBackupStatements) SelectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) @@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomID( +func (s *keyBackupStatements) SelectKeysByRoomID( ctx context.Context, txn *sql.Tx, userID, version, roomID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) @@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID( return unpackKeys(ctx, rows) } -func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( +func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID( ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ) (map[string]map[string]api.KeyBackupSession, error) { rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) diff --git a/userapi/storage/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go index 4211ed0f1..e85e6f08b 100644 --- a/userapi/storage/sqlite3/key_backup_version_table.go +++ b/userapi/storage/sqlite3/key_backup_version_table.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const keyBackupVersionTableSchema = ` @@ -67,12 +68,13 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(keyBackupVersionTableSchema) +func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) { + s := &keyBackupVersionStatements{} + _, err := db.Exec(keyBackupVersionTableSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, @@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *keyBackupVersionStatements) insertKeyBackup( +func (s *keyBackupVersionStatements) InsertKeyBackup( ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ) (version string, err error) { var versionInt int64 @@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup( return strconv.FormatInt(versionInt, 10), err } -func (s *keyBackupVersionStatements) updateKeyBackupAuthData( +func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData( ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData( return err } -func (s *keyBackupVersionStatements) updateKeyBackupETag( +func (s *keyBackupVersionStatements) UpdateKeyBackupETag( ctx context.Context, txn *sql.Tx, userID, version, etag string, ) error { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag( return err } -func (s *keyBackupVersionStatements) deleteKeyBackup( +func (s *keyBackupVersionStatements) DeleteKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (bool, error) { versionInt, err := strconv.ParseInt(version, 10, 64) @@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup( return ra == 1, nil } -func (s *keyBackupVersionStatements) selectKeyBackup( +func (s *keyBackupVersionStatements) SelectKeyBackup( ctx context.Context, txn *sql.Tx, userID, version string, ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { var versionInt int64 diff --git a/userapi/storage/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go index 52322b46a..78d42029a 100644 --- a/userapi/storage/sqlite3/logintoken_table.go +++ b/userapi/storage/sqlite3/logintoken_table.go @@ -21,18 +21,17 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/util" ) type loginTokenStatements struct { - insertStmt *sql.Stmt - deleteStmt *sql.Stmt - selectByTokenStmt *sql.Stmt + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt } -// execSchema ensures tables and indices exist. -func (s *loginTokenStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(` +const loginTokenSchema = ` CREATE TABLE IF NOT EXISTS login_tokens ( -- The random value of the token issued to a user token TEXT NOT NULL PRIMARY KEY, @@ -45,24 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens ( -- This index allows efficient garbage collection of expired tokens. CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); -`) - return err -} +` -// prepare runs statement preparation. -func (s *loginTokenStatements) prepare(db *sql.DB) error { - if err := s.execSchema(db); err != nil { - return err +const insertLoginTokenSQL = "" + + "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)" + +const deleteLoginTokenSQL = "" + + "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2" + +const selectLoginTokenSQL = "" + + "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2" + +func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) { + s := &loginTokenStatements{} + _, err := db.Exec(loginTokenSchema) + if err != nil { + return nil, err } - return sqlutil.StatementList{ - {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, - {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, - {&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, + return s, sqlutil.StatementList{ + {&s.insertStmt, insertLoginTokenSQL}, + {&s.deleteStmt, deleteLoginTokenSQL}, + {&s.selectStmt, selectLoginTokenSQL}, }.Prepare(db) } // insert adds an already generated token to the database. -func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { +func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { stmt := sqlutil.TxStmt(txn, s.insertStmt) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) return err @@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata // // As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // The login_tokens_expiration_idx index should make that efficient. -func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { +func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error { stmt := sqlutil.TxStmt(txn, s.deleteStmt) res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) if err != nil { @@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t } // selectByToken returns the data associated with the given token. May return sql.ErrNoRows. -func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { +func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) { var data api.LoginTokenData - err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) + err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) if err != nil { return nil, err } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 98c0488b1..d6090e0da 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens ( ); ` -const insertTokenSQL = "" + +const insertOpenIDTokenSQL = "" + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" -const selectTokenSQL = "" + +const selectOpenIDTokenSQL = "" + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" -type tokenStatements struct { +type openIDTokenStatements struct { db *sql.DB insertTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt serverName gomatrixserverlib.ServerName } -func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.db = db - _, err = db.Exec(openIDTokenSchema) - if err != nil { - return err +func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) { + s := &openIDTokenStatements{ + db: db, + serverName: serverName, } - s.serverName = server - return sqlutil.StatementList{ - {&s.insertTokenStmt, insertTokenSQL}, - {&s.selectTokenStmt, selectTokenSQL}, + _, err := db.Exec(openIDTokenSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertTokenStmt, insertOpenIDTokenSQL}, + {&s.selectTokenStmt, selectOpenIDTokenSQL}, }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. // Returns new token, otherwise returns error if the token already exists. -func (s *tokenStatements) insertToken( +func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, token, localpart string, @@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken( // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // Returns the existing token's attributes, or err if no token is found -func (s *tokenStatements) selectOpenIDTokenAtrributes( +func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( ctx context.Context, token string, ) (*api.OpenIDTokenAttributes, error) { diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index a92e95663..d85b19c7b 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) const profilesSchema = ` @@ -60,13 +61,15 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(profilesSchema) - if err != nil { - return +func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) { + s := &profilesStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(profilesSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertProfileStmt, insertProfileSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL}, @@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *profilesStatements) insertProfile( +func (s *profilesStatements) InsertProfile( ctx context.Context, txn *sql.Tx, localpart string, ) error { _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return err } -func (s *profilesStatements) selectProfileByLocalpart( +func (s *profilesStatements) SelectProfileByLocalpart( ctx context.Context, localpart string, ) (*authtypes.Profile, error) { var profile authtypes.Profile @@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart( return &profile, nil } -func (s *profilesStatements) setAvatarURL( +func (s *profilesStatements) SetAvatarURL( ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) @@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL( return } -func (s *profilesStatements) setDisplayName( +func (s *profilesStatements) SetDisplayName( ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) @@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName( return } -func (s *profilesStatements) selectProfilesBySearch( +func (s *profilesStatements) SelectProfilesBySearch( ctx context.Context, searchString string, limit int, ) ([]authtypes.Profile, error) { var profiles []authtypes.Profile diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 56ec1b6af..98c244977 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,80 +15,34 @@ package sqlite3 import ( - "context" - "crypto/rand" - "database/sql" - "encoding/base64" - "encoding/json" - "errors" "fmt" - "strconv" - "sync" "time" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" -) -// Database represents an account database -type Database struct { - db *sql.DB - writer sqlutil.Writer - - sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - openIDTokens tokenStatements - keyBackupVersions keyBackupVersionStatements - keyBackups keyBackupStatements - devices devicesStatements - loginTokens loginTokenStatements - loginTokenLifetime time.Duration - serverName gomatrixserverlib.ServerName - bcryptCost int - openIDTokenLifetimeMS int64 - - accountsMu sync.Mutex - profilesMu sync.Mutex - accountDatasMu sync.Mutex - threepidsMu sync.Mutex -} - -const ( - // The length of generated device IDs - deviceIDByteLength = 6 - loginTokenByteLength = 32 + // Import the postgres database driver. + _ "github.com/lib/pq" ) // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - loginTokenLifetime: loginTokenLifetime, - bcryptCost: bcryptCost, - openIDTokenLifetimeMS: openIDTokenLifetimeMS, - } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - if err = d.accounts.execSchema(db); err != nil { + m := sqlutil.NewMigrations() + if _, err = db.Exec(accountsSchema); err != nil { + // do this so that the migration can and we don't fail on + // preparing statements for columns that don't exist yet return nil, err } - m := sqlutil.NewMigrations() deltas.LoadIsActive(m) //deltas.LoadLastSeenTSIP(m) deltas.LoadAddAccountType(m) @@ -96,666 +50,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return nil, err } - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, d.writer, "account"); err != nil { - return nil, err - } - if err = d.accounts.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.profiles.prepare(db); err != nil { - return nil, err - } - if err = d.accountDatas.prepare(db); err != nil { - return nil, err - } - if err = d.threepids.prepare(db); err != nil { - return nil, err - } - if err = d.openIDTokens.prepare(db, serverName); err != nil { - return nil, err - } - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } - if err = d.devices.prepare(db, d.writer, serverName); err != nil { - return nil, err - } - if err = d.loginTokens.prepare(db); err != nil { - return nil, err - } - - return d, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*api.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) + accountDataTable, err := NewSQLiteAccountDataTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err - } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) - }) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - d.profilesMu.Lock() - defer d.profilesMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.profiles.setDisplayName(ctx, txn, localpart, displayName) - }) -} - -// SetPassword sets the account password to the given hash. -func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, -) error { - hash, err := d.hashPassword(plaintextPassword) + accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { - return err + return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.updatePassword(ctx, localpart, hash) - }) -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, ErrUserExists. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (acc *api.Account, err error) { - // Create one account at a time else we can get 'database is locked'. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - // For guest accounts, we create a new numeric local part - if accountType == api.AccountTypeGuest { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart = strconv.FormatInt(numLocalpart, 10) - plaintextPassword = "" - appserviceID = "" - } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) - return err - }) - return -} - -// WARNING! This function assumes that the relevant mutexes have already -// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). -func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, -) (*api.Account, error) { - var err error - var account *api.Account - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = d.hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { - return nil, sqlutil.ErrUserExists - } - if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { - return nil, err - } - if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { - return nil, err - } - return account, nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, -) error { - d.accountDatasMu.Lock() - defer d.accountDatasMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) - }) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global map[string]json.RawMessage, - rooms map[string]map[string]json.RawMessage, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data json.RawMessage, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx, nil) -} - -func (d *Database) hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost) - return string(hashBytes), err -} - -// Err3PIDInUse is the error returned when trying to save an association involving -// a third-party identifier which is already associated to a local user. -var Err3PIDInUse = errors.New("this third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - d.threepidsMu.Lock() - defer d.threepidsMu.Unlock() - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.threepids.deleteThreePID(ctx, txn, threepid, medium) - }) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*api.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// SearchProfiles returns all profiles where the provided localpart or display name -// match any part of the profiles in the database. -func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, -) ([]authtypes.Profile, error) { - return d.profiles.selectProfilesBySearch(ctx, searchString, limit) -} - -// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { - return d.writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.accounts.deactivateAccount(ctx, localpart) - }) -} - -// CreateOpenIDToken persists a new token that was issued for OpenID Connect -func (d *Database) CreateOpenIDToken( - ctx context.Context, - token, localpart string, -) (int64, error) { - expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS - err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) - }) - return expiresAtMS, err -} - -// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token -func (d *Database) GetOpenIDTokenAttributes( - ctx context.Context, - token string, -) (*api.OpenIDTokenAttributes, error) { - return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) -} - -func (d *Database) CreateKeyBackup( - ctx context.Context, userID, algorithm string, authData json.RawMessage, -) (version string, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") - return err - }) - return -} - -func (d *Database) UpdateKeyBackupAuthData( - ctx context.Context, userID, version string, authData json.RawMessage, -) (err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) - }) - return -} - -func (d *Database) DeleteKeyBackup( - ctx context.Context, userID, version string, -) (exists bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetKeyBackup( - ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) GetBackupKeys( - ctx context.Context, version, userID, filterRoomID, filterSessionID string, -) (result map[string]map[string]api.KeyBackupSession, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if filterSessionID != "" { - result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) - return err - } - if filterRoomID != "" { - result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) - return err - } - result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) - return err - }) - return -} - -func (d *Database) CountBackupKeys( - ctx context.Context, version, userID string, -) (count int64, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - return nil - }) - return -} - -// nolint:nakedret -func (d *Database) UpsertBackupKeys( - ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, -) (count int64, etag string, err error) { - // wrap the following logic in a txn to ensure we atomically upload keys - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) - if err != nil { - return err - } - if deleted { - return fmt.Errorf("backup was deleted") - } - // pull out all keys for this (user_id, version) - existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) - if err != nil { - return err - } - - changed := false - // loop over all the new keys (which should be smaller than the set of backed up keys) - for _, newKey := range uploads { - // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. - existingRoom := existingKeys[newKey.RoomID] - if existingRoom != nil { - existingSession, ok := existingRoom[newKey.SessionID] - if ok { - if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { - err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) - } - } - // if we shouldn't replace the key we do nothing with it - continue - } - } - // if we're here, either the room or session are new, either way, we insert - err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) - changed = true - if err != nil { - return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) - } - } - - count, err = d.keyBackups.countKeys(ctx, txn, userID, version) - if err != nil { - return err - } - if changed { - // update the etag - var newETag string - if oldETag == "" { - newETag = "1" - } else { - oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) - if err != nil { - return fmt.Errorf("failed to parse old etag: %s", err) - } - newETag = strconv.FormatInt(oldETagInt+1, 10) - } - etag = newETag - return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) - } else { - etag = oldETag - } - - return nil - }) - return -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*api.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*api.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") -} - -func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { - return d.devices.selectDevicesByID(ctx, deviceIDs) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, -) (dev *api.Device, returnErr error) { - if deviceID != nil { - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) - return err - }) - if returnErr == nil { - return - } - } - } - return -} - -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) + devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { - return "", err + return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, -) (devices []api.Device, err error) { - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) - if err != nil { - return err - } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { - return err - } - return nil - }) - return -} - -// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) - }) -} - -// CreateLoginToken generates a token, stores and returns it. The lifetime is -// determined by the loginTokenLifetime given to the Database constructor. -func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { - tok, err := generateLoginToken() + keyBackupTable, err := NewSQLiteKeyBackupTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err) } - meta := &api.LoginTokenMetadata{ - Token: tok, - Expiration: time.Now().Add(d.loginTokenLifetime), - } - - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.insert(ctx, txn, meta, data) - }) + keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db) if err != nil { - return nil, err + return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err) } - - return meta, nil -} - -func generateLoginToken() (string, error) { - b := make([]byte, loginTokenByteLength) - _, err := rand.Read(b) + loginTokenTable, err := NewSQLiteLoginTokenTable(db) if err != nil { - return "", err + return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err) } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// RemoveLoginToken removes the named token (and may clean up other expired tokens). -func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - return d.loginTokens.deleteByToken(ctx, txn, token) - }) -} - -// GetLoginTokenDataByToken returns the data associated with the given token. -// May return sql.ErrNoRows. -func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { - return d.loginTokens.selectByToken(ctx, token) + openIDTable, err := NewSQLiteOpenIDTable(db, serverName) + if err != nil { + return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err) + } + profilesTable, err := NewSQLiteProfilesTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err) + } + threePIDTable, err := NewSQLiteThreePIDTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) + } + return &shared.Database{ + AccountDatas: accountDataTable, + Accounts: accountsTable, + Devices: devicesTable, + KeyBackups: keyBackupTable, + KeyBackupVersions: keyBackupVersionTable, + LoginTokens: loginTokenTable, + OpenIDTokens: openIDTable, + Profiles: profilesTable, + ThreePIDs: threePIDTable, + ServerName: serverName, + DB: db, + Writer: sqlutil.NewExclusiveWriter(), + LoginTokenLifetime: loginTokenLifetime, + BcryptCost: bcryptCost, + OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + }, nil } diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 9dc0e2d22..fa174eed5 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -60,13 +61,15 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - s.db = db - _, err = db.Exec(threepidSchema) - if err != nil { - return +func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{ + db: db, } - return sqlutil.StatementList{ + _, err := db.Exec(threepidSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return threepids, rows.Err() } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID( return err } -func (s *threepidStatements) deleteThreePID( +func (s *threepidStatements) DeleteThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) _, err = stmt.ExecContext(ctx, threepid, medium) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go new file mode 100644 index 000000000..12939ced5 --- /dev/null +++ b/userapi/storage/tables/interface.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/userapi/api" +) + +type AccountDataTable interface { + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) +} + +type AccountsTable interface { + InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string) (err error) + SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) +} + +type DevicesTable interface { + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) + SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error +} + +type KeyBackupTable interface { + CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error) + InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error) + SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error) + SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error) +} + +type KeyBackupVersionTable interface { + InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error + UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error + DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error) + SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) +} + +type LoginTokenTable interface { + InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error + DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error + SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} + +type OpenIDTable interface { + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) +} + +type ProfileTable interface { + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error + SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error) + SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) +} + +type ThreePIDTable interface { + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) +} From dbded875257703eb63c8eb8af8d47d74c811642f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Feb 2022 14:14:16 +0000 Subject: [PATCH 78/81] Expose sync endpoints via `/v3` (#2203) --- syncapi/routing/routing.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index e2ff27395..005a33555 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -39,14 +39,14 @@ func Setup( rsAPI api.RoomserverInternalAPI, cfg *config.SyncAPI, ) { - r0mux := csMux.PathPrefix("/r0").Subrouter() + v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() // TODO: Add AS support for all handlers below. - r0mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -54,7 +54,7 @@ func Setup( return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, federation, rsAPI, cfg, srp) })).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter", + v3mux.Handle("/user/{userId}/filter", httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -64,7 +64,7 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) - r0mux.Handle("/user/{userId}/filter/{filterId}", + v3mux.Handle("/user/{userId}/filter/{filterId}", httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -74,7 +74,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - r0mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) })).Methods(http.MethodGet, http.MethodOptions) } From 002429c9e24cc746e0929b41eccbe429f89a6e1f Mon Sep 17 00:00:00 2001 From: S7evinK <2353100+S7evinK@users.noreply.github.com> Date: Fri, 18 Feb 2022 16:05:03 +0100 Subject: [PATCH 79/81] Implement server notices (#2180) * Add server_notices config * Disallow rejecting "server notice" invites * Update config * Slightly refactor sendEvent and CreateRoom so it can be reused * Implement unspecced server notices * Validate the request * Set the user api when starting * Rename function/variables * Update comments * Update config * Set the avatar on account creation * Update test * Only create the account when starting Only add routes if sever notices are enabled * Use reserver username Check that we actually got roomData * Add check for admin account Enable server notices for CI Return same values as Synapse * Add custom error for rejecting server notice invite * Move building an invite to it's own function, for reusability * Don't create new rooms, use the existing one (follow Synapse behavior) Co-authored-by: kegsay --- clientapi/jsonerror/jsonerror.go | 9 + clientapi/routing/createroom.go | 79 ++--- clientapi/routing/leaveroom.go | 6 + clientapi/routing/membership.go | 35 +- clientapi/routing/routing.go | 45 +++ clientapi/routing/sendevent.go | 52 +-- clientapi/routing/server_notices.go | 343 +++++++++++++++++++ clientapi/routing/server_notices_test.go | 83 +++++ cmd/dendrite-monolith-server/main.go | 1 + dendrite-config.yaml | 12 + roomserver/api/api.go | 5 +- roomserver/api/api_trace.go | 10 +- roomserver/api/perform.go | 2 + roomserver/internal/api.go | 5 + roomserver/internal/perform/perform_leave.go | 43 ++- roomserver/inthttp/client.go | 6 + setup/config/config_global.go | 30 ++ setup/config/config_test.go | 5 + 18 files changed, 689 insertions(+), 82 deletions(-) create mode 100644 clientapi/routing/server_notices.go create mode 100644 clientapi/routing/server_notices_test.go diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index caa216e62..97c597030 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -149,6 +149,15 @@ func MissingParam(msg string) *MatrixError { return &MatrixError{"M_MISSING_PARAM", msg} } +// LeaveServerNoticeError is an error returned when trying to reject an invite +// for a server notice room. +func LeaveServerNoticeError() *MatrixError { + return &MatrixError{ + ErrCode: "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM", + Err: "You cannot reject this invite", + } +} + type IncompatibleRoomVersionError struct { RoomVersion string `json:"room_version"` Error string `json:"error"` diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 80ac22935..fcacc76c0 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -140,33 +141,14 @@ func CreateRoom( accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { - // TODO (#267): Check room ID doesn't clash with an existing one, and we - // probably shouldn't be using pseudo-random strings, maybe GUIDs? - roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) - return createRoom(req, device, cfg, roomID, accountDB, rsAPI, asAPI) -} - -// createRoom implements /createRoom -// nolint: gocyclo -func createRoom( - req *http.Request, device *api.Device, - cfg *config.ClientAPI, roomID string, - accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, - asAPI appserviceAPI.AppServiceQueryAPI, -) util.JSONResponse { - logger := util.GetLogger(req.Context()) - userID := device.UserID var r createRoomRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr } - // TODO: apply rate-limit - if resErr = r.Validate(); resErr != nil { return *resErr } - evTime, err := httputil.ParseTSParam(req) if err != nil { return util.JSONResponse{ @@ -174,6 +156,25 @@ func createRoom( JSON: jsonerror.InvalidArgumentValue(err.Error()), } } + return createRoom(req.Context(), r, device, cfg, accountDB, rsAPI, asAPI, evTime) +} + +// createRoom implements /createRoom +// nolint: gocyclo +func createRoom( + ctx context.Context, + r createRoomRequest, device *api.Device, + cfg *config.ClientAPI, + accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, + evTime time.Time, +) util.JSONResponse { + // TODO (#267): Check room ID doesn't clash with an existing one, and we + // probably shouldn't be using pseudo-random strings, maybe GUIDs? + roomID := fmt.Sprintf("!%s:%s", util.RandomString(16), cfg.Matrix.ServerName) + + logger := util.GetLogger(ctx) + userID := device.UserID // Clobber keys: creator, room_version @@ -200,16 +201,16 @@ func createRoom( "roomVersion": roomVersion, }).Info("Creating new room") - profile, err := appserviceAPI.RetrieveUserProfile(req.Context(), userID, asAPI, accountDB) + profile, err := appserviceAPI.RetrieveUserProfile(ctx, userID, asAPI, accountDB) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") + util.GetLogger(ctx).WithError(err).Error("appserviceAPI.RetrieveUserProfile failed") return jsonerror.InternalServerError() } createContent := map[string]interface{}{} if len(r.CreationContent) > 0 { if err = json.Unmarshal(r.CreationContent, &createContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for creation_content failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for creation_content failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("invalid create content"), @@ -230,7 +231,7 @@ func createRoom( // Merge powerLevelContentOverride fields by unmarshalling it atop the defaults err = json.Unmarshal(r.PowerLevelContentOverride, &powerLevelContent) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal for power_level_content_override failed") + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for power_level_content_override failed") return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON("malformed power_level_content_override"), @@ -319,9 +320,9 @@ func createRoom( } var aliasResp roomserverAPI.GetRoomIDForAliasResponse - err = rsAPI.GetRoomIDForAlias(req.Context(), &hasAliasReq, &aliasResp) + err = rsAPI.GetRoomIDForAlias(ctx, &hasAliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.GetRoomIDForAlias failed") return jsonerror.InternalServerError() } if aliasResp.RoomID != "" { @@ -426,7 +427,7 @@ func createRoom( } err = builder.SetContent(e.Content) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") return jsonerror.InternalServerError() } if i > 0 { @@ -435,12 +436,12 @@ func createRoom( var ev *gomatrixserverlib.Event ev, err = buildEvent(&builder, &authEvents, cfg, evTime, roomVersion) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildEvent failed") return jsonerror.InternalServerError() } if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.Allowed failed") + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return jsonerror.InternalServerError() } @@ -448,7 +449,7 @@ func createRoom( builtEvents = append(builtEvents, ev.Headered(roomVersion)) err = authEvents.AddEvent(ev) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("authEvents.AddEvent failed") + util.GetLogger(ctx).WithError(err).Error("authEvents.AddEvent failed") return jsonerror.InternalServerError() } } @@ -462,8 +463,8 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(req.Context(), rsAPI, inputs, false); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } @@ -478,9 +479,9 @@ func createRoom( } var aliasResp roomserverAPI.SetRoomAliasResponse - err = rsAPI.SetRoomAlias(req.Context(), &aliasReq, &aliasResp) + err = rsAPI.SetRoomAlias(ctx, &aliasReq, &aliasResp) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("aliasAPI.SetRoomAlias failed") + util.GetLogger(ctx).WithError(err).Error("aliasAPI.SetRoomAlias failed") return jsonerror.InternalServerError() } @@ -519,11 +520,11 @@ func createRoom( for _, invitee := range r.Invite { // Build the invite event. inviteEvent, err := buildMembershipEvent( - req.Context(), invitee, "", accountDB, device, gomatrixserverlib.Invite, + ctx, invitee, "", accountDB, device, gomatrixserverlib.Invite, roomID, true, cfg, evTime, rsAPI, asAPI, ) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") continue } inviteStrippedState := append( @@ -532,7 +533,7 @@ func createRoom( ) // Send the invite event to the roomserver. err = roomserverAPI.SendInvite( - req.Context(), + ctx, rsAPI, inviteEvent.Headered(roomVersion), inviteStrippedState, // invite room state @@ -544,7 +545,7 @@ func createRoom( return e.JSONResponse() case nil: default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), @@ -556,13 +557,13 @@ func createRoom( if r.Visibility == "public" { // expose this room in the published room list var pubRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: "public", }, &pubRes) if pubRes.Error != nil { // treat as non-fatal since the room is already made by this point - util.GetLogger(req.Context()).WithError(pubRes.Error).Error("failed to visibility:public") + util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") } } diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index 38cef118e..a34dd02d3 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -38,6 +38,12 @@ func LeaveRoomByID( // Ask the roomserver to perform the leave. if err := rsAPI.PerformLeave(req.Context(), &leaveReq, &leaveRes); err != nil { + if leaveRes.Code != 0 { + return util.JSONResponse{ + Code: leaveRes.Code, + JSON: jsonerror.LeaveServerNoticeError(), + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 112239241..ffe8da136 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -226,27 +226,42 @@ func SendInvite( } } + // We already received the return value, so no need to check for an error here. + response, _ := sendInvite(req.Context(), accountDB, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) + return response +} + +// sendInvite sends an invitation to a user. Returns a JSONResponse and an error +func sendInvite( + ctx context.Context, + accountDB userdb.Database, + device *userapi.Device, + roomID, userID, reason string, + cfg *config.ClientAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, + asAPI appserviceAPI.AppServiceQueryAPI, evTime time.Time, +) (util.JSONResponse, error) { event, err := buildMembershipEvent( - req.Context(), body.UserID, body.Reason, accountDB, device, "invite", + ctx, userID, reason, accountDB, device, "invite", roomID, false, cfg, evTime, rsAPI, asAPI, ) if err == errMissingUserID { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error()), - } + }, err } else if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound(err.Error()), - } + }, err } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvent failed") - return jsonerror.InternalServerError() + util.GetLogger(ctx).WithError(err).Error("buildMembershipEvent failed") + return jsonerror.InternalServerError(), err } err = roomserverAPI.SendInvite( - req.Context(), rsAPI, + ctx, rsAPI, event, nil, // ask the roomserver to draw up invite room state for us cfg.Matrix.ServerName, @@ -254,18 +269,18 @@ func SendInvite( ) switch e := err.(type) { case *roomserverAPI.PerformError: - return e.JSONResponse() + return e.JSONResponse(), err case nil: return util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, - } + }, nil default: - util.GetLogger(req.Context()).WithError(err).Error("roomserverAPI.SendInvite failed") + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInvite failed") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError(), - } + }, err } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 63dcaa413..d75f58b81 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -15,6 +15,7 @@ package routing import ( + "context" "encoding/json" "net/http" "strings" @@ -117,6 +118,50 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) } + // server notifications + if cfg.Matrix.ServerNotices.Enabled { + logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") + serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, accountDB, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + + synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", + httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + txnID := vars["txnID"] + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + &txnID, transactionsCache, + ) + }), + ).Methods(http.MethodPut, http.MethodOptions) + + synapseAdminRouter.Handle("/admin/v1/send_server_notice", + httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req); r != nil { + return *r + } + return SendServerNotice( + req, &cfg.Matrix.ServerNotices, + cfg, userAPI, rsAPI, accountDB, asAPI, + device, serverNotificationSender, + nil, transactionsCache, + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) + } + // You can't just do PathPrefix("/(r0|v3)") because regexps only apply when inside named path variables. // So make a named path variable called 'apiversion' (which we will never read in handlers) and then do // (r0|v3) - BUT this is a captured group, which makes no sense because you cannot extract this group diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 606107b9f..23935b5d9 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -15,10 +15,16 @@ package routing import ( + "context" "net/http" "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" @@ -26,10 +32,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // http://matrix.org/docs/spec/client_server/r0.2.0.html#put-matrix-client-r0-rooms-roomid-send-eventtype-txnid @@ -97,7 +99,22 @@ func SendEvent( defer mutex.(*sync.Mutex).Unlock() startedGeneratingEvent := time.Now() - e, resErr := generateSendEvent(req, device, roomID, eventType, stateKey, cfg, rsAPI) + + var r map[string]interface{} // must be a JSON object + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + evTime, err := httputil.ParseTSParam(req) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } + + e, resErr := generateSendEvent(req.Context(), r, device, roomID, eventType, stateKey, cfg, rsAPI, evTime) if resErr != nil { return *resErr } @@ -153,27 +170,16 @@ func SendEvent( } func generateSendEvent( - req *http.Request, + ctx context.Context, + r map[string]interface{}, device *userapi.Device, roomID, eventType string, stateKey *string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI, + evTime time.Time, ) (*gomatrixserverlib.Event, *util.JSONResponse) { // parse the incoming http request userID := device.UserID - var r map[string]interface{} // must be a JSON object - resErr := httputil.UnmarshalJSONRequest(req, &r) - if resErr != nil { - return nil, resErr - } - - evTime, err := httputil.ParseTSParam(req) - if err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidArgumentValue(err.Error()), - } - } // create the new event and set all the fields we can builder := gomatrixserverlib.EventBuilder{ @@ -182,15 +188,15 @@ func generateSendEvent( Type: eventType, StateKey: stateKey, } - err = builder.SetContent(r) + err := builder.SetContent(r) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("builder.SetContent failed") + util.GetLogger(ctx).WithError(err).Error("builder.SetContent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, @@ -213,7 +219,7 @@ func generateSendEvent( JSON: jsonerror.BadJSON(e.Error()), } } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("eventutil.BuildEvent failed") + util.GetLogger(ctx).WithError(err).Error("eventutil.BuildEvent failed") resErr := jsonerror.InternalServerError() return nil, &resErr } diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go new file mode 100644 index 000000000..42a303a6b --- /dev/null +++ b/clientapi/routing/server_notices.go @@ -0,0 +1,343 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + userdb "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/tokens" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/transactions" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + userapi "github.com/matrix-org/dendrite/userapi/api" +) + +// Unspecced server notice request +// https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/server_notices.md +type sendServerNoticeRequest struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` +} + +// SendServerNotice sends a message to a specific user. It can only be invoked by an admin. +func SendServerNotice( + req *http.Request, + cfgNotices *config.ServerNotices, + cfgClient *config.ClientAPI, + userAPI userapi.UserInternalAPI, + rsAPI api.RoomserverInternalAPI, + accountsDB userdb.Database, + asAPI appserviceAPI.AppServiceQueryAPI, + device *userapi.Device, + senderDevice *userapi.Device, + txnID *string, + txnCache *transactions.Cache, +) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + + if txnID != nil { + // Try to fetch response from transactionsCache + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + + ctx := req.Context() + var r sendServerNoticeRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // check that all required fields are set + if !r.valid() { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid request"), + } + } + + // get rooms for specified user + allUserRooms := []string{} + userRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "join", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get invites for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "invite", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + // get left rooms for specified user + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: r.UserID, + WantMembership: "leave", + }, &userRooms); err != nil { + return util.ErrorResponse(err) + } + allUserRooms = append(allUserRooms, userRooms.RoomIDs...) + + // get rooms of the sender + senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) + senderRooms := api.QueryRoomsForUserResponse{} + if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ + UserID: senderUserID, + WantMembership: "join", + }, &senderRooms); err != nil { + return util.ErrorResponse(err) + } + + // check if we have rooms in common + commonRooms := []string{} + for _, userRoomID := range allUserRooms { + for _, senderRoomID := range senderRooms.RoomIDs { + if userRoomID == senderRoomID { + commonRooms = append(commonRooms, senderRoomID) + } + } + } + + if len(commonRooms) > 1 { + return util.ErrorResponse(fmt.Errorf("expected to find one room, but got %d", len(commonRooms))) + } + + var ( + roomID string + roomVersion = gomatrixserverlib.RoomVersionV6 + ) + + // create a new room for the user + if len(commonRooms) == 0 { + powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) + powerLevelContent.Users[r.UserID] = -10 // taken from Synapse + pl, err := json.Marshal(powerLevelContent) + if err != nil { + return util.ErrorResponse(err) + } + createContent := map[string]interface{}{} + createContent["m.federate"] = false + cc, err := json.Marshal(createContent) + if err != nil { + return util.ErrorResponse(err) + } + crReq := createRoomRequest{ + Invite: []string{r.UserID}, + Name: cfgNotices.RoomName, + Visibility: "private", + Preset: presetPrivateChat, + CreationContent: cc, + GuestCanJoin: false, + RoomVersion: roomVersion, + PowerLevelContentOverride: pl, + } + + roomRes := createRoom(ctx, crReq, senderDevice, cfgClient, accountsDB, rsAPI, asAPI, time.Now()) + + switch data := roomRes.JSON.(type) { + case createRoomResponse: + roomID = data.RoomID + + // tag the room, so we can later check if the user tries to reject an invite + serverAlertTag := gomatrix.TagContent{Tags: map[string]gomatrix.TagProperties{ + "m.server_notice": { + Order: 1.0, + }, + }} + if err = saveTagData(req, r.UserID, roomID, userAPI, serverAlertTag); err != nil { + util.GetLogger(ctx).WithError(err).Error("saveTagData failed") + return jsonerror.InternalServerError() + } + + default: + // if we didn't get a createRoomResponse, we probably received an error, so return that. + return roomRes + } + + } else { + // we've found a room in common, check the membership + roomID = commonRooms[0] + // re-invite the user + res, err := sendInvite(ctx, accountsDB, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + if err != nil { + return res + } + } + + startedGeneratingEvent := time.Now() + + request := map[string]interface{}{ + "body": r.Content.Body, + "msgtype": r.Content.MsgType, + } + e, resErr := generateSendEvent(ctx, request, senderDevice, roomID, "m.room.message", nil, cfgClient, rsAPI, time.Now()) + if resErr != nil { + logrus.Errorf("failed to send message: %+v", resErr) + return *resErr + } + timeToGenerateEvent := time.Since(startedGeneratingEvent) + + var txnAndSessionID *api.TransactionID + if txnID != nil { + txnAndSessionID = &api.TransactionID{ + TransactionID: *txnID, + SessionID: device.SessionID, + } + } + + // pass the new event to the roomserver and receive the correct event ID + // event ID in case of duplicate transaction is discarded + startedSubmittingEvent := time.Now() + if err := api.SendEvents( + ctx, rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + e.Headered(roomVersion), + }, + cfgClient.Matrix.ServerName, + cfgClient.Matrix.ServerName, + txnAndSessionID, + false, + ); err != nil { + util.GetLogger(ctx).WithError(err).Error("SendEvents failed") + return jsonerror.InternalServerError() + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "event_id": e.EventID(), + "room_id": roomID, + "room_version": roomVersion, + }).Info("Sent event to roomserver") + timeToSubmitEvent := time.Since(startedSubmittingEvent) + + res := util.JSONResponse{ + Code: http.StatusOK, + JSON: sendEventResponse{e.EventID()}, + } + // Add response to transactionsCache + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + // Take a note of how long it took to generate the event vs submit + // it to the roomserver. + sendEventDuration.With(prometheus.Labels{"action": "build"}).Observe(float64(timeToGenerateEvent.Milliseconds())) + sendEventDuration.With(prometheus.Labels{"action": "submit"}).Observe(float64(timeToSubmitEvent.Milliseconds())) + + return res +} + +func (r sendServerNoticeRequest) valid() (ok bool) { + if r.UserID == "" { + return false + } + if r.Content.MsgType == "" || r.Content.Body == "" { + return false + } + return true +} + +// getSenderDevice creates a user account to be used when sending server notices. +// It returns an userapi.Device, which is used for building the event +func getSenderDevice( + ctx context.Context, + userAPI userapi.UserInternalAPI, + accountDB userdb.Database, + cfg *config.ClientAPI, +) (*userapi.Device, error) { + var accRes userapi.PerformAccountCreationResponse + // create account if it doesn't exist + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + OnConflict: userapi.ConflictUpdate, + }, &accRes) + if err != nil { + return nil, err + } + + // set the avatarurl for the user + if err = accountDB.SetAvatarURL(ctx, cfg.Matrix.ServerNotices.LocalPart, cfg.Matrix.ServerNotices.AvatarURL); err != nil { + util.GetLogger(ctx).WithError(err).Error("accountDB.SetAvatarURL failed") + return nil, err + } + + // Check if we got existing devices + deviceRes := &userapi.QueryDevicesResponse{} + err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{ + UserID: accRes.Account.UserID, + }, deviceRes) + if err != nil { + return nil, err + } + + if len(deviceRes.Devices) > 0 { + return &deviceRes.Devices[0], nil + } + + // create an AccessToken + token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ + ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), + ServerName: string(cfg.Matrix.ServerName), + UserID: accRes.Account.UserID, + }) + if err != nil { + return nil, err + } + + // create a new device, if we didn't find any + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: cfg.Matrix.ServerNotices.LocalPart, + DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, + AccessToken: token, + NoDeviceListUpdate: true, + }, &devRes) + + if err != nil { + return nil, err + } + return devRes.Device, nil +} diff --git a/clientapi/routing/server_notices_test.go b/clientapi/routing/server_notices_test.go new file mode 100644 index 000000000..2fac072cd --- /dev/null +++ b/clientapi/routing/server_notices_test.go @@ -0,0 +1,83 @@ +package routing + +import ( + "testing" +) + +func Test_sendServerNoticeRequest_validate(t *testing.T) { + type fields struct { + UserID string `json:"user_id,omitempty"` + Content struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + } `json:"content,omitempty"` + Type string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + } + + content := struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + MsgType: "m.text", + Body: "Hello world!", + } + + tests := []struct { + name string + fields fields + wantOk bool + }{ + { + name: "empty request", + fields: fields{}, + }, + { + name: "msgtype empty", + fields: fields{ + UserID: "@alice:localhost", + Content: struct { + MsgType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` + }{ + Body: "Hello world!", + }, + }, + }, + { + name: "msg body empty", + fields: fields{ + UserID: "@alice:localhost", + }, + }, + { + name: "statekey empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + { + name: "type empty", + fields: fields{ + UserID: "@alice:localhost", + Content: content, + }, + wantOk: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := sendServerNoticeRequest{ + UserID: tt.fields.UserID, + Content: tt.fields.Content, + Type: tt.fields.Type, + StateKey: tt.fields.StateKey, + } + if gotOk := r.valid(); gotOk != tt.wantOk { + t.Errorf("valid() = %v, want %v", gotOk, tt.wantOk) + } + }) + } +} diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 4d0598f3f..bb2685208 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -132,6 +132,7 @@ func main() { // dependency. Other components also need updating after their dependencies are up. rsImpl.SetFederationAPI(fsAPI, keyRing) rsImpl.SetAppserviceAPI(asAPI) + rsImpl.SetUserAPI(userAPI) keyImpl.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 35f72222e..6d086ed77 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -68,6 +68,18 @@ global: # to other servers and the federation API will not be exposed. disable_federation: false + # Server notices allows server admins to send messages to all users. + server_notices: + enabled: false + # The server localpart to be used when sending notices, ensure this is not yet taken + local_part: "_server" + # The displayname to be used when sending notices + display_name: "Server alerts" + # The mxid of the avatar to use + avatar_url: "" + # The roomname to be used when creating messages + room_name: "Server Alerts" + # Configuration for NATS JetStream jetstream: # A list of NATS Server addresses to connect to. If none are specified, an diff --git a/roomserver/api/api.go b/roomserver/api/api.go index e6d37e8f1..bcbf0e4f9 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -3,9 +3,11 @@ package api import ( "context" + "github.com/matrix-org/gomatrixserverlib" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInputAPI is used to write events to the room server. @@ -14,6 +16,7 @@ type RoomserverInternalAPI interface { // interdependencies between the roomserver and other input APIs SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) + SetUserAPI(userAPI userapi.UserInternalAPI) InputRoomEvents( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 16f52abb7..88b372154 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -5,10 +5,12 @@ import ( "encoding/json" "fmt" - asAPI "github.com/matrix-org/dendrite/appservice/api" - fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + + asAPI "github.com/matrix-org/dendrite/appservice/api" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // RoomserverInternalAPITrace wraps a RoomserverInternalAPI and logs the @@ -25,6 +27,10 @@ func (t *RoomserverInternalAPITrace) SetAppserviceAPI(asAPI asAPI.AppServiceQuer t.Impl.SetAppserviceAPI(asAPI) } +func (t *RoomserverInternalAPITrace) SetUserAPI(userAPI userapi.UserInternalAPI) { + t.Impl.SetUserAPI(userAPI) +} + func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 51cbcb1ad..d640858a6 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -95,6 +95,8 @@ type PerformLeaveRequest struct { } type PerformLeaveResponse struct { + Code int `json:"code,omitempty"` + Message interface{} `json:"message,omitempty"` } type PerformInviteRequest struct { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index e58f11c13..10c8c844e 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -15,6 +15,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" @@ -159,6 +160,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.FederationInternalA } } +func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { + r.Leaver.UserAPI = userAPI +} + func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { r.asAPI = asAPI } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 12784e5f5..49ddd4810 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -16,25 +16,29 @@ package perform import ( "context" + "encoding/json" "fmt" "strings" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" + userapi "github.com/matrix-org/dendrite/userapi/api" ) type Leaver struct { - Cfg *config.RoomServer - DB storage.Database - FSAPI fsAPI.FederationInternalAPI - + Cfg *config.RoomServer + DB storage.Database + FSAPI fsAPI.FederationInternalAPI + UserAPI userapi.UserInternalAPI Inputer *input.Inputer } @@ -85,6 +89,31 @@ func (r *Leaver) performLeaveRoomByID( if host != r.Cfg.Matrix.ServerName { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } + // check that this is not a "server notice room" + accData := &userapi.QueryAccountDataResponse{} + if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ + UserID: req.UserID, + RoomID: req.RoomID, + DataType: "m.tag", + }, accData); err != nil { + return nil, fmt.Errorf("unable to query account data") + } + + if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { + tagData, ok := roomData["m.tag"] + if ok { + tags := gomatrix.TagContent{} + if err = json.Unmarshal(tagData, &tags); err != nil { + return nil, fmt.Errorf("unable to unmarshal tag content") + } + if _, ok = tags.Tags["m.server_notice"]; ok { + // mimic the returned values from Synapse + res.Message = "You cannot reject this invite" + res.Code = 403 + return nil, fmt.Errorf("You cannot reject this invite") + } + } + } } // There's no invite pending, so first of all we want to find out diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index a61404efe..99c596606 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -11,6 +11,8 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/opentracing/opentracing-go" ) @@ -90,6 +92,10 @@ func (h *httpRoomserverInternalAPI) SetFederationAPI(fsAPI fsInputAPI.Federation func (h *httpRoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceQueryAPI) { } +// SetUserAPI no-ops in HTTP client mode as there is no chicken/egg scenario +func (h *httpRoomserverInternalAPI) SetUserAPI(userAPI userapi.UserInternalAPI) { +} + // SetRoomAlias implements RoomserverAliasAPI func (h *httpRoomserverInternalAPI) SetRoomAlias( ctx context.Context, diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 6f2306a6d..b947f2076 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -57,6 +57,9 @@ type Global struct { // DNS caching options for all outbound HTTP requests DNSCache DNSCacheOptions `yaml:"dns_cache"` + + // ServerNotices configuration used for sending server notices + ServerNotices ServerNotices `yaml:"server_notices"` } func (c *Global) Defaults(generate bool) { @@ -72,6 +75,7 @@ func (c *Global) Defaults(generate bool) { c.Metrics.Defaults(generate) c.DNSCache.Defaults() c.Sentry.Defaults() + c.ServerNotices.Defaults(generate) } func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -82,6 +86,7 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) c.DNSCache.Verify(configErrs, isMonolith) + c.ServerNotices.Verify(configErrs, isMonolith) } type OldVerifyKeys struct { @@ -123,6 +128,31 @@ func (c *Metrics) Defaults(generate bool) { func (c *Metrics) Verify(configErrs *ConfigErrors, isMonolith bool) { } +// ServerNotices defines the configuration used for sending server notices +type ServerNotices struct { + Enabled bool `yaml:"enabled"` + // The localpart to be used when sending notices + LocalPart string `yaml:"local_part"` + // The displayname to be used when sending notices + DisplayName string `yaml:"display_name"` + // The avatar of this user + AvatarURL string `yaml:"avatar"` + // The roomname to be used when creating messages + RoomName string `yaml:"room_name"` +} + +func (c *ServerNotices) Defaults(generate bool) { + if generate { + c.Enabled = true + c.LocalPart = "_server" + c.DisplayName = "Server Alert" + c.RoomName = "Server Alert" + c.AvatarURL = "" + } +} + +func (c *ServerNotices) Verify(errors *ConfigErrors, isMonolith bool) {} + // The configuration to use for Sentry error reporting type Sentry struct { Enabled bool `yaml:"enabled"` diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 97c98e57f..8f7611f0a 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -58,6 +58,11 @@ global: basic_auth: username: metrics password: metrics + server_notices: + local_part: "_server" + display_name: "Server alerts" + avatar: "" + room_name: "Server Alerts" app_service_api: internal_api: listen: http://localhost:7777 From a386fbed2c3696cd28307e7cfe02822dff76e4f9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Feb 2022 12:30:43 +0000 Subject: [PATCH 80/81] Delete one-time keys when deleting a device (#2208) --- keyserver/storage/postgres/one_time_keys_table.go | 12 ++++++++++++ keyserver/storage/shared/storage.go | 3 +++ keyserver/storage/sqlite3/one_time_keys_table.go | 12 ++++++++++++ keyserver/storage/tables/interface.go | 1 + 4 files changed, 28 insertions(+) diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index cc397ba84..0b143a1aa 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -59,6 +59,9 @@ const deleteOneTimeKeySQL = "" + const selectKeyByAlgorithmSQL = "" + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + type oneTimeKeysStatements struct { db *sql.DB upsertKeysStmt *sql.Stmt @@ -66,6 +69,7 @@ type oneTimeKeysStatements struct { selectKeysCountStmt *sql.Stmt selectKeyByAlgorithmStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt } func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -91,6 +95,9 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { return nil, err } + if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -187,3 +194,8 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err } + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index deee76eb4..f2790c8df 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -171,6 +171,9 @@ func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceID if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } } return nil }) diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 185b88612..897839aca 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -58,6 +58,9 @@ const deleteOneTimeKeySQL = "" + const selectKeyByAlgorithmSQL = "" + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + type oneTimeKeysStatements struct { db *sql.DB upsertKeysStmt *sql.Stmt @@ -65,6 +68,7 @@ type oneTimeKeysStatements struct { selectKeysCountStmt *sql.Stmt selectKeyByAlgorithmStmt *sql.Stmt deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt } func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -90,6 +94,9 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { return nil, err } + if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { + return nil, err + } return s, nil } @@ -201,3 +208,8 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err } + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index ff70a2366..cd1719598 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -31,6 +31,7 @@ type OneTimeKeys interface { // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error } type DeviceKeys interface { From a02dd7721d8555391597444d185d402f94b626ae Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Feb 2022 15:25:54 +0000 Subject: [PATCH 81/81] Reset invalid state snapshots for events during state storage refactor migration (#2209) This should help with #2204. We can't do this for rooms, only events. --- .../2021041615092700_state_blocks_refactor.go | 26 +++++++------------ .../2021041615092700_state_blocks_refactor.go | 12 ++++++++- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go index 06740dc8b..06442a4c3 100644 --- a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go @@ -256,23 +256,17 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { return fmt.Errorf("assertion query failed: %s", err) } if count > 0 { - var debugEventID, debugRoomID string - var debugEventTypeNID, debugStateKeyNID, debugSnapNID, debugDepth int64 - err = tx.QueryRow( - `SELECT event_id, event_type_nid, event_state_key_nid, roomserver_events.state_snapshot_nid, depth, room_id FROM roomserver_events - JOIN roomserver_rooms ON roomserver_rooms.room_nid = roomserver_events.room_nid WHERE roomserver_events.state_snapshot_nid < $1 AND roomserver_events.state_snapshot_nid != 0`, maxsnapshotid, - ).Scan(&debugEventID, &debugEventTypeNID, &debugStateKeyNID, &debugSnapNID, &debugDepth, &debugRoomID) - if err != nil { - logrus.Errorf("cannot extract debug info: %v", err) - } else { - logrus.Errorf( - "Affected row: event_id=%v room_id=%v type=%v state_key=%v snapshot=%v depth=%v", - debugEventID, debugRoomID, debugEventTypeNID, debugStateKeyNID, debugSnapNID, debugDepth, - ) - logrus.Errorf("To fix this manually, run this query first then retry the migration: "+ - "UPDATE roomserver_events SET state_snapshot_nid=0 WHERE event_id='%v'", debugEventID) + var res sql.Result + var c int64 + res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to reset invalid state snapshots: %w", err) + } + if c, err = res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get row count for invalid state snapshots updated: %w", err) + } else if c != count { + return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) } - return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) } if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { return fmt.Errorf("assertion query failed: %s", err) diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go index 8d0331748..8f5ab8fc5 100644 --- a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go +++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go @@ -179,7 +179,17 @@ func UpStateBlocksRefactor(tx *sql.Tx) error { return fmt.Errorf("assertion query failed: %s", err) } if count > 0 { - return fmt.Errorf("%d events exist in roomserver_events which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) + var res sql.Result + var c int64 + res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID) + if err != nil && err != sql.ErrNoRows { + return fmt.Errorf("failed to reset invalid state snapshots: %w", err) + } + if c, err = res.RowsAffected(); err != nil { + return fmt.Errorf("failed to get row count for invalid state snapshots updated: %w", err) + } else if c != count { + return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) + } } if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { return fmt.Errorf("assertion query failed: %s", err)