From 8d64c24b23518fde968020e5a093564acc46235a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 26 Sep 2022 09:33:34 +0100 Subject: [PATCH 01/17] Update documentation to state that Dendrite requires PostgreSQL UTF-8 encoding --- docs/installation/4_database.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/installation/4_database.md b/docs/installation/4_database.md index 68f2d44d0..d64ee6615 100644 --- a/docs/installation/4_database.md +++ b/docs/installation/4_database.md @@ -16,6 +16,9 @@ Dendrite can automatically populate the database with the relevant tables and in it is not capable of creating the databases themselves. You will need to create the databases manually. +The databases **must** be created with UTF-8 encoding configured or you will likely run into problems +with your Dendrite deployment. + At this point, you can choose to either use a single database for all Dendrite components, or you can run each component with its own separate database: @@ -65,7 +68,7 @@ sudo -u postgres createuser -P dendrite Create the database itself, using the `dendrite` role from above: ```bash -sudo -u postgres createdb -O dendrite dendrite +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite ``` ### Multiple database creation @@ -85,7 +88,7 @@ The following eight components require a database. In this example they will be ```bash for i in appservice federationapi mediaapi mscs roomserver syncapi keyserver userapi; do - sudo -u postgres createdb -O dendrite dendrite_$i + sudo -u postgres createdb -O dendrite -E UTF-8 dendrite_$i done ``` From 3c416517b0045a27ec0bd03f285600b733dd5d1c Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 26 Sep 2022 10:45:35 +0200 Subject: [PATCH 02/17] Fix possible "Database is locked" issue --- userapi/storage/shared/storage.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index acc65212c..e32a442d0 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -777,7 +777,7 @@ func (d *Database) GetPushers( func (d *Database) RemovePusher( ctx context.Context, appid, pushkey, localpart string, ) error { - return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) if err == sql.ErrNoRows { return nil @@ -792,7 +792,7 @@ func (d *Database) RemovePusher( func (d *Database) RemovePushers( ctx context.Context, appid, pushkey string, ) error { - return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) }) } From 3e87096a21729fcc7e074d09ee12da56394dd15d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 26 Sep 2022 09:54:54 +0100 Subject: [PATCH 03/17] Use `TxStmt` in SQLite pusher table --- userapi/storage/sqlite3/pusher_table.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index d5bd1617b..dba97c3d4 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -96,7 +96,7 @@ func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, ) error { - _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) logrus.Debugf("Created pusher %d", session_id) return err } @@ -144,13 +144,13 @@ func (s *pushersStatements) SelectPushers( func (s *pushersStatements) DeletePusher( ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ) error { - _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) return err } func (s *pushersStatements) DeletePushers( ctx context.Context, txn *sql.Tx, appid, pushkey string, ) error { - _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) return err } From f022fc1397fda984245ad1611531b37480cf4f46 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 26 Sep 2022 17:35:35 +0100 Subject: [PATCH 04/17] Remove `origin` field from PDUs (#2737) This nukes the `origin` field from PDUs as per matrix-org/matrix-spec#998, matrix-org/gomatrixserverlib#341. --- federationapi/internal/perform.go | 7 ++--- federationapi/routing/invite.go | 9 ++++++- federationapi/routing/join.go | 18 ++++--------- federationapi/routing/leave.go | 27 ++++++++++++------- go.mod | 2 +- go.sum | 4 +-- roomserver/internal/input/input_events.go | 22 ++++++++------- .../internal/perform/perform_backfill.go | 4 ++- roomserver/internal/perform/perform_invite.go | 8 ++++-- roomserver/internal/perform/perform_leave.go | 15 +++++++---- roomserver/storage/shared/storage.go | 2 +- 11 files changed, 69 insertions(+), 49 deletions(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 9100c8f18..84702f4ce 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -217,7 +217,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( var remoteEvent *gomatrixserverlib.Event remoteEvent, err = respSendJoin.Event.UntrustedEvent(respMakeJoin.RoomVersion) if err == nil && isWellFormedMembershipEvent( - remoteEvent, roomID, userID, r.cfg.Matrix.ServerName, + remoteEvent, roomID, userID, ) { event = remoteEvent } @@ -285,7 +285,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // isWellFormedMembershipEvent returns true if the event looks like a legitimate // membership event. -func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string, origin gomatrixserverlib.ServerName) bool { +func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID string) bool { if membership, err := event.Membership(); err != nil { return false } else if membership != gomatrixserverlib.Join { @@ -294,9 +294,6 @@ func isWellFormedMembershipEvent(event *gomatrixserverlib.Event, roomID, userID if event.RoomID() != roomID { return false } - if event.Origin() != origin { - return false - } if !event.StateKeyEquals(userID) { return false } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 4b795018c..504204504 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -148,8 +148,15 @@ func processInvite( JSON: jsonerror.BadJSON("The event JSON could not be redacted"), } } + _, serverName, err := gomatrixserverlib.SplitID('@', event.Sender()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The event JSON contains an invalid sender"), + } + } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: event.Origin(), + ServerName: serverName, Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 1a1219873..74d065e59 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -203,14 +203,6 @@ func SendJoin( } } - // Check that the event is from the server sending the request. - if event.Origin() != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"), - } - } - // Check that a state key is provided. if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ @@ -228,16 +220,16 @@ func SendJoin( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var domain gomatrixserverlib.ServerName - if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + var serverName gomatrixserverlib.ServerName + if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: jsonerror.Forbidden("The sender of the join is invalid"), } - } else if domain != request.Origin() { + } else if serverName != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"), + JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), } } @@ -292,7 +284,7 @@ func SendJoin( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: event.Origin(), + ServerName: serverName, Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 8e43ce959..a67e4e28b 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -118,6 +118,7 @@ func MakeLeave( } // SendLeave implements the /send_leave API +// nolint:gocyclo func SendLeave( httpReq *http.Request, request *gomatrixserverlib.FederationRequest, @@ -167,14 +168,6 @@ func SendLeave( } } - // Check that the event is from the server sending the request. - if event.Origin() != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The leave must be sent by the server it originated on"), - } - } - if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -188,6 +181,22 @@ func SendLeave( } } + // Check that the sender belongs to the server that is sending us + // the request. By this point we've already asserted that the sender + // and the state key are equal so we don't need to check both. + var serverName gomatrixserverlib.ServerName + if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender of the join is invalid"), + } + } else if serverName != request.Origin() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender does not match the server that originated the request"), + } + } + // Check if the user has already left. If so, no-op! queryReq := &api.QueryLatestEventsAndStateRequest{ RoomID: roomID, @@ -240,7 +249,7 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: event.Origin(), + ServerName: serverName, Message: redacted, AtTS: event.OriginServerTS(), StrictValidityChecking: true, diff --git a/go.mod b/go.mod index ded7f28b0..3d99a71e1 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 570e6172f..32edf8c11 100644 --- a/go.sum +++ b/go.sum @@ -384,8 +384,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3 h1:u3FKZmXxfhv3XhD8RziBlt96QTt8eHFhg1upCloBh2g= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220923115829-2217f6c65ce3/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 h1:Ym50Fgn3GiYya4p29k3nJ5nYsalFGev3eIm3DeGNIq4= github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 29af649ad..01fd62010 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -118,6 +118,10 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } + _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + if err != nil { + return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + } // If we already know about this outlier and it hasn't been rejected // then we won't attempt to reprocess it. If it was rejected or has now @@ -145,7 +149,8 @@ func (r *Inputer) processRoomEvent( var missingAuth, missingPrev bool serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} if !isCreateEvent { - missingAuthIDs, missingPrevIDs, err := r.DB.MissingAuthPrevEvents(ctx, event) + var missingAuthIDs, missingPrevIDs []string + missingAuthIDs, missingPrevIDs, err = r.DB.MissingAuthPrevEvents(ctx, event) if err != nil { return fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) } @@ -158,7 +163,7 @@ func (r *Inputer) processRoomEvent( RoomID: event.RoomID(), ExcludeSelf: true, } - if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { + if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } // Sort all of the servers into a map so that we can randomise @@ -173,9 +178,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if origin := event.Origin(); origin != input.Origin { - serverRes.ServerNames = append(serverRes.ServerNames, origin) - delete(servers, origin) + if senderDomain != input.Origin { + serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) + delete(servers, senderDomain) } for server := range servers { serverRes.ServerNames = append(serverRes.ServerNames, server) @@ -188,7 +193,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err := r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -231,7 +236,6 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - var err error softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") @@ -265,7 +269,8 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.Event{}, } - if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + var stateSnapshot *parsedRespState + if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { // Something went wrong with retrieving the missing state, so we can't // really do anything with the event other than reject it at this point. isRejected = true @@ -302,7 +307,6 @@ func (r *Inputer) processRoomEvent( // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { - var err error historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 51c66415a..69a075733 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -468,7 +468,9 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[gomatrixserverlib.ServerName]bool) for _, event := range memberEvents { - serverSet[event.Origin()] = true + if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { + serverSet[senderDomain] = true + } } var servers []gomatrixserverlib.ServerName for server := range serverSet { diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 483e78c3f..3fbdf332e 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -50,6 +50,10 @@ func (r *Inviter) PerformInvite( if event.StateKey() == nil { return nil, fmt.Errorf("invite must be a state event") } + _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + if err != nil { + return nil, fmt.Errorf("sender %q is invalid", event.Sender()) + } roomID := event.RoomID() targetUserID := *event.StateKey() @@ -67,7 +71,7 @@ func (r *Inviter) PerformInvite( return nil, nil } isTargetLocal := domain == r.Cfg.Matrix.ServerName - isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName + isOriginLocal := senderDomain == r.Cfg.Matrix.ServerName if !isOriginLocal && !isTargetLocal { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -235,7 +239,7 @@ func (r *Inviter) PerformInvite( { Kind: api.KindNew, Event: event, - Origin: event.Origin(), + Origin: senderDomain, SendAsServer: req.SendAsServer, }, }, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 036404cd2..ada3aab06 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -81,12 +81,11 @@ func (r *Leaver) performLeaveRoomByID( // that. isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) if err == nil && isInvitePending { - var host gomatrixserverlib.ServerName - _, host, err = gomatrixserverlib.SplitID('@', senderUser) - if err != nil { + _, senderDomain, serr := gomatrixserverlib.SplitID('@', senderUser) + if serr != nil { return nil, fmt.Errorf("sender %q is invalid", senderUser) } - if host != r.Cfg.Matrix.ServerName { + if senderDomain != r.Cfg.Matrix.ServerName { return r.performFederatedRejectInvite(ctx, req, res, senderUser, eventID) } // check that this is not a "server notice room" @@ -172,6 +171,12 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } + // Get the sender domain. + _, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender()) + if serr != nil { + return nil, fmt.Errorf("sender %q is invalid", event.Sender()) + } + // Give our leave event to the roomserver input stream. The // roomserver will process the membership change and notify // downstream automatically. @@ -180,7 +185,7 @@ func (r *Leaver) performLeaveRoomByID( { Kind: api.KindNew, Event: event.Headered(buildRes.RoomVersion), - Origin: event.Origin(), + Origin: senderDomain, SendAsServer: string(r.Cfg.Matrix.ServerName), }, }, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 00a17e5cb..593abbea1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -896,7 +896,7 @@ func (d *Database) handleRedactions( switch { case redactUser >= pl.Redact: // The power level of the redaction event’s sender is greater than or equal to the redact level. - case redactedEvent.Origin() == redactionEvent.Origin() && redactedEvent.Sender() == redactionEvent.Sender(): + case redactedEvent.Sender() == redactionEvent.Sender(): // The domain of the redaction event’s sender matches that of the original event’s sender. default: return nil, "", nil From 65ee181de4d98fbb1a6ee988fee1214d82d50a56 Mon Sep 17 00:00:00 2001 From: Tak Wai Wong <64229756+tak-hntlabs@users.noreply.github.com> Date: Mon, 26 Sep 2022 16:46:52 -0700 Subject: [PATCH 05/17] Authorization - config, interface, and default implementation (#33) * add config yaml for enable_auth * zion_space_manager_localhost.go * Placeholders for authorization * rename func and type * re-run go mod tidy Co-authored-by: Tak Wai Wong --- authorization/authorization.go | 35 ++++++++++++++ authorization/default_authorization.go | 23 +++++++++ clientapi/routing/routing.go | 3 ++ setup/config/config_publickey.go | 7 +-- web3/account.go | 65 ++++++++++++++++++++++++++ web3/client.go | 14 ++++++ 6 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 authorization/authorization.go create mode 100644 authorization/default_authorization.go create mode 100644 web3/account.go create mode 100644 web3/client.go diff --git a/authorization/authorization.go b/authorization/authorization.go new file mode 100644 index 000000000..9f7cbcbc1 --- /dev/null +++ b/authorization/authorization.go @@ -0,0 +1,35 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorization + +import "github.com/matrix-org/dendrite/setup/config" + +type AuthorizationArgs struct { + RoomId string + UserId string + Permission string +} + +type Authorization interface { + IsAllowed(args AuthorizationArgs) (bool, error) +} + +func NewClientApiAuthorization(cfg *config.ClientAPI) Authorization { + // Load authorization manager for Zion + //if cfg.PublicKeyAuthentication.Ethereum.EnableAuthz { + //} + + return &DefaultAuthorization{} +} diff --git a/authorization/default_authorization.go b/authorization/default_authorization.go new file mode 100644 index 000000000..1baba3f86 --- /dev/null +++ b/authorization/default_authorization.go @@ -0,0 +1,23 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authorization + +type DefaultAuthorization struct { +} + +func (azm *DefaultAuthorization) IsAllowed(args AuthorizationArgs) (bool, error) { + // Default. No authorization logic. + return true, nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a06ef3c12..441fd7b07 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -21,6 +21,7 @@ import ( "github.com/gorilla/mux" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/authorization" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" clientutil "github.com/matrix-org/dendrite/clientapi/httputil" @@ -73,6 +74,8 @@ func Setup( rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, userAPI, cfg) + authorization := authorization.NewClientApiAuthorization(cfg) + _ = authorization // todo: use this in httputil.MakeAuthAPI unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, diff --git a/setup/config/config_publickey.go b/setup/config/config_publickey.go index e214163e2..d834cfefc 100644 --- a/setup/config/config_publickey.go +++ b/setup/config/config_publickey.go @@ -21,9 +21,10 @@ func (p EthereumAuthParams) GetParams() interface{} { } type EthereumAuthConfig struct { - Enabled bool `yaml:"enabled"` - Version uint `yaml:"version"` - ChainIDs []int `yaml:"chain_ids"` + Enabled bool `yaml:"enabled"` + Version uint `yaml:"version"` + ChainIDs []int `yaml:"chain_ids"` + EnableAuthz bool `yaml:"enable_authz"` // Flag to enable / disable authorization during development } type PublicKeyAuthentication struct { diff --git a/web3/account.go b/web3/account.go new file mode 100644 index 000000000..27eda6b5d --- /dev/null +++ b/web3/account.go @@ -0,0 +1,65 @@ +package web3 + +import ( + "context" + "crypto/ecdsa" + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethclient" +) + +type CreateTransactionSignerArgs struct { + PrivateKey string + ChainId int64 + Client *ethclient.Client + GasValue int64 // in wei + GasLimit int64 // in units +} + +func CreateTransactionSigner(args CreateTransactionSignerArgs) (*bind.TransactOpts, error) { + privateKey, err := crypto.HexToECDSA(args.PrivateKey) + if err != nil { + return nil, err + } + + publicKey := privateKey.Public() + publicKeyECDSA, ok := publicKey.(*ecdsa.PublicKey) + if !ok { + return nil, errors.New("cannot create public key ECDSA") + } + + fromAddress := crypto.PubkeyToAddress(*publicKeyECDSA) + + nonce, err := args.Client.PendingNonceAt(context.Background(), fromAddress) + if err != nil { + return nil, err + } + + gasPrice, err := args.Client.SuggestGasPrice((context.Background())) + if err != nil { + return nil, err + } + + signer, err := bind.NewKeyedTransactorWithChainID(privateKey, big.NewInt(args.ChainId)) + if err != nil { + return nil, err + } + + signer.Nonce = big.NewInt(int64(nonce)) + signer.Value = big.NewInt(args.GasValue) + signer.GasLimit = uint64(args.GasLimit) + signer.GasPrice = gasPrice + + fmt.Printf("{ nonce: %d, value: %d, gasLimit: %d, gasPrice: %d }\n", + signer.Nonce, + signer.Value, + signer.GasLimit, + signer.GasPrice, + ) + + return signer, nil +} diff --git a/web3/client.go b/web3/client.go new file mode 100644 index 000000000..9cd643648 --- /dev/null +++ b/web3/client.go @@ -0,0 +1,14 @@ +package web3 + +import ( + "github.com/ethereum/go-ethereum/ethclient" +) + +func GetEthClient(web3ProviderUrl string) (*ethclient.Client, error) { + client, err := ethclient.Dial(web3ProviderUrl) + if err != nil { + return nil, err + } + + return client, nil +} From 40fec70d1336f97e31e1d23cc0576b543cb119fc Mon Sep 17 00:00:00 2001 From: networkException Date: Tue, 27 Sep 2022 10:39:39 +0200 Subject: [PATCH 06/17] Add pinecone demo container image (#2710) This pull request adds the configuration and CI steps to build and publish a container wrapping the `dendrite-demo-pinecone` command as well as fixes a sentence structure issue in the pull request template. As this does not touch any go source code no tests have been added ### Pull Request Checklist * [x] I have added tests for PR _or_ I have justified why this PR doesn't need tests. * [x] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: networkException (by private sign-off) Co-authored-by: Neil Alexander --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- .github/workflows/docker.yml | 60 +++++++++++++++++++++++++++ build/docker/Dockerfile.demo-pinecone | 25 +++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 build/docker/Dockerfile.demo-pinecone diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 9bfb01667..e0b82e2aa 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,7 +2,7 @@ -* [ ] I have added added tests for PR _or_ I have justified why this PR doesn't need tests. +* [ ] I have added tests for PR _or_ I have justified why this PR doesn't need tests. * [ ] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: `Your Name ` diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 642587924..b4e24e52f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -137,3 +137,63 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} + + demo-pinecone: + name: Pinecone demo image + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Get release tag + if: github.event_name == 'release' # Only for GitHub releases + run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: Login to Docker Hub + uses: docker/login-action@v1 + with: + username: ${{ env.DOCKER_HUB_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + - name: Login to GitHub Containers + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build main pinecone demo image + if: github.ref_name == 'main' + id: docker_build_demo_pinecone + uses: docker/build-push-action@v2 + with: + cache-from: type=gha + cache-to: type=gha,mode=max + context: . + file: ./build/docker/Dockerfile.demo-pinecone + platforms: ${{ env.PLATFORMS }} + push: true + tags: | + ${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }} + ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ github.ref_name }} + + - name: Build release pinecone demo image + if: github.event_name == 'release' # Only for GitHub releases + id: docker_build_demo_pinecone_release + uses: docker/build-push-action@v2 + with: + cache-from: type=gha + cache-to: type=gha,mode=max + context: . + file: ./build/docker/Dockerfile.demo-pinecone + platforms: ${{ env.PLATFORMS }} + push: true + tags: | + ${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:latest + ${{ env.DOCKER_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }} + ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:latest + ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-demo-pinecone:${{ env.RELEASE_VERSION }} diff --git a/build/docker/Dockerfile.demo-pinecone b/build/docker/Dockerfile.demo-pinecone new file mode 100644 index 000000000..a5d709c7a --- /dev/null +++ b/build/docker/Dockerfile.demo-pinecone @@ -0,0 +1,25 @@ +FROM docker.io/golang:1.18-alpine AS base + +RUN apk --update --no-cache add bash build-base + +WORKDIR /build + +COPY . /build + +RUN mkdir -p bin +RUN go build -trimpath -o bin/ ./cmd/dendrite-demo-pinecone +RUN go build -trimpath -o bin/ ./cmd/create-account +RUN go build -trimpath -o bin/ ./cmd/generate-keys + +FROM alpine:latest +LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)" +LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" +LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" +LABEL org.opencontainers.image.licenses="Apache-2.0" + +COPY --from=base /build/bin/* /usr/bin/ + +VOLUME /etc/dendrite +WORKDIR /etc/dendrite + +ENTRYPOINT ["/usr/bin/dendrite-demo-pinecone"] From 12649ccedd858cbe075271ea234b1e268f973c5a Mon Sep 17 00:00:00 2001 From: PiotrKozimor <37144818+PiotrKozimor@users.noreply.github.com> Date: Tue, 27 Sep 2022 10:41:36 +0200 Subject: [PATCH 07/17] Improve selectRoomIDsWithAnyMembershipSQL performance (#2738) Recently I have observed that dendrite spends a lot of time (~390s) in `selectRoomIDsWithAnyMembershipSQL` query ``` dendrite_syncapi=# select total_exec_time, left(query,100) from pg_stat_statements order by total_exec_time desc limit 5 ; total_exec_time | left --------------------+------------------------------------------------------------------------------------------------------ 747826.5800519128 | SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_vis 389130.5490339942 | SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = $2 AND state_key = 376104.17514700035 | SELECT psd.datname, xact_commit, xact_rollback, blks_read, blks_hit, tup_returned, tup_fetched, tup_ 363644.164092031 | SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events WHERE event_nid = ANY($ 58570.48104699995 | SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1 AND ( $2::te (5 rows) ``` Explain analyze showed correct usage of `syncapi_room_state_unique` index: ``` dendrite_syncapi=# explain analyze SELECT distinct room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = '@qjfl:dendrite.stg.globekeeper.com'; QUERY PLAN ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Unique (cost=2749.38..2749.56 rows=24 width=52) (actual time=2.933..2.956 rows=65 loops=1) -> Sort (cost=2749.38..2749.44 rows=24 width=52) (actual time=2.932..2.937 rows=65 loops=1) Sort Key: room_id, membership Sort Method: quicksort Memory: 34kB -> Index Scan using syncapi_room_state_unique on syncapi_current_room_state (cost=0.41..2748.83 rows=24 width=52) (actual time=0.030..2.890 rows=65 loops=1) Index Cond: ((type = 'm.room.member'::text) AND (state_key = '@qjfl:dendrite.stg.globekeeper.com'::text)) Planning Time: 0.140 ms Execution Time: 2.990 ms (8 rows) ``` Multi-column indexes in Postgres shall perform well for leftmost columns, but I gave it a try and created `syncapi_current_room_state_type_state_key_idx` index. I could observe significant performance improvement. Execution time dropped from 2.9 ms to 0.24 ms: ``` explain analyze SELECT distinct room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = '@qjfl:dendrite.stg.globekeeper.com'; QUERY PLAN -------------------------------------------------------------------------------------------------------------------------------------------------------------------- Unique (cost=96.46..96.64 rows=24 width=52) (actual time=0.199..0.218 rows=65 loops=1) -> Sort (cost=96.46..96.52 rows=24 width=52) (actual time=0.199..0.202 rows=65 loops=1) Sort Key: room_id, membership Sort Method: quicksort Memory: 34kB -> Bitmap Heap Scan on syncapi_current_room_state (cost=4.53..95.91 rows=24 width=52) (actual time=0.048..0.139 rows=65 loops=1) Recheck Cond: ((type = 'm.room.member'::text) AND (state_key = '@qjfl:dendrite.stg.globekeeper.com'::text)) Heap Blocks: exact=59 -> Bitmap Index Scan on syncapi_current_room_state_type_state_key_idx (cost=0.00..4.53 rows=24 width=0) (actual time=0.037..0.037 rows=65 loops=1) Index Cond: ((type = 'm.room.member'::text) AND (state_key = '@qjfl:dendrite.stg.globekeeper.com'::text)) Planning Time: 0.236 ms Execution Time: 0.242 ms (11 rows) ``` Next improvement is skipping DISTINCT and rely on map assignment in `SelectRoomIDsWithAnyMembership`. Execution time drops by almost half: ``` explain analyze SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = '@qjfl:dendrite.stg.globekeeper.com'; QUERY PLAN -------------------------------------------------------------------------------------------------------------------------------------------------------- Bitmap Heap Scan on syncapi_current_room_state (cost=4.53..95.91 rows=24 width=52) (actual time=0.032..0.113 rows=65 loops=1) Recheck Cond: ((type = 'm.room.member'::text) AND (state_key = '@qjfl:dendrite.stg.globekeeper.com'::text)) Heap Blocks: exact=59 -> Bitmap Index Scan on syncapi_current_room_state_type_state_key_idx (cost=0.00..4.53 rows=24 width=0) (actual time=0.021..0.021 rows=65 loops=1) Index Cond: ((type = 'm.room.member'::text) AND (state_key = '@qjfl:dendrite.stg.globekeeper.com'::text)) Planning Time: 0.087 ms Execution Time: 0.136 ms (7 rows) ``` In our env we spend only 1s on inserting to table, so the write penalty of creating an index should be small. ``` dendrite_syncapi=# select total_exec_time, left(query,100) from pg_stat_statements where query like '%INSERT%syncapi_current_room_state%' order by total_exec_time desc; total_exec_time | left --------------------+------------------------------------------------------------------------------------------------------ 1139.9057619999971 | INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, he (1 row) ``` This PR does not require test modifications. ### Pull Request Checklist * [x] I have added added tests for PR _or_ I have justified why this PR doesn't need tests. * [x] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: `Piotr Kozimor ` --- syncapi/storage/postgres/current_room_state_table.go | 4 +++- syncapi/storage/sqlite3/current_room_state_table.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 083d10b88..5e6daaaf8 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -62,6 +62,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- for querying state by event IDs CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); +-- for improving selectRoomIDsWithAnyMembershipSQL +CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key); ` const upsertRoomStateSQL = "" + @@ -80,7 +82,7 @@ const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" const selectRoomIDsWithAnyMembershipSQL = "" + - "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index b88c11f88..bd1271dd6 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -51,6 +51,8 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s -- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave'; -- for querying state by event IDs CREATE UNIQUE INDEX IF NOT EXISTS syncapi_current_room_state_eventid_idx ON syncapi_current_room_state(event_id); +-- for improving selectRoomIDsWithAnyMembershipSQL +CREATE INDEX IF NOT EXISTS syncapi_current_room_state_type_state_key_idx ON syncapi_current_room_state(type, state_key); ` const upsertRoomStateSQL = "" + @@ -69,7 +71,7 @@ const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" const selectRoomIDsWithAnyMembershipSQL = "" + - "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + "SELECT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" From b5bfff47d296e704fbeb33c5b57829609bed0426 Mon Sep 17 00:00:00 2001 From: Dov Alperin Date: Tue, 27 Sep 2022 04:42:08 -0400 Subject: [PATCH 08/17] Use /usr/bin/env bash in shebangs to make them universal (#2735) Some systems (like nixos) don't have bash living at `/bin/bash` so using `/usr/bin/env bash` we can make these scripts universal. ### Pull Request Checklist * [X] I have added added tests for PR _or_ I have justified why this PR doesn't need tests. * [x] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: `Dov Alperin ` Signed-off-by: `Dov Alperin ` --- build/docker/images-build.sh | 2 +- build/docker/images-pull.sh | 2 +- build/docker/images-push.sh | 2 +- run-sytest.sh | 2 +- show-expected-fail-tests.sh | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/build/docker/images-build.sh b/build/docker/images-build.sh index eaed5f6dc..c2c140685 100755 --- a/build/docker/images-build.sh +++ b/build/docker/images-build.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash cd $(git rev-parse --show-toplevel) diff --git a/build/docker/images-pull.sh b/build/docker/images-pull.sh index 496e80067..f3f98ce7c 100755 --- a/build/docker/images-pull.sh +++ b/build/docker/images-pull.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash TAG=${1:-latest} diff --git a/build/docker/images-push.sh b/build/docker/images-push.sh index fd9b999ea..248fdee2b 100755 --- a/build/docker/images-push.sh +++ b/build/docker/images-push.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash TAG=${1:-latest} diff --git a/run-sytest.sh b/run-sytest.sh index e23982397..4ed1c8d45 100755 --- a/run-sytest.sh +++ b/run-sytest.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Runs SyTest either from Docker Hub, or from ../sytest. If it's run # locally, the Docker image is rebuilt first. diff --git a/show-expected-fail-tests.sh b/show-expected-fail-tests.sh index 3ed937a0f..b7af8f648 100755 --- a/show-expected-fail-tests.sh +++ b/show-expected-fail-tests.sh @@ -1,4 +1,4 @@ -#! /bin/bash +#!/usr/bin/env bash # # Parses a results.tap file from SyTest output and a file containing test names (a test whitelist) # and checks whether a test name that exists in the whitelist (that should pass), failed or not. From d531202b0e877a41765f9a833f6c52b6c8eae919 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Sep 2022 10:52:03 +0100 Subject: [PATCH 09/17] Build Docker images using Go 1.19 (related to #2714) --- build/docker/Dockerfile.demo-pinecone | 2 +- build/docker/Dockerfile.monolith | 2 +- build/docker/Dockerfile.polylith | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/docker/Dockerfile.demo-pinecone b/build/docker/Dockerfile.demo-pinecone index a5d709c7a..133c63c53 100644 --- a/build/docker/Dockerfile.demo-pinecone +++ b/build/docker/Dockerfile.demo-pinecone @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.18-alpine AS base +FROM docker.io/golang:1.19-alpine AS base RUN apk --update --no-cache add bash build-base diff --git a/build/docker/Dockerfile.monolith b/build/docker/Dockerfile.monolith index bb02934cd..3180e9626 100644 --- a/build/docker/Dockerfile.monolith +++ b/build/docker/Dockerfile.monolith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.18-alpine AS base +FROM docker.io/golang:1.19-alpine AS base RUN apk --update --no-cache add bash build-base diff --git a/build/docker/Dockerfile.polylith b/build/docker/Dockerfile.polylith index 166ea99cb..79f8a5f23 100644 --- a/build/docker/Dockerfile.polylith +++ b/build/docker/Dockerfile.polylith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.18-alpine AS base +FROM docker.io/golang:1.19-alpine AS base RUN apk --update --no-cache add bash build-base From f18bce93cc3e7e5f57ebc55d309360b7f8703553 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Sep 2022 11:15:49 +0100 Subject: [PATCH 10/17] Pinecone hybrid routing (update to matrix-org/pinecone#67) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 3d99a71e1..b682d9bc4 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 - github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 + github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.1-0.20220920152220-52d7b481c4b5 diff --git a/go.sum b/go.sum index 32edf8c11..1afed73a5 100644 --- a/go.sum +++ b/go.sum @@ -386,8 +386,8 @@ github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5 h1:cQMA9hip0WSp6cv7CUfButa9Jl/9E6kqWmQyOjx5A5s= github.com/matrix-org/gomatrixserverlib v0.0.0-20220926161602-759a8ee7c4d5/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89 h1:Ym50Fgn3GiYya4p29k3nJ5nYsalFGev3eIm3DeGNIq4= -github.com/matrix-org/pinecone v0.0.0-20220923151905-0900fceecb89/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d h1:kGPJ6Rg8nn5an2CbCZrRiuTNyNzE0rRMiqm4UXJYrRs= +github.com/matrix-org/pinecone v0.0.0-20220927101513-d0beb180f44d/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= From 249b32c4f3ee2e01e6f89435e0c7a5786d2ae3a1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:01:34 +0200 Subject: [PATCH 11/17] Refactor notifications (#2688) This PR changes the handling of notifications - removes the `StreamEvent` and `ReadUpdate` stream - listens on the `OutputRoomEvent` stream in the UserAPI to inform the SyncAPI about unread notifications - listens on the `OutputReceiptEvent` stream in the UserAPI to set receipts/update notifications - sets the `read_markers` directly from within the internal UserAPI Co-authored-by: Neil Alexander --- setup/jetstream/nats.go | 5 +- setup/jetstream/streams.go | 10 -- syncapi/consumers/clientapi.go | 46 ------ syncapi/consumers/receipts.go | 48 +----- syncapi/consumers/roomserver.go | 17 +-- syncapi/consumers/userapi.go | 5 +- syncapi/producers/userapi_readupdate.go | 62 -------- syncapi/producers/userapi_streamevent.go | 60 -------- syncapi/storage/interface.go | 15 +- .../postgres/notification_data_table.go | 36 ++--- syncapi/storage/shared/syncserver.go | 11 +- .../sqlite3/notification_data_table.go | 39 +++-- syncapi/storage/tables/interface.go | 2 +- syncapi/streams/stream_accountdata.go | 3 +- syncapi/streams/stream_notificationdata.go | 23 +-- syncapi/syncapi.go | 14 +- syncapi/types/types.go | 23 +-- userapi/consumers/clientapi.go | 127 ++++++++++++++++ .../{syncapi_streamevent.go => roomserver.go} | 76 +++++----- ...streamevent_test.go => roomserver_test.go} | 2 +- userapi/consumers/syncapi_readupdate.go | 137 ------------------ userapi/internal/api.go | 46 ++++++ userapi/producers/syncapi.go | 7 +- userapi/storage/interface.go | 6 +- .../storage/postgres/notifications_table.go | 51 ++----- userapi/storage/postgres/pusher_table.go | 5 +- userapi/storage/shared/storage.go | 6 +- .../storage/sqlite3/notifications_table.go | 51 ++----- userapi/storage/sqlite3/pusher_table.go | 5 +- userapi/storage/storage_test.go | 11 +- userapi/storage/tables/interface.go | 6 +- userapi/userapi.go | 11 +- 32 files changed, 368 insertions(+), 598 deletions(-) delete mode 100644 syncapi/producers/userapi_readupdate.go delete mode 100644 syncapi/producers/userapi_streamevent.go create mode 100644 userapi/consumers/clientapi.go rename userapi/consumers/{syncapi_streamevent.go => roomserver.go} (85%) rename userapi/consumers/{syncapi_streamevent_test.go => roomserver_test.go} (98%) delete mode 100644 userapi/consumers/syncapi_readupdate.go diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 3660e91e3..7409fd6c8 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -9,9 +9,10 @@ import ( "time" "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" natsserver "github.com/nats-io/nats-server/v2/server" natsclient "github.com/nats-io/nats.go" @@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"}, + OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"}, + OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"}, } { streamName := cfg.Matrix.JetStream.Prefixed(stream) for _, consumer := range consumers { diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index c07d3a0b4..ee9810dae 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, - { - Name: OutputStreamEvent, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, - { - Name: OutputReadUpdate, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, { Name: OutputPresenceEvent, Retention: nats.InterestPolicy, diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f0588cab8..a170a6ec1 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -16,9 +16,7 @@ package consumers import ( "context" - "database/sql" "encoding/json" - "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -31,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -46,7 +43,6 @@ type OutputClientDataConsumer struct { stream types.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -57,7 +53,6 @@ func NewOutputClientDataConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, - producer *producers.UserAPIReadProducer, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ ctx: process.Context(), @@ -68,7 +63,6 @@ func NewOutputClientDataConsumer( notifier: notifier, stream: stream, serverName: cfg.Matrix.ServerName, - producer: producer, } } @@ -113,15 +107,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return false } - if err = s.sendReadUpdate(ctx, userID, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": userID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - if output.IgnoredUsers != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { log.WithError(err).WithFields(logrus.Fields{ @@ -136,34 +121,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return true } - -func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { - if output.Type != "m.fully_read" || output.ReadMarker == nil { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - var fullyReadPos types.StreamPosition - if output.ReadMarker.Read != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if output.ReadMarker.FullyRead != "" { - if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) - } - } - if readPos > 0 || fullyReadPos > 0 { - if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index a18244c44..4379dd134 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -16,22 +16,19 @@ package consumers import ( "context" - "database/sql" - "fmt" "strconv" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. @@ -44,7 +41,6 @@ type OutputReceiptEventConsumer struct { stream types.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -56,7 +52,6 @@ func NewOutputReceiptEventConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, - producer *producers.UserAPIReadProducer, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ ctx: process.Context(), @@ -67,7 +62,6 @@ func NewOutputReceiptEventConsumer( notifier: notifier, stream: stream, serverName: cfg.Matrix.ServerName, - producer: producer, } } @@ -111,42 +105,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - if err = s.sendReadUpdate(ctx, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": output.UserID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } - -func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error { - if output.Type != "m.read" { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - if output.EventID != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if readPos > 0 { - if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 6979eb484..0964ae207 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -21,17 +21,17 @@ import ( "fmt" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. @@ -46,7 +46,6 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier - producer *producers.UserAPIStreamEventProducer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -59,7 +58,6 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.SyncRoomserverAPI, - producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -72,7 +70,6 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, - producer: producer, } } @@ -255,12 +252,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } - if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { - log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) - sentry.CaptureException(err) - return err - } - if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) sentry.CaptureException(err) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 227823522..c9b96f788 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -19,6 +19,9 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -26,8 +29,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputNotificationDataConsumer consumes events that originated in diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go deleted file mode 100644 index d56cab776..000000000 --- a/syncapi/producers/userapi_readupdate.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIReadProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, userID) - m.Header.Set(jetstream.RoomID, roomID) - - data := types.ReadUpdate{ - UserID: userID, - RoomID: roomID, - Read: readPos, - FullyRead: fullyReadPos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "user_id": userID, - "room_id": roomID, - "read_pos": readPos, - "fully_read_pos": fullyReadPos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go deleted file mode 100644 index 2bbd19c0b..000000000 --- a/syncapi/producers/userapi_streamevent.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIStreamEventProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.RoomID, roomID) - - data := types.StreamedEvent{ - Event: event, - StreamPosition: pos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "room_id": roomID, - "event_id": event.EventID(), - "event_type": event.Type(), - "stream_pos": pos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0c8ba4e3d..ad3be4206 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -29,6 +29,7 @@ import ( type Database interface { Presence SharedUsers + Notifications MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) @@ -148,12 +149,6 @@ type Database interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) - // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. - UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - - // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. - GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) - SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) @@ -179,3 +174,11 @@ type SharedUsers interface { // SharedUsers returns a subset of otherUserIDs that share a room with userID. SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) } + +type Notifications interface { + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + + // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms + GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) +} diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 708c3a9b4..2c7b24800 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro r := ¬ificationDataStatements{} return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, }.Prepare(db) } type notificationDataStatements struct { - upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt - selectMaxID *sql.Stmt + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCountsForRooms *sql.Stmt + selectMaxID *sql.Stmt } const notificationDataSchema = ` @@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 RETURNING id` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id = ANY($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs)) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 778ad8b18..215bad3a8 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI return } -func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) +func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { + roomIDs := make([]string, 0, len(rooms)) + for roomID, membership := range rooms { + if membership != gomatrixserverlib.Join { + continue + } + roomIDs = append(roomIDs, roomID) + } + return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs) } func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 66d4d4381..ceff60555 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t } r := ¬ificationDataStatements{ streamIDStatements: streamID, + db: db, } return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } type notificationDataStatements struct { + db *sql.DB streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + //selectUserUnreadCountsForRooms *sql.Stmt } const notificationDataSchema = ` @@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ ON CONFLICT (user_id, room_id) DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id IN ($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + params := make([]interface{}, len(roomIDs)+1) + params[0] = userID + for i := range roomIDs { + params[i+1] = roomIDs[i] + } + sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1) + rows, err := r.db.QueryContext(ctx, sql, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 193881b44..9a873c2ed 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -190,7 +190,7 @@ type Memberships interface { type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 9c19b846b..0297d5c2f 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -3,9 +3,10 @@ package streams import ( "context" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type AccountDataStreamProvider struct { diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 8ba9e07ca..33872734d 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -30,26 +30,29 @@ func (p *NotificationDataStreamProvider) CompleteSync( func (p *NotificationDataStreamProvider) IncrementalSync( ctx context.Context, req *types.SyncRequest, - from, to types.StreamPosition, + from, _ types.StreamPosition, ) types.StreamPosition { - // We want counts for all possible rooms, so always start from zero. - countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + // Get the unread notifications for rooms in our join response. + // This is to ensure clients always have an unread notification section + // and can display the correct numbers. + countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { - req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") return from } - // We're merely decorating existing rooms. Note that the Join map - // values are not pointers. + // We're merely decorating existing rooms. for roomID, jr := range req.Response.Rooms.Join { counts := countsByRoom[roomID] if counts == nil { continue } - - jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount - jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + jr.UnreadNotifications = &types.UnreadNotifications{ + HighlightCount: counts.UnreadHighlightCount, + NotificationCount: counts.UnreadNotificationCount, + } req.Response.Rooms.Join[roomID] = jr } - return to + + return p.LatestPosition(ctx) } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 68537bc45..f5d00f367 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -77,16 +77,6 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start presence consumer") } - userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), - } - - userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - } - keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), js, rsAPI, syncDB, notifier, @@ -98,7 +88,7 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, + streams.InviteStreamProvider, rsAPI, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -106,7 +96,6 @@ func AddPublicRoutes( clientConsumer := consumers.NewOutputClientDataConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, - userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") @@ -135,7 +124,6 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, - userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d75d53ca9..3b85db4a4 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool { len(r.ToDevice.Events) == 0 } +type UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` +} + // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. type JoinResponse struct { Summary struct { @@ -419,10 +424,7 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` - UnreadNotifications struct { - HighlightCount int `json:"highlight_count"` - NotificationCount int `json:"notification_count"` - } `json:"unread_notifications"` + *UnreadNotifications `json:"unread_notifications,omitempty"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -503,19 +505,6 @@ type Peek struct { Deleted bool } -type ReadUpdate struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - Read StreamPosition `json:"read,omitempty"` - FullyRead StreamPosition `json:"fully_read,omitempty"` -} - -// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type StreamedEvent struct { - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - StreamPosition StreamPosition `json:"stream_position"` -} - // OutputReceiptEvent is an entry in the receipt output kafka log type OutputReceiptEvent struct { UserID string `json:"user_id"` diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go new file mode 100644 index 000000000..c220d35cb --- /dev/null +++ b/userapi/consumers/clientapi.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/util" +) + +// OutputReceiptEventConsumer consumes events that originated in the clientAPI. +type OutputReceiptEventConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + serverName gomatrixserverlib.ServerName + syncProducer *producers.SyncAPI + pgClient pushgateway.Client +} + +// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputReceiptEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + syncProducer *producers.SyncAPI, + pgClient pushgateway.Client, +) *OutputReceiptEventConsumer { + return &OutputReceiptEventConsumer{ + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"), + db: store, + serverName: cfg.Matrix.ServerName, + syncProducer: syncProducer, + pgClient: pgClient, + } +} + +// Start consuming receipts events. +func (s *OutputReceiptEventConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userID := msg.Header.Get(jetstream.UserID) + roomID := msg.Header.Get(jetstream.RoomID) + readPos := msg.Header.Get(jetstream.EventID) + evType := msg.Header.Get("type") + + if readPos == "" || evType != "m.read" { + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.serverName { + return true + } + + metadata, err := msg.Metadata() + if err != nil { + return false + } + + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + + if !updated { + return true + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/roomserver.go similarity index 85% rename from userapi/consumers/syncapi_streamevent.go rename to userapi/consumers/roomserver.go index f3b2bf27f..952de98f7 100644 --- a/userapi/consumers/syncapi_streamevent.go +++ b/userapi/consumers/roomserver.go @@ -26,7 +26,7 @@ import ( "github.com/matrix-org/dendrite/userapi/util" ) -type OutputStreamEventConsumer struct { +type OutputRoomEventConsumer struct { ctx context.Context cfg *config.UserAPI rsAPI rsapi.UserRoomserverAPI @@ -38,7 +38,7 @@ type OutputStreamEventConsumer struct { syncProducer *producers.SyncAPI } -func NewOutputStreamEventConsumer( +func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, @@ -46,21 +46,21 @@ func NewOutputStreamEventConsumer( pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, -) *OutputStreamEventConsumer { - return &OutputStreamEventConsumer{ +) *OutputRoomEventConsumer { + return &OutputRoomEventConsumer{ ctx: process.Context(), cfg: cfg, jetstream: js, db: store, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), + durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), pgClient: pgClient, rsAPI: rsAPI, syncProducer: syncProducer, } } -func (s *OutputStreamEventConsumer) Start() error { +func (s *OutputRoomEventConsumer) Start() error { if err := jetstream.JetStreamConsumer( s.ctx, s.jetstream, s.topic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), @@ -70,35 +70,43 @@ func (s *OutputStreamEventConsumer) Start() error { return nil } -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { msg := msgs[0] // Guaranteed to exist if onMessage is called - var output types.StreamedEvent - output.Event = &gomatrixserverlib.HeaderedEvent{} + var output rsapi.OutputEvent if err := json.Unmarshal(msg.Data, &output); err != nil { - log.WithError(err).Errorf("userapi consumer: message parse failure") + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") return true } - if output.Event.Event == nil { + if output.Type != rsapi.OutputTypeNewRoomEvent { + return true + } + event := output.NewRoomEvent.Event + if event == nil { log.Errorf("userapi consumer: expected event") return true } log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), - "event_type": output.Event.Type(), - "stream_pos": output.StreamPosition, - }).Tracef("Received message from sync API: %#v", output) + "event_id": event.EventID(), + "event_type": event.Type(), + }).Tracef("Received message from roomserver: %#v", output) - if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + metadata, err := msg.Metadata() + if err != nil { + return true + } + + if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), + "event_id": event.EventID(), }).WithError(err).Errorf("userapi consumer: process room event failure") } return true } -func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { +func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) @@ -138,10 +146,10 @@ func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *g // removing it means we can send all notifications to // e.g. Element's Push gateway in one go. for _, mem := range members { - if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil { log.WithFields(log.Fields{ "localpart": mem.Localpart, - }).WithError(err).Debugf("Unable to push to local user") + }).WithError(err).Error("Unable to push to local user") continue } } @@ -179,7 +187,7 @@ func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, // localRoomMembers fetches the current local members of a room, and // the total number of members. -func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { +func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, @@ -219,7 +227,7 @@ func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID // looks it up in roomserver. If there is no name, // m.room.canonical_alias is consulted. Returns an empty string if the // room has no name. -func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { +func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { if event.Type() == gomatrixserverlib.MRoomName { name, err := unmarshalRoomName(event) if err != nil { @@ -287,7 +295,7 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er } // notifyLocal finds the right push actions for a local user, given an event. -func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { +func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { return err @@ -302,7 +310,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma "event_id": event.EventID(), "room_id": event.RoomID(), "localpart": mem.Localpart, - }).Debugf("Push rule evaluation rejected the event") + }).Tracef("Push rule evaluation rejected the event") return nil } @@ -325,7 +333,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { return err } @@ -345,7 +353,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma "localpart": mem.Localpart, "num_urls": len(devicesByURLAndFormat), "num_unread": userNumUnreadNotifs, - }).Debugf("Notifying single member") + }).Trace("Notifying single member") // Push gateways are out of our control, and we cannot risk // looking up the server on a misbehaving push gateway. Each user @@ -396,7 +404,7 @@ func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *goma // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). -func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { +func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { if event.Sender() == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. @@ -447,7 +455,7 @@ func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event "room_id": event.RoomID(), "localpart": mem.Localpart, "rule_id": rule.RuleID, - }).Tracef("Matched a push rule") + }).Trace("Matched a push rule") return rule.Actions, nil } @@ -491,7 +499,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) if err != nil { return nil, "", err @@ -515,7 +523,7 @@ func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localp } // notifyHTTP performs a notificatation to a Push Gateway. -func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { +func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { logger := log.WithFields(log.Fields{ "event_id": event.EventID(), "url": url, @@ -561,13 +569,13 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat } } - logger.Debugf("Notifying push gateway %s", url) + logger.Tracef("Notifying push gateway %s", url) var res pushgateway.NotifyResponse if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { logger.WithError(err).Errorf("Failed to notify push gateway %s", url) return nil, err } - logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result") if len(res.Rejected) == 0 { return nil, nil @@ -589,7 +597,7 @@ func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomat } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, diff --git a/userapi/consumers/syncapi_streamevent_test.go b/userapi/consumers/roomserver_test.go similarity index 98% rename from userapi/consumers/syncapi_streamevent_test.go rename to userapi/consumers/roomserver_test.go index 48ea0fe11..3bbeb439a 100644 --- a/userapi/consumers/syncapi_streamevent_test.go +++ b/userapi/consumers/roomserver_test.go @@ -40,7 +40,7 @@ func Test_evaluatePushRules(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputStreamEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db} testCases := []struct { name string diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go deleted file mode 100644 index 54654f757..000000000 --- a/userapi/consumers/syncapi_readupdate.go +++ /dev/null @@ -1,137 +0,0 @@ -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/syncapi/types" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/producers" - "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/dendrite/userapi/util" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -type OutputReadUpdateConsumer struct { - ctx context.Context - cfg *config.UserAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - pgClient pushgateway.Client - ServerName gomatrixserverlib.ServerName - topic string - userAPI uapi.UserInternalAPI - syncProducer *producers.SyncAPI -} - -func NewOutputReadUpdateConsumer( - process *process.ProcessContext, - cfg *config.UserAPI, - js nats.JetStreamContext, - store storage.Database, - pgClient pushgateway.Client, - userAPI uapi.UserInternalAPI, - syncProducer *producers.SyncAPI, -) *OutputReadUpdateConsumer { - return &OutputReadUpdateConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - pgClient: pgClient, - userAPI: userAPI, - syncProducer: syncProducer, - } -} - -func (s *OutputReadUpdateConsumer) Start() error { - if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, 1, - s.onMessage, nats.DeliverAll(), nats.ManualAck(), - ); err != nil { - return err - } - return nil -} - -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var read types.ReadUpdate - if err := json.Unmarshal(msg.Data, &read); err != nil { - log.WithError(err).Error("userapi clientapi consumer: message parse failure") - return true - } - if read.FullyRead == 0 && read.Read == 0 { - return true - } - - userID := string(msg.Header.Get(jetstream.UserID)) - roomID := string(msg.Header.Get(jetstream.RoomID)) - - localpart, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - log.WithError(err).Error("userapi clientapi consumer: SplitID failure") - return true - } - if domain != s.ServerName { - log.Error("userapi clientapi consumer: not a local user") - return true - } - - log := log.WithFields(log.Fields{ - "room_id": roomID, - "user_id": userID, - }) - log.Tracef("Received read update from sync API: %#v", read) - - if read.Read > 0 { - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) - if err != nil { - log.WithError(err).Error("userapi EDU consumer") - return false - } - - if updated { - if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { - log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") - return false - } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") - return false - } - } - } - - if read.FullyRead > 0 { - deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) - if err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") - return false - } - - if deleted { - if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") - return false - } - - if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") - return false - } - } - } - - return true -} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index dcbb73614..3e761a886 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -39,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + userapiUtil "github.com/matrix-org/dendrite/userapi/util" ) type UserInternalAPI struct { @@ -51,6 +53,7 @@ type UserInternalAPI struct { AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI + PgClient pushgateway.Client } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc ignoredUsers = &synctypes.IgnoredUsers{} _ = json.Unmarshal(req.AccountData, ignoredUsers) } + if req.DataType == "m.fully_read" { + if err := a.setFullyRead(ctx, req); err != nil { + return err + } + } if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ RoomID: req.RoomID, Type: req.DataType, @@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc return nil } +func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { + var output eventutil.ReadMarkerJSON + + if err := json.Unmarshal(req.AccountData, &output); err != nil { + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") + return nil + } + if domain != a.ServerName { + return nil + } + + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + if err != nil { + logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") + return err + } + + if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") + return err + } + + // nothing changed, no need to notify the push gateway + if !deleted { + return nil + } + + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") + return err + } + return nil +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 27cfc2848..f556ea352 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" ) type JetStreamPublisher interface { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index fbac463e2..02efe7afe 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -119,9 +119,9 @@ type ThreePID interface { } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) + InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index a27c1125e..24a30b2f5 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 2eb379ae4..6fb714fba 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e32a442d0..3ff299f1b 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) return err @@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) return err diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index df8260251..a35ec7be5 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index dba97c3d4..4de0a9f06 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a26097338..ca7c1bfd2 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -7,6 +7,11 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" @@ -14,10 +19,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/bcrypt" ) const loginTokenLifetime = time.Minute @@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2fe955670..cc4287997 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -105,9 +105,9 @@ type PusherTable interface { type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/userapi.go b/userapi/userapi.go index 23855a89f..d26b4e19a 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -81,16 +81,17 @@ func NewInternalAPI( KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, + PgClient: pgClient, } - readConsumer := consumers.NewOutputReadUpdateConsumer( - base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + receiptConsumer := consumers.NewOutputReceiptEventConsumer( + base.ProcessContext, cfg, js, db, syncProducer, pgClient, ) - if err := readConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start user API read update consumer") + if err := receiptConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API receipt consumer") } - eventConsumer := consumers.NewOutputStreamEventConsumer( + eventConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, ) if err := eventConsumer.Start(); err != nil { From 6c67552bf9eee18f656d731adf646aa09c5d7c92 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Sep 2022 15:50:22 +0100 Subject: [PATCH 12/17] Return `M_UNRECOGNIZED` for unknown CS API endpoints/actions (#2740) Fixes #2739. --- setup/base/base.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/setup/base/base.go b/setup/base/base.go index 0c7b222d0..32716c766 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -392,17 +392,26 @@ func (b *BaseDendrite) configureHTTPErrors() { _, _ = w.Write([]byte(fmt.Sprintf("405 %s not allowed on this endpoint", r.Method))) } + clientNotFoundHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}`)) // nolint:misspell + } + notFoundCORSHandler := httputil.WrapHandlerInCORS(http.NotFoundHandler()) notAllowedCORSHandler := httputil.WrapHandlerInCORS(http.HandlerFunc(notAllowedHandler)) for _, router := range []*mux.Router{ - b.PublicClientAPIMux, b.PublicMediaAPIMux, - b.DendriteAdminMux, b.SynapseAdminMux, - b.PublicWellKnownAPIMux, + b.PublicMediaAPIMux, b.DendriteAdminMux, + b.SynapseAdminMux, b.PublicWellKnownAPIMux, } { router.NotFoundHandler = notFoundCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler } + + // Special case so that we don't upset clients on the CS API. + b.PublicClientAPIMux.NotFoundHandler = http.HandlerFunc(clientNotFoundHandler) + b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) } // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on From 87be32ca2671173a4287a938932e543410a32c3a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 27 Sep 2022 18:06:49 +0200 Subject: [PATCH 13/17] Fulltext implementation using Bleve (#2675) Based on #2480 This actually indexes events based on their event type. They are removed from the index if we receive a `m.room.redaction` event on the `OutputRoomEvent` stream. An admin endpoint is added to reindex all existing events. Co-authored-by: Neil Alexander --- clientapi/routing/admin.go | 22 ++ clientapi/routing/routing.go | 17 +- cmd/generate-config/main.go | 9 +- dendrite-sample.monolith.yaml | 9 +- dendrite-sample.polylith.yaml | 9 +- docs/administration/4_adminapi.md | 5 + docs/installation/7_configuration.md | 4 +- internal/fulltext/bleve.go | 3 +- internal/fulltext/bleve_test.go | 10 +- setup/base/base.go | 7 +- setup/config/config_syncapi.go | 10 +- syncapi/consumers/clientapi.go | 109 +++++- syncapi/consumers/roomserver.go | 54 +++ syncapi/routing/context.go | 7 +- syncapi/routing/routing.go | 28 +- syncapi/routing/search.go | 344 ++++++++++++++++++ syncapi/storage/interface.go | 1 + .../postgres/output_room_events_table.go | 28 ++ syncapi/storage/shared/syncserver.go | 8 + .../sqlite3/output_room_events_table.go | 41 +++ syncapi/storage/tables/interface.go | 1 + syncapi/syncapi.go | 7 +- 22 files changed, 680 insertions(+), 53 deletions(-) create mode 100644 syncapi/routing/search.go diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 0c5f8c167..5089d7c36 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -3,15 +3,19 @@ package routing import ( "encoding/json" "net/http" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" ) func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { @@ -138,3 +142,21 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap }, } } + +func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse { + if device.AccountType != userapi.AccountTypeAdmin { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("This API can only be used by admin users."), + } + } + _, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10) + if err != nil { + logrus.WithError(err).Error("failed to publish nats message") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index d7a48d228..9c1f8f720 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,6 +20,12 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth" @@ -34,11 +40,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/nats-io/nats.go" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client @@ -161,6 +162,12 @@ func Setup( }), ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/fulltext/reindex", + httputil.MakeAuthAPI("admin_fultext_reindex", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminReindex(req, cfg, device, natsClient) + }), + ).Methods(http.MethodGet, http.MethodOptions) + // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index c24e8153e..8b042c56e 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -5,10 +5,11 @@ import ( "fmt" "path/filepath" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v2" + + "github.com/matrix-org/dendrite/setup/config" ) func main() { @@ -82,6 +83,12 @@ func main() { EnableInbound: true, EnableOutbound: true, } + cfg.SyncAPI.Fulltext = config.Fulltext{ + Enabled: true, + IndexPath: config.Path(filepath.Join(*dirPath, "fulltextindex")), + InMemory: true, + Language: "en", + } } } else { var err error diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml index f1758f54d..3cad17da8 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.monolith.yaml @@ -275,10 +275,15 @@ sync_api: # address of the client. This is likely required if Dendrite is running behind # a reverse proxy server. # real_ip_header: X-Real-IP - fulltext: + + # Configuration for the fulltext search + search: enabled: false + # The path where the fulltext index will be created in. index_path: "./fulltextindex" - language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + # The language most likely to be used on the server - used when indexing, to ensure the returned results match the expectations. + # A full list of possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + language: "en" # Configuration for the User API. user_api: diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index 97d10825f..e58062fe2 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -326,10 +326,15 @@ sync_api: max_open_conns: 10 max_idle_conns: 2 conn_max_lifetime: -1 - fulltext: + + # Configuration for the fulltext search + search: enabled: false + # The path where the fulltext index will be created in. index_path: "./fulltextindex" - language: "en" # more possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + # The language most likely to be used on the server - used when indexing, to ensure the returned results match the expectations. + # A full list of possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + language: "en" # This option controls which HTTP header to inspect to find the real remote IP # address of the client. This is likely required if Dendrite is running behind diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index a34bfde1f..1712bb1bf 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -57,6 +57,11 @@ Request body format: Reset the password of a local user. The `localpart` is the username only, i.e. if the full user ID is `@alice:domain.com` then the local part is `alice`. +## GET `/_dendrite/admin/fulltext/reindex` + +This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. +Indexing is done in the background, the server logs every 1000 events (or below) when they are being indexed. Once reindexing is done, you'll see something along the lines `Indexed 69586 events in 53.68223182s` in your debug logs. + ## POST `/_synapse/admin/v1/send_server_notice` Request body format: diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 8fbe71c40..67cd339cf 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -140,12 +140,12 @@ room_server: ## Fulltext search -Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expections. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). +Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). ```yaml sync_api: # ... - fulltext: + search: enabled: false index_path: "./fulltextindex" language: "en" diff --git a/internal/fulltext/bleve.go b/internal/fulltext/bleve.go index b07c0e51d..da8932f5c 100644 --- a/internal/fulltext/bleve.go +++ b/internal/fulltext/bleve.go @@ -22,8 +22,9 @@ import ( "github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2/mapping" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/setup/config" ) // Search contains all existing bleve.Index diff --git a/internal/fulltext/bleve_test.go b/internal/fulltext/bleve_test.go index 84a282423..d16397a45 100644 --- a/internal/fulltext/bleve_test.go +++ b/internal/fulltext/bleve_test.go @@ -27,11 +27,11 @@ import ( func mustOpenIndex(t *testing.T, tempDir string) *fulltext.Search { t.Helper() - cfg := config.Fulltext{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) + cfg := config.Fulltext{ + Enabled: true, + InMemory: true, + Language: "en", + } if tempDir != "" { cfg.IndexPath = config.Path(tempDir) cfg.InMemory = false diff --git a/setup/base/base.go b/setup/base/base.go index 32716c766..0636c7b8d 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -37,16 +37,13 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/gorilla/mux" "github.com/kardianos/minwinsvc" @@ -61,6 +58,8 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" rsinthttp "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" ) diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index c890b0054..edef22c93 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -10,7 +10,7 @@ type SyncAPI struct { RealIPHeader string `yaml:"real_ip_header"` - Fulltext Fulltext `yaml:"fulltext"` + Fulltext Fulltext `yaml:"search"` } func (c *SyncAPI) Defaults(opts DefaultOpts) { @@ -52,16 +52,12 @@ func (f *Fulltext) Defaults(opts DefaultOpts) { f.Enabled = false f.IndexPath = "./fulltextindex" f.Language = "en" - if opts.Generate { - f.Enabled = true - f.InMemory = true - } } func (f *Fulltext) Verify(configErrs *ConfigErrors, isMonolith bool) { if !f.Enabled { return } - checkNotEmpty(configErrs, "syncapi.fulltext.index_path", string(f.IndexPath)) - checkNotEmpty(configErrs, "syncapi.fulltext.language", f.Language) + checkNotEmpty(configErrs, "syncapi.search.index_path", string(f.IndexPath)) + checkNotEmpty(configErrs, "syncapi.search.language", f.Language) } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index a170a6ec1..b11ed4f5e 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -17,14 +17,18 @@ package consumers import ( "context" "encoding/json" + "strings" + "time" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" @@ -35,14 +39,18 @@ import ( // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream types.StreamProvider - notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + nats *nats.Conn + durable string + topic string + topicReIndex string + db storage.Database + stream types.StreamProvider + notifier *notifier.Notifier + serverName gomatrixserverlib.ServerName + fts *fulltext.Search + cfg *config.SyncAPI } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -50,24 +58,93 @@ func NewOutputClientDataConsumer( process *process.ProcessContext, cfg *config.SyncAPI, js nats.JetStreamContext, + nats *nats.Conn, store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, + fts *fulltext.Search, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), - durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"), - db: store, - notifier: notifier, - stream: stream, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), + topicReIndex: cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex), + durable: cfg.Matrix.JetStream.Durable("SyncAPIAccountDataConsumer"), + nats: nats, + db: store, + notifier: notifier, + stream: stream, + serverName: cfg.Matrix.ServerName, + fts: fts, + cfg: cfg, } } // Start consuming from room servers func (s *OutputClientDataConsumer) Start() error { + _, err := s.nats.Subscribe(s.topicReIndex, func(msg *nats.Msg) { + if err := msg.Ack(); err != nil { + return + } + if !s.cfg.Fulltext.Enabled { + logrus.Warn("Fulltext indexing is disabled") + return + } + ctx := context.Background() + logrus.Debugf("Starting to index events") + var offset int + start := time.Now() + count := 0 + var id int64 = 0 + for { + evs, err := s.db.ReIndex(ctx, 1000, id) + if err != nil { + logrus.WithError(err).Errorf("unable to get events to index") + return + } + if len(evs) == 0 { + break + } + logrus.Debugf("Indexing %d events", len(evs)) + elements := make([]fulltext.IndexElement, 0, len(evs)) + + for streamPos, ev := range evs { + id = streamPos + e := fulltext.IndexElement{ + EventID: ev.EventID(), + RoomID: ev.RoomID(), + StreamPosition: streamPos, + } + e.SetContentType(ev.Type()) + + switch ev.Type() { + case "m.room.message": + e.Content = gjson.GetBytes(ev.Content(), "body").String() + case gomatrixserverlib.MRoomName: + e.Content = gjson.GetBytes(ev.Content(), "name").String() + case gomatrixserverlib.MRoomTopic: + e.Content = gjson.GetBytes(ev.Content(), "topic").String() + default: + continue + } + + if strings.TrimSpace(e.Content) == "" { + continue + } + elements = append(elements, e) + } + if err = s.fts.Index(elements...); err != nil { + logrus.WithError(err).Error("unable to index events") + continue + } + offset += len(evs) + count += len(elements) + } + logrus.Debugf("Indexed %d events in %v", count, time.Since(start)) + }) + if err != nil { + return err + } return jetstream.JetStreamConsumer( s.ctx, s.jetstream, s.topic, s.durable, 1, s.onMessage, nats.DeliverAll(), nats.ManualAck(), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 0964ae207..3756ad75c 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -24,7 +24,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -46,6 +48,7 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier + fts *fulltext.Search } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -58,6 +61,7 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.SyncRoomserverAPI, + fts *fulltext.Search, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -70,6 +74,7 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, + fts: fts, } } @@ -251,6 +256,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write new event failure") return nil } + if err = s.writeFTS(ev, pduPos); err != nil { + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "type": ev.Type(), + }).WithError(err).Warn("failed to index fulltext element") + } if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) @@ -295,6 +306,13 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent( return nil } + if err = s.writeFTS(ev, pduPos); err != nil { + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "type": ev.Type(), + }).WithError(err).Warn("failed to index fulltext element") + } + if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) return err @@ -451,3 +469,39 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head event.Event, err = event.SetUnsigned(prev) return event, err } + +func (s *OutputRoomEventConsumer) writeFTS(ev *gomatrixserverlib.HeaderedEvent, pduPosition types.StreamPosition) error { + if !s.cfg.Fulltext.Enabled { + return nil + } + e := fulltext.IndexElement{ + EventID: ev.EventID(), + RoomID: ev.RoomID(), + StreamPosition: int64(pduPosition), + } + e.SetContentType(ev.Type()) + + switch ev.Type() { + case "m.room.message": + e.Content = gjson.GetBytes(ev.Content(), "body").String() + case gomatrixserverlib.MRoomName: + e.Content = gjson.GetBytes(ev.Content(), "name").String() + case gomatrixserverlib.MRoomTopic: + e.Content = gjson.GetBytes(ev.Content(), "topic").String() + case gomatrixserverlib.MRoomRedaction: + log.Tracef("Redacting event: %s", ev.Redacts()) + if err := s.fts.Delete(ev.Redacts()); err != nil { + return fmt.Errorf("failed to delete entry from fulltext index: %w", err) + } + return nil + default: + return nil + } + if e.Content != "" { + log.Tracef("Indexing element: %+v", e) + if err := s.fts.Index(e); err != nil { + return err + } + } + return nil +} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 13c4e9d89..1ebdfe604 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -37,11 +37,11 @@ import ( type ContextRespsonse struct { End string `json:"end"` - Event gomatrixserverlib.ClientEvent `json:"event"` + Event *gomatrixserverlib.ClientEvent `json:"event,omitempty"` EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after,omitempty"` EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before,omitempty"` Start string `json:"start"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } func Context( @@ -162,8 +162,9 @@ func Context( eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ - Event: gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll), + Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, State: gomatrixserverlib.HeaderedToClientEvents(newState, gomatrixserverlib.FormatAll), diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 6bc495d8d..8f84a1341 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -18,15 +18,18 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" ) // Setup configures the given mux with sync-server listeners @@ -40,6 +43,7 @@ func Setup( rsAPI api.SyncRoomserverAPI, cfg *config.SyncAPI, lazyLoadCache caching.LazyLoadCache, + fts *fulltext.Search, ) { v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -95,4 +99,24 @@ func Setup( ) }), ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/search", + httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if !cfg.Fulltext.Enabled { + return util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("Search has been disabled by the server administrator."), + } + } + var nextBatch *string + if err := req.ParseForm(); err != nil { + return jsonerror.InternalServerError() + } + if req.Form.Has("next_batch") { + nb := req.FormValue("next_batch") + nextBatch = &nb + } + return Search(req, device, syncDB, fts, nextBatch) + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go new file mode 100644 index 000000000..341efeb14 --- /dev/null +++ b/syncapi/routing/search.go @@ -0,0 +1,344 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/blevesearch/bleve/v2/search" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/fulltext" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/userapi/api" +) + +// nolint:gocyclo +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts *fulltext.Search, from *string) util.JSONResponse { + start := time.Now() + var ( + searchReq SearchRequest + err error + ctx = req.Context() + ) + resErr := httputil.UnmarshalJSONRequest(req, &searchReq) + if resErr != nil { + logrus.Error("failed to unmarshal search request") + return *resErr + } + + nextBatch := 0 + if from != nil && *from != "" { + nextBatch, err = strconv.Atoi(*from) + if err != nil { + return jsonerror.InternalServerError() + } + } + + if searchReq.SearchCategories.RoomEvents.Filter.Limit == 0 { + searchReq.SearchCategories.RoomEvents.Filter.Limit = 5 + } + + // only search rooms the user is actually joined to + joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join") + if err != nil { + return jsonerror.InternalServerError() + } + if len(joinedRooms) == 0 { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("User not joined to any rooms."), + } + } + joinedRoomsMap := make(map[string]struct{}, len(joinedRooms)) + for _, roomID := range joinedRooms { + joinedRoomsMap[roomID] = struct{}{} + } + rooms := []string{} + if searchReq.SearchCategories.RoomEvents.Filter.Rooms != nil { + for _, roomID := range *searchReq.SearchCategories.RoomEvents.Filter.Rooms { + if _, ok := joinedRoomsMap[roomID]; ok { + rooms = append(rooms, roomID) + } + } + } else { + rooms = joinedRooms + } + + if len(rooms) == 0 { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown("User not allowed to search in this room(s)."), + } + } + + orderByTime := searchReq.SearchCategories.RoomEvents.OrderBy == "recent" + + result, err := fts.Search( + searchReq.SearchCategories.RoomEvents.SearchTerm, + rooms, + searchReq.SearchCategories.RoomEvents.Keys, + searchReq.SearchCategories.RoomEvents.Filter.Limit, + nextBatch, + orderByTime, + ) + if err != nil { + logrus.WithError(err).Error("failed to search fulltext") + return jsonerror.InternalServerError() + } + logrus.Debugf("Search took %s", result.Took) + + // From was specified but empty, return no results, only the count + if from != nil && *from == "" { + return util.JSONResponse{ + Code: http.StatusOK, + JSON: SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + NextBatch: nil, + }, + }, + }, + } + } + + results := []Result{} + + wantEvents := make([]string, 0, len(result.Hits)) + eventScore := make(map[string]*search.DocumentMatch) + + for _, hit := range result.Hits { + wantEvents = append(wantEvents, hit.ID) + eventScore[hit.ID] = hit + } + + // Filter on m.room.message, as otherwise we also get events like m.reaction + // which "breaks" displaying results in Element Web. + types := []string{"m.room.message"} + roomFilter := &gomatrixserverlib.RoomEventFilter{ + Rooms: &rooms, + Types: &types, + } + + evs, err := syncDB.Events(ctx, wantEvents) + if err != nil { + logrus.WithError(err).Error("failed to get events from database") + return jsonerror.InternalServerError() + } + + groups := make(map[string]RoomResult) + knownUsersProfiles := make(map[string]ProfileInfo) + + // Sort the events by depth, as the returned values aren't ordered + if orderByTime { + sort.Slice(evs, func(i, j int) bool { + return evs[i].Depth() > evs[j].Depth() + }) + } + + stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) + for _, event := range evs { + eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq) + if err != nil { + logrus.WithError(err).Error("failed to get context events") + return jsonerror.InternalServerError() + } + startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) + if err != nil { + logrus.WithError(err).Error("failed to get start/end") + return jsonerror.InternalServerError() + } + + profileInfos := make(map[string]ProfileInfo) + for _, ev := range append(eventsBefore, eventsAfter...) { + profile, ok := knownUsersProfiles[event.Sender()] + if !ok { + stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) + if err != nil { + logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + continue + } + if stateEvent == nil { + continue + } + profile = ProfileInfo{ + AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, + DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, + } + knownUsersProfiles[event.Sender()] = profile + } + profileInfos[ev.Sender()] = profile + } + + results = append(results, Result{ + Context: SearchContextResponse{ + Start: startToken.String(), + End: endToken.String(), + EventsAfter: gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatSync), + EventsBefore: gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatSync), + ProfileInfo: profileInfos, + }, + Rank: eventScore[event.EventID()].Score, + Result: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll), + }) + roomGroup := groups[event.RoomID()] + roomGroup.Results = append(roomGroup.Results, event.EventID()) + groups[event.RoomID()] = roomGroup + if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { + stateFilter := gomatrixserverlib.DefaultStateFilter() + state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil) + if err != nil { + logrus.WithError(err).Error("unable to get current state") + return jsonerror.InternalServerError() + } + stateForRooms[event.RoomID()] = gomatrixserverlib.HeaderedToClientEvents(state, gomatrixserverlib.FormatSync) + } + } + + var nextBatchResult *string = nil + if int(result.Total) > nextBatch+len(results) { + nb := strconv.Itoa(len(results) + nextBatch) + nextBatchResult = &nb + } else if int(result.Total) == nextBatch+len(results) { + // Sytest expects a next_batch even if we don't actually have any more results + nb := "" + nextBatchResult = &nb + } + + res := SearchResponse{ + SearchCategories: SearchCategories{ + RoomEvents: RoomEvents{ + Count: int(result.Total), + Groups: Groups{RoomID: groups}, + Results: results, + NextBatch: nextBatchResult, + Highlights: strings.Split(searchReq.SearchCategories.RoomEvents.SearchTerm, " "), + State: stateForRooms, + }, + }, + } + + logrus.Debugf("Full search request took %v", time.Since(start)) + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} + +// contextEvents returns the events around a given eventID +func contextEvents( + ctx context.Context, + syncDB storage.Database, + event *gomatrixserverlib.HeaderedEvent, + roomFilter *gomatrixserverlib.RoomEventFilter, + searchReq SearchRequest, +) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { + id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID()) + if err != nil { + logrus.WithError(err).Error("failed to query context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query before context event") + return nil, nil, err + } + roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit + _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) + if err != nil { + logrus.WithError(err).Error("failed to query after context event") + return nil, nil, err + } + return eventsBefore, eventsAfter, err +} + +type SearchRequest struct { + SearchCategories struct { + RoomEvents struct { + EventContext struct { + AfterLimit int `json:"after_limit,omitempty"` + BeforeLimit int `json:"before_limit,omitempty"` + IncludeProfile bool `json:"include_profile,omitempty"` + } `json:"event_context"` + Filter gomatrixserverlib.StateFilter `json:"filter"` + Groupings struct { + GroupBy []struct { + Key string `json:"key"` + } `json:"group_by"` + } `json:"groupings"` + IncludeState bool `json:"include_state"` + Keys []string `json:"keys"` + OrderBy string `json:"order_by"` + SearchTerm string `json:"search_term"` + } `json:"room_events"` + } `json:"search_categories"` +} + +type SearchResponse struct { + SearchCategories SearchCategories `json:"search_categories"` +} +type RoomResult struct { + NextBatch *string `json:"next_batch,omitempty"` + Order int `json:"order"` + Results []string `json:"results"` +} + +type Groups struct { + RoomID map[string]RoomResult `json:"room_id"` +} + +type Result struct { + Context SearchContextResponse `json:"context"` + Rank float64 `json:"rank"` + Result gomatrixserverlib.ClientEvent `json:"result"` +} + +type SearchContextResponse struct { + End string `json:"end"` + EventsAfter []gomatrixserverlib.ClientEvent `json:"events_after"` + EventsBefore []gomatrixserverlib.ClientEvent `json:"events_before"` + Start string `json:"start"` + ProfileInfo map[string]ProfileInfo `json:"profile_info"` +} + +type ProfileInfo struct { + AvatarURL string `json:"avatar_url"` + DisplayName string `json:"display_name"` +} + +type RoomEvents struct { + Count int `json:"count"` + Groups Groups `json:"groups"` + Highlights []string `json:"highlights"` + NextBatch *string `json:"next_batch,omitempty"` + Results []Result `json:"results"` + State map[string][]gomatrixserverlib.ClientEvent `json:"state,omitempty"` +} +type SearchCategories struct { + RoomEvents RoomEvents `json:"room_events"` +} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index ad3be4206..dd03365e9 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -161,6 +161,7 @@ type Database interface { // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) } type Presence interface { diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 041f99061..20a9ea428 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -166,6 +166,8 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -180,6 +182,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + selectSearchStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -215,6 +218,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -632,3 +636,27 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { } return result, rows.Err() } + +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed") + + var eventID string + var id int64 + result := make(map[int64]gomatrixserverlib.HeaderedEvent) + for rows.Next() { + var ev gomatrixserverlib.HeaderedEvent + var eventBytes []byte + if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { + return nil, err + } + if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + result[id] = ev + } + return result, rows.Err() +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 215bad3a8..47e3a991c 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1093,3 +1093,11 @@ func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.Stre func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } + +func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ + gomatrixserverlib.MRoomName, + gomatrixserverlib.MRoomTopic, + "m.room.message", + }) +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1626e32ef..6269f4fdf 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -115,6 +115,8 @@ const selectContextAfterEventSQL = "" + // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters +const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -125,6 +127,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + //selectSearchStmt *sql.Stmt - prepared at runtime } func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) { @@ -157,6 +160,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -628,3 +632,40 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ } return } + +func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + params := make([]interface{}, len(types)) + for i := range types { + params[i] = types[i] + } + params = append(params, afterID) + params = append(params, limit) + selectSQL := strings.Replace(selectSearchSQL, "($1)", sqlutil.QueryVariadic(len(types)), 1) + + stmt, err := s.db.Prepare(selectSQL) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed") + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "rows.close() failed") + + var eventID string + var id int64 + result := make(map[int64]gomatrixserverlib.HeaderedEvent) + for rows.Next() { + var ev gomatrixserverlib.HeaderedEvent + var eventBytes []byte + if err = rows.Scan(&id, &eventID, &eventBytes); err != nil { + return nil, err + } + if err = ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + result[id] = ev + } + return result, rows.Err() +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 9a873c2ed..2a6d6fa82 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -75,6 +75,7 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } // Topology keeps track of the depths and stream positions for all events. diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index f5d00f367..be19310f2 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -88,14 +88,15 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, + streams.InviteStreamProvider, rsAPI, base.Fulltext, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") } clientConsumer := consumers.NewOutputClientDataConsumer( - base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, + base.ProcessContext, cfg, js, natsClient, syncDB, notifier, + streams.AccountDataStreamProvider, base.Fulltext, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") @@ -131,6 +132,6 @@ func AddPublicRoutes( routing.Setup( base.PublicClientAPIMux, requestPool, syncDB, userAPI, - rsAPI, cfg, base.Caches, + rsAPI, cfg, base.Caches, base.Fulltext, ) } From 34993717fd702db50a82858a3ad4b660f0c7feac Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Sep 2022 17:10:47 +0100 Subject: [PATCH 14/17] Update search docs --- cmd/generate-config/main.go | 2 +- dendrite-sample.monolith.yaml | 14 +++++++++----- dendrite-sample.polylith.yaml | 14 +++++++++----- docs/installation/7_configuration.md | 8 +++++--- setup/config/config_syncapi.go | 2 +- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 8b042c56e..33b18c471 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -85,7 +85,7 @@ func main() { } cfg.SyncAPI.Fulltext = config.Fulltext{ Enabled: true, - IndexPath: config.Path(filepath.Join(*dirPath, "fulltextindex")), + IndexPath: config.Path(filepath.Join(*dirPath, "searchindex")), InMemory: true, Language: "en", } diff --git a/dendrite-sample.monolith.yaml b/dendrite-sample.monolith.yaml index 3cad17da8..e41e83d7c 100644 --- a/dendrite-sample.monolith.yaml +++ b/dendrite-sample.monolith.yaml @@ -276,13 +276,17 @@ sync_api: # a reverse proxy server. # real_ip_header: X-Real-IP - # Configuration for the fulltext search + # Configuration for the full-text search engine. search: + # Whether or not search is enabled. enabled: false - # The path where the fulltext index will be created in. - index_path: "./fulltextindex" - # The language most likely to be used on the server - used when indexing, to ensure the returned results match the expectations. - # A full list of possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + + # The path where the search index will be created in. + index_path: "./searchindex" + + # The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang language: "en" # Configuration for the User API. diff --git a/dendrite-sample.polylith.yaml b/dendrite-sample.polylith.yaml index e58062fe2..0ae4cc8fb 100644 --- a/dendrite-sample.polylith.yaml +++ b/dendrite-sample.polylith.yaml @@ -327,13 +327,17 @@ sync_api: max_idle_conns: 2 conn_max_lifetime: -1 - # Configuration for the fulltext search + # Configuration for the full-text search engine. search: + # Whether or not search is enabled. enabled: false - # The path where the fulltext index will be created in. - index_path: "./fulltextindex" - # The language most likely to be used on the server - used when indexing, to ensure the returned results match the expectations. - # A full list of possible languages can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang + + # The path where the search index will be created in. + index_path: "./searchindex" + + # The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found at https://github.com/blevesearch/bleve/tree/master/analysis/lang language: "en" # This option controls which HTTP header to inspect to find the real remote IP diff --git a/docs/installation/7_configuration.md b/docs/installation/7_configuration.md index 67cd339cf..19958c92f 100644 --- a/docs/installation/7_configuration.md +++ b/docs/installation/7_configuration.md @@ -138,16 +138,18 @@ room_server: conn_max_lifetime: -1 ``` -## Fulltext search +## Full-text search -Dendrite supports experimental fulltext indexing using [Bleve](https://github.com/blevesearch/bleve), it is configured in the `sync_api` section as follows. Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). +Dendrite supports experimental full-text indexing using [Bleve](https://github.com/blevesearch/bleve). It is configured in the `sync_api` section as follows. + +Depending on the language most likely to be used on the server, it might make sense to change the `language` used when indexing, to ensure the returned results match the expectations. A full list of possible languages can be found [here](https://github.com/blevesearch/bleve/tree/master/analysis/lang). ```yaml sync_api: # ... search: enabled: false - index_path: "./fulltextindex" + index_path: "./searchindex" language: "en" ``` diff --git a/setup/config/config_syncapi.go b/setup/config/config_syncapi.go index edef22c93..a87da3732 100644 --- a/setup/config/config_syncapi.go +++ b/setup/config/config_syncapi.go @@ -50,7 +50,7 @@ type Fulltext struct { func (f *Fulltext) Defaults(opts DefaultOpts) { f.Enabled = false - f.IndexPath = "./fulltextindex" + f.IndexPath = "./searchindex" f.Language = "en" } From 083ae01520afadfacc6f0ea4bdd501d41f0d832b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 27 Sep 2022 17:30:40 +0100 Subject: [PATCH 15/17] Promote reindexing log level --- syncapi/consumers/clientapi.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index b11ed4f5e..796cc61e1 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -91,7 +91,7 @@ func (s *OutputClientDataConsumer) Start() error { return } ctx := context.Background() - logrus.Debugf("Starting to index events") + logrus.Infof("Starting to index events") var offset int start := time.Now() count := 0 @@ -140,7 +140,7 @@ func (s *OutputClientDataConsumer) Start() error { offset += len(evs) count += len(elements) } - logrus.Debugf("Indexed %d events in %v", count, time.Since(start)) + logrus.Infof("Indexed %d events in %v", count, time.Since(start)) }) if err != nil { return err From a574ed53696c06e6be6dbe313af0caaa56a659ec Mon Sep 17 00:00:00 2001 From: texuf Date: Tue, 27 Sep 2022 21:19:34 -0700 Subject: [PATCH 16/17] =?UTF-8?q?Fix=20for=20`sql:=20converting=20argument?= =?UTF-8?q?=20$1=20type:=20unsupported=20type=20[]interfa=E2=80=A6=20(#274?= =?UTF-8?q?3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ce {}, a slice of interface` in new notifications select The sqlite3 version was just not working, original pr here: https://github.com/matrix-org/dendrite/pull/2688 signed off by: austin ellis This doesn't fix the notification counts, they still only work about 1 out of every 5 times in my tests. I will stick with my other fix locally for reliable notification delivery: https://github.com/matrix-org/dendrite/pull/2701 --- syncapi/storage/sqlite3/notification_data_table.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index ceff60555..a690ffad6 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -90,8 +90,8 @@ func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( for i := range roomIDs { params[i+1] = roomIDs[i] } - sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1) - rows, err := r.db.QueryContext(ctx, sql, params) + sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) + rows, err := r.db.QueryContext(ctx, sql, params...) if err != nil { return nil, err } From 3f9e38e80a7be356aaf1294038888df27e0697a8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 28 Sep 2022 10:18:03 +0100 Subject: [PATCH 17/17] Consistent `*sql.Tx` usage across sync API (#2744) This tidies up the `storage` package so that everything takes a transaction parameter instead of something things that do and some that don't. --- .../storage/postgres/account_data_table.go | 5 ++-- .../postgres/backwards_extremities_table.go | 4 +-- .../postgres/current_room_state_table.go | 12 ++++----- syncapi/storage/postgres/filter_table.go | 14 +++++----- syncapi/storage/postgres/invites_table.go | 2 +- .../postgres/output_room_events_table.go | 4 +-- .../output_room_events_topology_table.go | 8 +++--- syncapi/storage/postgres/peeks_table.go | 4 +-- syncapi/storage/postgres/receipt_table.go | 4 +-- syncapi/storage/shared/syncserver.go | 26 +++++++++---------- syncapi/storage/sqlite3/account_data_table.go | 4 +-- .../sqlite3/backwards_extremities_table.go | 4 +-- .../sqlite3/current_room_state_table.go | 21 ++++++++++----- syncapi/storage/sqlite3/filter_table.go | 14 +++++----- .../sqlite3/notification_data_table.go | 7 ++++- .../sqlite3/output_room_events_table.go | 6 ++--- .../output_room_events_topology_table.go | 4 +-- syncapi/storage/sqlite3/peeks_table.go | 4 +-- syncapi/storage/sqlite3/receipt_table.go | 9 +++++-- syncapi/storage/tables/interface.go | 20 +++++++------- 20 files changed, 99 insertions(+), 77 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index e9c72058b..aa54cb08f 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData( } func (s *accountDataStatements) SelectAccountDataInRange( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), + rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext( + ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), accountDataEventFilter.Limit, diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index d4515735c..8fc92091f 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID) if err != nil { return } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5e6daaaf8..4ffd29610 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -185,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -209,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } @@ -387,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { diff --git a/syncapi/storage/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go index c82ef092f..86cec3625 100644 --- a/syncapi/storage/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { return err } @@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string @@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) + err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext( + ctx, localpart, filterJSON, + ).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { return "", err } @@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter( } // Otherwise insert the filter and return the new ID - err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). + err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart). Scan(&filterID) return } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 97001ae2c..f87ccf965 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( return } - err = s.insertInviteEventStmt.QueryRowContext( + err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 20a9ea428..cb092150d 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -222,12 +222,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID()) return err } diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index a1fc9b2a3..6fab900eb 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos, spos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } @@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (topoPos types.StreamPosition, err error) { if backwardOrdering { - err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } else { - err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } return } @@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index 75eeac986..e20a4882f 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange( } func (s *peekStatements) SelectPeekingDevices( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index bbddaa939..327a7a372 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room return } -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { var lastPos types.StreamPosition - rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) + rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 47e3a991c..a05e68804 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -148,7 +148,7 @@ func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r } func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { - return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) + return d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) } // Events lookups a list of event by their event ID. @@ -168,15 +168,15 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse } func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsers(ctx) + return d.CurrentRoomState.SelectJoinedUsers(ctx, nil) } func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, roomIDs) + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, nil, roomIDs) } func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { - return d.Peeks.SelectPeekingDevices(ctx) + return d.Peeks.SelectPeekingDevices(ctx, nil) } func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { @@ -186,7 +186,7 @@ func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs func (d *Database) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectStateEvent(ctx, roomID, evType, stateKey) + return d.CurrentRoomState.SelectStateEvent(ctx, nil, roomID, evType, stateKey) } func (d *Database) GetStateEventsForRoom( @@ -277,7 +277,7 @@ func (d *Database) GetAccountDataInRange( ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter, ) (map[string][]string, types.StreamPosition, error) { - return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) + return d.AccountData.SelectAccountDataInRange(ctx, nil, userID, r, accountDataFilterPart) } // UpsertAccountData keeps track of new or updated account data, by saving the type @@ -484,7 +484,7 @@ func (d *Database) GetEventsInTopologicalRange( func (d *Database) BackwardExtremitiesForRoom( ctx context.Context, roomID string, ) (backwardExtremities map[string][]string, err error) { - return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) + return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, nil, roomID) } func (d *Database) MaxTopologicalPosition( @@ -530,7 +530,7 @@ func (d *Database) StreamToTopologicalPosition( func (d *Database) GetFilter( ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { - return d.Filter.SelectFilter(ctx, target, localpart, filterID) + return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) } func (d *Database) PutFilter( @@ -538,8 +538,8 @@ func (d *Database) PutFilter( ) (string, error) { var filterID string var err error - err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - filterID, err = d.Filter.InsertFilter(ctx, filter, localpart) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart) return err }) return filterID, err @@ -561,8 +561,8 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda } newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) - err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.OutputEvents.UpdateEventJSON(ctx, newEvent) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) }) return err } @@ -1024,7 +1024,7 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId } func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { - _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) + _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) return receipts, err } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 21a16dcd3..d8967113a 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -91,14 +91,14 @@ func (s *accountDataStatements) InsertAccountData( } func (s *accountDataStatements) SelectAccountDataInRange( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID string, r types.Range, filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) stmt, params, err := prepareWithFilters( - s.db, nil, selectAccountDataInRangeSQL, + s.db, txn, selectAccountDataInRangeSQL, []interface{}{ userID, r.Low(), r.High(), }, diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index c5674dded..3a5fd6be3 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -82,9 +82,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID) if err != nil { return } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index bd1271dd6..ba6d8126c 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -163,9 +163,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -187,7 +187,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) params := make([]interface{}, 0, len(roomIDs)) @@ -200,7 +200,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( } defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed") - rows, err := stmt.QueryContext(ctx, params...) + rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) if err != nil { return nil, err } @@ -401,9 +401,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { @@ -429,10 +429,17 @@ func (s *currentRoomStateStatements) SelectSharedUsers( params[k+1] = v } + var provider sqlutil.QueryProvider + if txn == nil { + provider = s.db + } else { + provider = txn + } + result := make([]string, 0, len(otherUserIDs)) query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1) err := sqlutil.RunLimitedVariablesQuery( - ctx, query, s.db, params, sqlutil.SQLite3MaxVariables, + ctx, query, provider, params, sqlutil.SQLite3MaxVariables, func(rows *sql.Rows) error { var stateKey string for rows.Next() { diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 6081a48b1..5f1e980eb 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -77,11 +78,11 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { return err } @@ -94,7 +95,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string @@ -115,8 +116,9 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) + err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext( + ctx, localpart, filterJSON, + ).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { return "", err } @@ -126,7 +128,7 @@ func (s *filterStatements) InsertFilter( } // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + res, err := sqlutil.TxStmt(txn, s.insertFilterStmt).ExecContext(ctx, filterJSON, localpart) if err != nil { return "", err } diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index a690ffad6..6242898e1 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -91,7 +91,12 @@ func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( params[i+1] = roomIDs[i] } sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) - rows, err := r.db.QueryContext(ctx, sql, params...) + prep, err := r.db.PrepareContext(ctx, sql) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, prep, "SelectUserUnreadCountsForRooms: prep.close() failed") + rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 6269f4fdf..165943027 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -164,12 +164,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID()) return err } @@ -647,7 +647,7 @@ func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, l return nil, err } defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed") - rows, err := stmt.QueryContext(ctx, params...) + rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index b2fb77417..81b264988 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -176,9 +176,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (topoPos types.StreamPosition, err error) { if backwardOrdering { - err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } else { - err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } return } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 5ee86448c..4ef51b103 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -172,9 +172,9 @@ func (s *peekStatements) SelectPeeksInRange( } func (s *peekStatements) SelectPeekingDevices( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index 31adb005b..a4a9b4395 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -108,7 +108,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room } // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) var lastPos types.StreamPosition params := make([]interface{}, len(roomIDs)+1) @@ -116,7 +116,12 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs for k, v := range roomIDs { params[k+1] = v } - rows, err := r.db.QueryContext(ctx, selectSQL, params...) + prep, err := r.db.Prepare(selectSQL) + if err != nil { + return 0, nil, fmt.Errorf("unable to prepare statement: %w", err) + } + defer internal.CloseAndLogIfError(ctx, prep, "SelectRoomReceiptsAfter: prep.close() failed") + rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2a6d6fa82..89cb537af 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -28,7 +28,7 @@ import ( type AccountData interface { InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. - SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) + SelectAccountDataInRange(ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -46,7 +46,7 @@ type Peeks interface { DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error) SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) - SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) + SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) } @@ -68,7 +68,7 @@ type Events interface { // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) - UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error + UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) @@ -98,7 +98,7 @@ type Topology interface { } type CurrentRoomState interface { - SelectStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + SelectStateEvent(ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpsertRoomState(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error @@ -110,9 +110,9 @@ type CurrentRoomState interface { // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. - SelectJoinedUsers(ctx context.Context) (map[string][]string, error) + SelectJoinedUsers(ctx context.Context, txn *sql.Tx) (map[string][]string, error) // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. - SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) + SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) } @@ -142,7 +142,7 @@ type BackwardsExtremities interface { // InsertsBackwardExtremity inserts a new backwards extremity. InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error) // SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids. - SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) + SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) } @@ -172,13 +172,13 @@ type SendToDevice interface { } type Filter interface { - SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error - InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) + SelectFilter(ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string) error + InsertFilter(ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) } type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) - SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) + SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) }