From 7df5d69a5b606e226c5d83c23aa9dd90785c0b2d Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Tue, 26 Apr 2022 08:07:27 +0200 Subject: [PATCH 01/22] Checkout correct branch for Sytest --- .github/workflows/dendrite.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 4f337a866..5d60301c7 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -250,6 +250,7 @@ jobs: env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} + SYTEST_BRANCH: ${{ github.head_ref }} steps: - uses: actions/checkout@v2 - name: Run Sytest From feac9db43fc459f1efa10424dfc96f8d54b55c64 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 26 Apr 2022 10:28:41 +0200 Subject: [PATCH 02/22] Add transactionsCache to redact endpoint (#2375) --- clientapi/routing/redaction.go | 20 +++++++++++++++++++- clientapi/routing/routing.go | 5 +++-- sytest-whitelist | 3 ++- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 01ea818ab..e8d14ce34 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/transactions" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -40,12 +41,21 @@ type redactionResponse struct { func SendRedaction( req *http.Request, device *userapi.Device, roomID, eventID string, cfg *config.ClientAPI, rsAPI roomserverAPI.RoomserverInternalAPI, + txnID *string, + txnCache *transactions.Cache, ) util.JSONResponse { resErr := checkMemberInRoom(req.Context(), rsAPI, device.UserID, roomID) if resErr != nil { return *resErr } + if txnID != nil { + // Try to fetch response from transactionsCache + if res, ok := txnCache.FetchTransaction(device.AccessToken, *txnID); ok { + return *res + } + } + ev := roomserverAPI.GetEvent(req.Context(), rsAPI, eventID) if ev == nil { return util.JSONResponse{ @@ -124,10 +134,18 @@ func SendRedaction( util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } - return util.JSONResponse{ + + res := util.JSONResponse{ Code: 200, JSON: redactionResponse{ EventID: e.EventID(), }, } + + // Add response to transactionsCache + if txnID != nil { + txnCache.AddTransaction(device.AccessToken, *txnID, &res) + } + + return res } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 37d825b80..f370b4f8c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -479,7 +479,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) + return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, nil, nil) }), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnId}", @@ -488,7 +488,8 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI) + txnID := vars["txnId"] + return SendRedaction(req, device, vars["roomID"], vars["eventID"], cfg, rsAPI, &txnID, transactionsCache) }), ).Methods(http.MethodPut, http.MethodOptions) diff --git a/sytest-whitelist b/sytest-whitelist index 5d67aee6c..91304bd71 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -713,4 +713,5 @@ Presence can be set from sync /state returns M_NOT_FOUND for a rejected message event /state_ids returns M_NOT_FOUND for a rejected message event /state returns M_NOT_FOUND for a rejected state event -/state_ids returns M_NOT_FOUND for a rejected state event \ No newline at end of file +/state_ids returns M_NOT_FOUND for a rejected state event +PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent \ No newline at end of file From e8be2b234f616c8422372665c845d9a7a1af245f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 26 Apr 2022 10:53:17 +0200 Subject: [PATCH 03/22] Add heroes to the room summary (#2373) * Implement room summary heroes * Add passing tests * Move MembershipCount to addRoomSummary * Add comments, close Statement --- syncapi/storage/interface.go | 1 + syncapi/storage/postgres/memberships_table.go | 38 ++++++++++--- syncapi/storage/shared/syncserver.go | 4 ++ syncapi/storage/sqlite3/memberships_table.go | 51 ++++++++++++++--- syncapi/storage/tables/interface.go | 1 + syncapi/streams/stream_pdu.go | 56 +++++++++++++++---- syncapi/streams/streams.go | 1 + sytest-whitelist | 5 +- 8 files changed, 131 insertions(+), 26 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 14cb08a52..0fea88da6 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -39,6 +39,7 @@ type Database interface { GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) + GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 39fa656cb..8c049977f 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,6 +19,8 @@ import ( "database/sql" "fmt" + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -61,9 +63,13 @@ const selectMembershipCountSQL = "" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + ") t WHERE t.membership = $3" +const selectHeroesSQL = "" + + "SELECT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" + type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt + selectHeroesStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -72,13 +78,11 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { if err != nil { return nil, err } - if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { - return nil, err - } - if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertMembershipStmt, upsertMembershipSQL}, + {&s.selectMembershipCountStmt, selectMembershipCountSQL}, + {&s.selectHeroesStmt, selectHeroesSQL}, + }.Prepare(db) } func (s *membershipsStatements) UpsertMembership( @@ -108,3 +112,23 @@ func (s *membershipsStatements) SelectMembershipCount( err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count) return } + +func (s *membershipsStatements) SelectHeroes( + ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, +) (heroes []string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt) + var rows *sql.Rows + rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships)) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") + var hero string + for rows.Next() { + if err = rows.Scan(&hero); err != nil { + return + } + heroes = append(heroes, hero) + } + return heroes, rows.Err() +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 2143fd672..3c431db48 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -124,6 +124,10 @@ func (d *Database) MembershipCount(ctx context.Context, roomID, membership strin return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos) } +func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { + return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships) +} + func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) } diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 9f3530ccd..e4daa99c1 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,7 +18,9 @@ import ( "context" "database/sql" "fmt" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -61,10 +63,14 @@ const selectMembershipCountSQL = "" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + ") t WHERE t.membership = $3" +const selectHeroesSQL = "" + + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt + //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -75,13 +81,11 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { if err != nil { return nil, err } - if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { - return nil, err - } - if s.selectMembershipCountStmt, err = db.Prepare(selectMembershipCountSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertMembershipStmt, upsertMembershipSQL}, + {&s.selectMembershipCountStmt, selectMembershipCountSQL}, + // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic + }.Prepare(db) } func (s *membershipsStatements) UpsertMembership( @@ -111,3 +115,36 @@ func (s *membershipsStatements) SelectMembershipCount( err = stmt.QueryRowContext(ctx, roomID, pos, membership).Scan(&count) return } + +func (s *membershipsStatements) SelectHeroes( + ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, +) (heroes []string, err error) { + stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + stmt, err := s.db.PrepareContext(ctx, stmtSQL) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed") + params := []interface{}{ + roomID, userID, + } + for _, membership := range memberships { + params = append(params, membership) + } + + stmt = sqlutil.TxStmt(txn, stmt) + var rows *sql.Rows + rows, err = stmt.QueryContext(ctx, params...) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") + var hero string + for rows.Next() { + if err = rows.Scan(&hero); err != nil { + return + } + heroes = append(heroes, hero) + } + return heroes, rows.Err() +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 993e2022b..ac713dd5c 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -170,6 +170,7 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) + SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index df5fb8e08..0d033095d 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -4,13 +4,16 @@ import ( "context" "database/sql" "fmt" + "sort" "sync" "time" "github.com/matrix-org/dendrite/internal/caching" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" "go.uber.org/atomic" ) @@ -30,6 +33,7 @@ type PDUStreamProvider struct { workers atomic.Int32 // userID+deviceID -> lazy loading cache lazyLoadCache *caching.LazyLoadCache + rsAPI roomserverAPI.RoomserverInternalAPI } func (p *PDUStreamProvider) worker() { @@ -290,16 +294,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } - // Work out how many members are in the room. - joinedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := p.DB.MembershipCount(ctx, delta.RoomID, gomatrixserverlib.Invite, latestPosition) - switch delta.Membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount + p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) @@ -332,6 +331,45 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } +func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { + // Work out how many members are in the room. + joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) + invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) + + jr.Summary.JoinedMemberCount = &joinedCount + jr.Summary.InvitedMemberCount = &invitedCount + + fetchStates := []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomName}, + {EventType: gomatrixserverlib.MRoomCanonicalAlias}, + } + // Check if the room has a name or a canonical alias + latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{} + err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState) + if err != nil { + return + } + // Check if the room has a name or canonical alias, if so, return. + for _, ev := range latestState.StateEvents { + switch ev.Type() { + case gomatrixserverlib.MRoomName: + if gjson.GetBytes(ev.Content(), "name").Str != "" { + return + } + case gomatrixserverlib.MRoomCanonicalAlias: + if gjson.GetBytes(ev.Content(), "alias").Str != "" { + return + } + } + } + heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) + if err != nil { + return + } + sort.Strings(heroes) + jr.Summary.Heroes = heroes +} + func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, roomID string, @@ -416,9 +454,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - // Work out how many members are in the room. - joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, r.From) - invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, r.From) + p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From) // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: @@ -439,8 +475,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } } - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount jr.Timeline.PrevBatch = prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index d3195b78f..a18a0cc41 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -33,6 +33,7 @@ func NewSyncStreamProviders( PDUStreamProvider: &PDUStreamProvider{ StreamProvider: StreamProvider{DB: d}, lazyLoadCache: lazyLoadCache, + rsAPI: rsAPI, }, TypingStreamProvider: &TypingStreamProvider{ StreamProvider: StreamProvider{DB: d}, diff --git a/sytest-whitelist b/sytest-whitelist index 91304bd71..c9829606f 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -714,4 +714,7 @@ Presence can be set from sync /state_ids returns M_NOT_FOUND for a rejected message event /state returns M_NOT_FOUND for a rejected state event /state_ids returns M_NOT_FOUND for a rejected state event -PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent \ No newline at end of file +PUT /rooms/:room_id/redact/:event_id/:txn_id is idempotent +Unnamed room comes with a name summary +Named room comes with just joined member count summary +Room summary only has 5 heroes \ No newline at end of file From 5306c73b008567d855ca548d195abf3dfaf8917c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Apr 2022 13:08:54 +0100 Subject: [PATCH 04/22] Fix bug when uploading device signatures (#2377) * Find the complete key ID when uploading signatures * Try that again * Try splitting the right thing * Don't do it for device keys * Refactor `QuerySignatures` * Revert "Refactor `QuerySignatures`" This reverts commit c02832a3e92569f64f180dec1555056dc8f8c3e3. * Both requested key IDs and master/self/user keys * Fix uniqueness * Try tweaking GMSL * Update GMSL again * Revert "Update GMSL again" This reverts commit bd6916cc379dd8d9e3f38d979c6550bd658938aa. * Revert "Try tweaking GMSL" This reverts commit 2a054524da9d64c6a2a5228262fbba5fde28798c. * Database migrations --- keyserver/internal/cross_signing.go | 7 ++ .../postgres/cross_signing_sigs_table.go | 6 +- .../deltas/2022042612000000_xsigning_idx.go | 52 +++++++++++++ keyserver/storage/postgres/storage.go | 1 + .../sqlite3/cross_signing_sigs_table.go | 4 +- .../deltas/2022042612000000_xsigning_idx.go | 76 +++++++++++++++++++ keyserver/storage/sqlite3/storage.go | 1 + 7 files changed, 144 insertions(+), 3 deletions(-) create mode 100644 keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go create mode 100644 keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 2281f4bbf..08bbfedb8 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -362,6 +362,13 @@ func (a *KeyInternalAPI) processSelfSignatures( for targetKeyID, signature := range forTargetUserID { switch sig := signature.CrossSigningBody.(type) { case *gomatrixserverlib.CrossSigningKey: + for keyID := range sig.Keys { + split := strings.SplitN(string(keyID), ":", 2) + if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { + targetKeyID = keyID // contains the ed25519: or other scheme + break + } + } for originUserID, forOriginUserID := range sig.Signatures { for originKeyID, originSig := range forOriginUserID { if err := a.DB.StoreCrossSigningSigsForTarget( diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 40633c05c..b101e7ce5 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( target_user_id TEXT NOT NULL, target_key_id TEXT NOT NULL, signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) ); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); ` const selectCrossSigningSigsForTargetSQL = "" + @@ -44,7 +46,7 @@ const selectCrossSigningSigsForTargetSQL = "" + const upsertCrossSigningSigsForTargetSQL = "" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + " VALUES($1, $2, $3, $4, $5)" + - " ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)" + " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" const deleteCrossSigningSigsForTargetSQL = "" + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 000000000..12956e3b4 --- /dev/null +++ b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { + m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) +} + +func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error { + _, err := tx.Exec(` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { + _, err := tx.Exec(` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); + + DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 136986885..d4c7e2cc7 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -54,6 +54,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) } m := sqlutil.NewMigrations() deltas.LoadRefactorKeyChanges(m) + deltas.LoadFixCrossSigningSignatureIndexes(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index 29ee889fd..36d562b8a 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -33,8 +33,10 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( target_user_id TEXT NOT NULL, target_key_id TEXT NOT NULL, signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) ); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); ` const selectCrossSigningSigsForTargetSQL = "" + diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 000000000..230e39fef --- /dev/null +++ b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,76 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { + m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) +} + +func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error { + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index 0e0adceef..84d4cdf55 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -53,6 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) m := sqlutil.NewMigrations() deltas.LoadRefactorKeyChanges(m) + deltas.LoadFixCrossSigningSignatureIndexes(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } From 4c19f22725b8f534163ad37845650005b32172ad Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 26 Apr 2022 15:50:56 +0200 Subject: [PATCH 05/22] Fix account_data not correctly send in a complete sync (#2379) * Return the StreamPosition from the database and not the latest * Fix linter issue --- syncapi/storage/interface.go | 2 +- syncapi/storage/postgres/account_data_table.go | 18 +++++++++++------- syncapi/storage/shared/syncserver.go | 2 +- syncapi/storage/sqlite3/account_data_table.go | 18 +++++++++++------- syncapi/storage/tables/interface.go | 2 +- syncapi/streams/stream_accountdata.go | 4 ++-- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0fea88da6..13065fa6b 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -81,7 +81,7 @@ type Database interface { // Returns a map following the format data[roomID] = []dataTypes // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error - GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) + GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 25bdb1da3..22bb4d7fa 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -57,7 +57,7 @@ const insertAccountDataSQL = "" + " RETURNING id" const selectAccountDataInRangeSQL = "" + - "SELECT room_id, type FROM syncapi_account_data_type" + + "SELECT id, room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + " AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" + " AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" + @@ -103,7 +103,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, -) (data map[string][]string, err error) { +) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), @@ -116,11 +116,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - for rows.Next() { - var dataType string - var roomID string + var dataType string + var roomID string + var id types.StreamPosition - if err = rows.Scan(&roomID, &dataType); err != nil { + for rows.Next() { + if err = rows.Scan(&id, &roomID, &dataType); err != nil { return } @@ -129,8 +130,11 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } + if id > pos { + pos = id + } } - return data, rows.Err() + return data, pos, rows.Err() } func (s *accountDataStatements) SelectMaxAccountDataID( diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 3c431db48..69bceb624 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -265,7 +265,7 @@ func (d *Database) DeletePeeks( func (d *Database) GetAccountDataInRange( ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { +) (map[string][]string, types.StreamPosition, error) { return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart) } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 71a098177..e0d97ec32 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -43,7 +43,7 @@ const insertAccountDataSQL = "" + // further parameters are added by prepareWithFilters const selectAccountDataInRangeSQL = "" + - "SELECT room_id, type FROM syncapi_account_data_type" + + "SELECT id, room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" const selectMaxAccountDataIDSQL = "" + @@ -95,7 +95,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( userID string, r types.Range, filter *gomatrixserverlib.EventFilter, -) (data map[string][]string, err error) { +) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) stmt, params, err := prepareWithFilters( s.db, nil, selectAccountDataInRangeSQL, @@ -112,11 +112,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed") - for rows.Next() { - var dataType string - var roomID string + var dataType string + var roomID string + var id types.StreamPosition - if err = rows.Scan(&roomID, &dataType); err != nil { + for rows.Next() { + if err = rows.Scan(&id, &roomID, &dataType); err != nil { return } @@ -125,9 +126,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } + if id > pos { + pos = id + } } - return data, nil + return data, pos, nil } func (s *accountDataStatements) SelectMaxAccountDataID( diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index ac713dd5c..32b1c34ef 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -27,7 +27,7 @@ import ( type AccountData interface { InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) // SelectAccountDataInRange returns a map of room ID to a list of `dataType`. - SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) + SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, pos types.StreamPosition, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 094c51485..99cd4a92a 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -43,7 +43,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( To: to, } - dataTypes, err := p.DB.GetAccountDataInRange( + dataTypes, pos, err := p.DB.GetAccountDataInRange( ctx, req.Device.UserID, r, &req.Filter.AccountData, ) if err != nil { @@ -95,5 +95,5 @@ func (p *AccountDataStreamProvider) IncrementalSync( } } - return to + return pos } From 6892e0f0e02466be3cac6fc6f17267aeecb5961b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Apr 2022 16:02:21 +0100 Subject: [PATCH 06/22] Start account data ID from `from` --- syncapi/storage/postgres/account_data_table.go | 2 +- syncapi/storage/sqlite3/account_data_table.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 22bb4d7fa..0a7146913 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -118,7 +118,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string - var id types.StreamPosition + id := r.From for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index e0d97ec32..d84159ac8 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -114,7 +114,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string - var id types.StreamPosition + id := r.From for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { From f6d07768a82cdea07c56cf4ae463449292fa9fe4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Apr 2022 16:07:13 +0100 Subject: [PATCH 07/22] Fix account data position --- syncapi/storage/postgres/account_data_table.go | 3 ++- syncapi/storage/sqlite3/account_data_table.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 0a7146913..ec1919fca 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -105,6 +105,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) + pos = r.Low() rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), @@ -118,7 +119,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string - id := r.From + var id types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index d84159ac8..2c7272ea8 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -96,6 +96,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( r types.Range, filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { + pos = r.Low() data = make(map[string][]string) stmt, params, err := prepareWithFilters( s.db, nil, selectAccountDataInRangeSQL, @@ -114,7 +115,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string - id := r.From + var id types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { From b527e33c16d70ee6f94ac12c077b43283ff1fd86 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Apr 2022 16:58:20 +0100 Subject: [PATCH 08/22] Send all account data on complete sync by default Squashed commit of the following: commit 0ec8de57261d573a5f88577aa9d7a1174d3999b9 Author: Neil Alexander Date: Tue Apr 26 16:56:30 2022 +0100 Select filter onto provided target filter commit da40b6fffbf5737864b223f49900048f557941f9 Author: Neil Alexander Date: Tue Apr 26 16:48:00 2022 +0100 Specify other field too commit ffc0b0801f63bb4d3061b6813e3ce5f3b4c8fbcb Author: Neil Alexander Date: Tue Apr 26 16:45:44 2022 +0100 Send as much account data as possible during complete sync --- syncapi/routing/filter.go | 4 ++-- syncapi/storage/interface.go | 6 +++--- syncapi/storage/postgres/filter_table.go | 13 ++++++------- syncapi/storage/shared/syncserver.go | 6 +++--- syncapi/storage/sqlite3/filter_table.go | 13 ++++++------- syncapi/storage/tables/interface.go | 2 +- syncapi/sync/request.go | 12 +++++++++--- 7 files changed, 30 insertions(+), 26 deletions(-) diff --git a/syncapi/routing/filter.go b/syncapi/routing/filter.go index baa4d841c..1a10bd649 100644 --- a/syncapi/routing/filter.go +++ b/syncapi/routing/filter.go @@ -44,8 +44,8 @@ func GetFilter( return jsonerror.InternalServerError() } - filter, err := syncDB.GetFilter(req.Context(), localpart, filterID) - if err != nil { + filter := gomatrixserverlib.DefaultFilter() + if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, // even though it is not correct. diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 13065fa6b..43aaa3588 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -125,10 +125,10 @@ type Database interface { // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified // from position, preventing the send-to-device table from growing indefinitely. CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error) - // GetFilter looks up the filter associated with a given local user and filter ID. - // Returns a filter structure. Otherwise returns an error if no such filter exists + // GetFilter looks up the filter associated with a given local user and filter ID + // and populates the target filter. Otherwise returns an error if no such filter exists // or if there was an error talking to the database. - GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error // PutFilter puts the passed filter into the database. // Returns the filterID as a string. Otherwise returns an error if something // goes wrong. diff --git a/syncapi/storage/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go index dfd3d6963..c82ef092f 100644 --- a/syncapi/storage/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -73,21 +73,20 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { - return nil, err + return err } // Unmarshal JSON into Filter struct - filter := gomatrixserverlib.DefaultFilter() - if err = json.Unmarshal(filterData, &filter); err != nil { - return nil, err + if err = json.Unmarshal(filterData, &target); err != nil { + return err } - return &filter, nil + return nil } func (s *filterStatements) InsertFilter( diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 69bceb624..25aca50ae 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -513,9 +513,9 @@ func (d *Database) StreamToTopologicalPosition( } func (d *Database) GetFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { - return d.Filter.SelectFilter(ctx, localpart, filterID) + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { + return d.Filter.SelectFilter(ctx, target, localpart, filterID) } func (d *Database) PutFilter( diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 0cfebef2a..6081a48b1 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -77,21 +77,20 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { - return nil, err + return err } // Unmarshal JSON into Filter struct - filter := gomatrixserverlib.DefaultFilter() - if err = json.Unmarshal(filterData, &filter); err != nil { - return nil, err + if err = json.Unmarshal(filterData, &target); err != nil { + return err } - return &filter, nil + return nil } func (s *filterStatements) InsertFilter( diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 32b1c34ef..4ff4689ed 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -157,7 +157,7 @@ type SendToDevice interface { } type Filter interface { - SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) } diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index f04f172d3..c9ee8e4a8 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -18,6 +18,7 @@ import ( "database/sql" "encoding/json" "fmt" + "math" "net/http" "strconv" "time" @@ -47,6 +48,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } // TODO: read from stored filters too filter := gomatrixserverlib.DefaultFilter() + if since.IsEmpty() { + // Send as much account data down for complete syncs as possible + // by default, otherwise clients do weird things while waiting + // for the rest of the data to trickle down. + filter.AccountData.Limit = math.MaxInt + filter.Room.AccountData.Limit = math.MaxInt + } filterQuery := req.URL.Query().Get("filter") if filterQuery != "" { if filterQuery[0] == '{' { @@ -61,11 +69,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows { + if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows { util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") return nil, fmt.Errorf("syncDB.GetFilter: %w", err) - } else if f != nil { - filter = *f } } } From 6c5c6d73d771e0ea5cc325e4251bcbfc48b7d55e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 26 Apr 2022 17:05:31 +0100 Subject: [PATCH 09/22] Use a value that is Go 1.16-friendly --- syncapi/sync/request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index c9ee8e4a8..9d4740e93 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -52,8 +52,8 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat // Send as much account data down for complete syncs as possible // by default, otherwise clients do weird things while waiting // for the rest of the data to trickle down. - filter.AccountData.Limit = math.MaxInt - filter.Room.AccountData.Limit = math.MaxInt + filter.AccountData.Limit = math.MaxInt32 + filter.Room.AccountData.Limit = math.MaxInt32 } filterQuery := req.URL.Query().Get("filter") if filterQuery != "" { From 66b397b3c60c51bba36e4bce858733b2fda26f6a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 11:25:07 +0100 Subject: [PATCH 10/22] Don't create fictitious presence entries (#2381) * Don't create fictitious presence entries for users that don't have any * Update whitelist, since that test probably shouldn't be passing * Fix panics --- syncapi/consumers/presence.go | 9 ++++++++- syncapi/storage/postgres/presence_table.go | 3 +++ syncapi/storage/sqlite3/presence_table.go | 3 +++ syncapi/streams/stream_presence.go | 12 +++++------ syncapi/sync/requestpool.go | 23 ++++++++++++---------- sytest-whitelist | 2 -- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index b198b2292..6bcca48f4 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -88,6 +88,11 @@ func (s *PresenceConsumer) Start() error { } return } + if presence == nil { + presence = &types.PresenceInternal{ + UserID: userID, + } + } deviceRes := api.QueryDevicesResponse{} if err = s.deviceAPI.QueryDevices(s.ctx, &api.QueryDevicesRequest{UserID: userID}, &deviceRes); err != nil { @@ -106,7 +111,9 @@ func (s *PresenceConsumer) Start() error { m.Header.Set(jetstream.UserID, presence.UserID) m.Header.Set("presence", presence.ClientFields.Presence) - m.Header.Set("status_msg", *presence.ClientFields.StatusMsg) + if presence.ClientFields.StatusMsg != nil { + m.Header.Set("status_msg", *presence.ClientFields.StatusMsg) + } m.Header.Set("last_active_ts", strconv.Itoa(int(presence.LastActiveTS))) if err = msg.RespondMsg(m); err != nil { diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 49336c4eb..9f1e37f79 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -127,6 +127,9 @@ func (p *presenceStatements) GetPresenceForUser( } stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) + if err == sql.ErrNoRows { + return nil, nil + } result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index 00b16458d..177a01bf3 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -142,6 +142,9 @@ func (p *presenceStatements) GetPresenceForUser( } stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) + if err == sql.ErrNoRows { + return nil, nil + } result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 9a6c5c130..614b88d48 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -16,7 +16,6 @@ package streams import ( "context" - "database/sql" "encoding/json" "sync" @@ -80,11 +79,10 @@ func (p *PresenceStreamProvider) IncrementalSync( if _, ok := presences[roomUsers[i]]; ok { continue } + // Bear in mind that this might return nil, but at least populating + // a nil means that there's a map entry so we won't repeat this call. presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) if err != nil { - if err == sql.ErrNoRows { - continue - } req.Log.WithError(err).Error("unable to query presence for user") return from } @@ -93,8 +91,10 @@ func (p *PresenceStreamProvider) IncrementalSync( } lastPos := to - for i := range presences { - presence := presences[i] + for _, presence := range presences { + if presence == nil { + continue + } // Ignore users we don't share a room with if req.Device.UserID != presence.UserID && !p.notifier.IsSharedUser(req.Device.UserID, presence.UserID) { continue diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 703340997..76d550a65 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -127,14 +127,23 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user if !ok { // this should almost never happen return } + newPresence := types.PresenceInternal{ - ClientFields: types.PresenceClientResponse{ - Presence: presenceID.String(), - }, Presence: presenceID, UserID: userID, LastActiveTS: gomatrixserverlib.AsTimestamp(time.Now()), } + + // ensure we also send the current status_msg to federated servers and not nil + dbPresence, err := db.GetPresence(context.Background(), userID) + if err != nil && err != sql.ErrNoRows { + return + } + if dbPresence != nil { + newPresence.ClientFields = dbPresence.ClientFields + } + newPresence.ClientFields.Presence = presenceID.String() + defer rp.presence.Store(userID, newPresence) // avoid spamming presence updates when syncing existingPresence, ok := rp.presence.LoadOrStore(userID, newPresence) @@ -145,13 +154,7 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } } - // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) - if err != nil && err != sql.ErrNoRows { - return - } - - if err := rp.producer.SendPresence(userID, presenceID, dbPresence.ClientFields.StatusMsg); err != nil { + if err := rp.producer.SendPresence(userID, presenceID, newPresence.ClientFields.StatusMsg); err != nil { logrus.WithError(err).Error("Unable to publish presence message from sync") return } diff --git a/sytest-whitelist b/sytest-whitelist index c9829606f..6af8d89ff 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -681,8 +681,6 @@ GET /presence/:user_id/status fetches initial status PUT /presence/:user_id/status updates my presence Presence change reports an event to myself Existing members see new members' presence -#Existing members see new member's presence -Newly joined room includes presence in incremental sync Get presence for newly joined members in incremental sync User sees their own presence in a sync User sees updates to presence from other users in the incremental sync. From dca4afd2f0871ed53109121dff17048e69cc4935 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 12:03:34 +0100 Subject: [PATCH 11/22] Don't send account data or receipts for left/forgotten rooms (#2382) * Only include account data and receipts for rooms in a complete sync that we care about * Fix global account data --- syncapi/streams/stream_accountdata.go | 6 ++++++ syncapi/streams/stream_receipt.go | 6 ++++++ syncapi/types/provider.go | 17 +++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 99cd4a92a..2cddbcf04 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -53,6 +53,12 @@ func (p *AccountDataStreamProvider) IncrementalSync( // Iterate over the rooms for roomID, dataTypes := range dataTypes { + // For a complete sync, make sure we're only including this room if + // that room was present in the joined rooms. + if from == 0 && roomID != "" && !req.IsRoomPresent(roomID) { + continue + } + // Request the missing data from the database for _, dataType := range dataTypes { dataReq := userapi.QueryAccountDataRequest{ diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 9d7d479a2..f4e84c7d0 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -62,6 +62,12 @@ func (p *ReceiptStreamProvider) IncrementalSync( } for roomID, receipts := range receiptsByRoom { + // For a complete sync, make sure we're only including this room if + // that room was present in the joined rooms. + if from == 0 && !req.IsRoomPresent(roomID) { + continue + } + jr := *types.NewJoinResponse() if existing, ok := req.Response.Rooms.Join[roomID]; ok { jr = existing diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index e6777f643..a9ea234d0 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -25,6 +25,23 @@ type SyncRequest struct { IgnoredUsers IgnoredUsers } +func (r *SyncRequest) IsRoomPresent(roomID string) bool { + membership, ok := r.Rooms[roomID] + if !ok { + return false + } + switch membership { + case gomatrixserverlib.Join: + return true + case gomatrixserverlib.Invite: + return true + case gomatrixserverlib.Peek: + return true + default: + return false + } +} + type StreamProvider interface { Setup() From 54ff4cf690918886c7e7a59a65ccff970c3aa1fc Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 12:23:55 +0100 Subject: [PATCH 12/22] Don't try to federated-join via ourselves (#2383) --- federationapi/internal/perform.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 8cd944346..aac36cc76 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -75,7 +75,7 @@ func (r *FederationInternalAPI) PerformJoin( seenSet := make(map[gomatrixserverlib.ServerName]bool) var uniqueList []gomatrixserverlib.ServerName for _, srv := range request.ServerNames { - if seenSet[srv] { + if seenSet[srv] || srv == r.cfg.Matrix.ServerName { continue } seenSet[srv] = true From d7cc187ec00410b949ffae1625835f8ac9f36c29 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 13:36:40 +0100 Subject: [PATCH 13/22] Prevent JetStream from handling OS signals, allow running as a Windows service (#2385) * Prevent JetStream from handling OS signals, allow running as a Windows service (fixes #2374) * Remove double import --- go.mod | 1 + go.sum | 1 + setup/base/base.go | 9 +++++++-- setup/jetstream/nats.go | 1 + 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index ba222ed8f..d51c3f75d 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/h2non/filetype v1.1.3 // indirect github.com/hashicorp/golang-lru v0.5.4 github.com/juju/testing v0.0.0-20220203020004-a0ff61f03494 // indirect + github.com/kardianos/minwinsvc v1.0.0 // indirect github.com/lib/pq v1.10.5 github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 diff --git a/go.sum b/go.sum index 8bb306a82..f8daca79e 100644 --- a/go.sum +++ b/go.sum @@ -721,6 +721,7 @@ github.com/julienschmidt/httprouter v1.1.1-0.20151013225520-77a895ad01eb/go.mod github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kardianos/minwinsvc v1.0.0 h1:+JfAi8IBJna0jY2dJGZqi7o15z13JelFIklJCAENALA= github.com/kardianos/minwinsvc v1.0.0/go.mod h1:Bgd0oc+D0Qo3bBytmNtyRKVlp85dAloLKhfxanPFFRc= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE= diff --git a/setup/base/base.go b/setup/base/base.go index 43d613b0c..281153444 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -42,6 +42,7 @@ import ( userdb "github.com/matrix-org/dendrite/userapi/storage" "github.com/gorilla/mux" + "github.com/kardianos/minwinsvc" appserviceAPI "github.com/matrix-org/dendrite/appservice/api" asinthttp "github.com/matrix-org/dendrite/appservice/inthttp" @@ -462,7 +463,8 @@ func (b *BaseDendrite) SetupAndServeHTTP( }() } - <-b.ProcessContext.WaitForShutdown() + minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite) + b.WaitForShutdown() ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -475,7 +477,10 @@ func (b *BaseDendrite) SetupAndServeHTTP( func (b *BaseDendrite) WaitForShutdown() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - <-sigs + select { + case <-sigs: + case <-b.ProcessContext.WaitForShutdown(): + } signal.Reset(syscall.SIGINT, syscall.SIGTERM) logrus.Warnf("Shutdown signal received") diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 1c8a89e8d..8d5289697 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -44,6 +44,7 @@ func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient StoreDir: string(cfg.StoragePath), NoSystemAccount: true, MaxPayload: 16 * 1024 * 1024, + NoSigs: true, }) if err != nil { panic(err) From f023cdf8c42cc1a4bb850b478dbbf7d901b5e1bd Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 27 Apr 2022 15:05:49 +0200 Subject: [PATCH 14/22] Add UserAPI storage tests (#2384) * Add tests for parts of the userapi storage * Add tests for keybackup * Add LoginToken tests * Add OpenID tests * Add profile tests * Add pusher tests * Add ThreePID tests * Add notification tests * Add more device tests, fix numeric localpart query * Fix failing CI * Fix numeric local part query --- go.mod | 1 + setup/base/base.go | 5 +- userapi/storage/interface.go | 91 ++-- userapi/storage/postgres/accounts_table.go | 6 +- userapi/storage/postgres/devices_table.go | 22 +- userapi/storage/shared/storage.go | 15 - userapi/storage/sqlite3/accounts_table.go | 8 +- userapi/storage/sqlite3/devices_table.go | 22 +- userapi/storage/storage.go | 4 +- userapi/storage/storage_test.go | 539 +++++++++++++++++++++ userapi/storage/storage_wasm.go | 2 +- userapi/userapi_test.go | 2 +- 12 files changed, 640 insertions(+), 77 deletions(-) create mode 100644 userapi/storage/storage_test.go diff --git a/go.mod b/go.mod index d51c3f75d..a7caadfb5 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,7 @@ require ( github.com/pressly/goose v2.7.0+incompatible github.com/prometheus/client_golang v1.12.1 github.com/sirupsen/logrus v1.8.1 + github.com/stretchr/testify v1.7.0 github.com/tidwall/gjson v1.14.0 github.com/tidwall/sjson v1.2.4 github.com/uber/jaeger-client-go v2.30.0+incompatible diff --git a/setup/base/base.go b/setup/base/base.go index 281153444..dbc5d2394 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -21,6 +21,7 @@ import ( "io" "net" "net/http" + _ "net/http/pprof" "os" "os/signal" "syscall" @@ -56,8 +57,6 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/sirupsen/logrus" - - _ "net/http/pprof" ) // BaseDendrite is a base for creating new instances of dendrite. It parses @@ -273,7 +272,7 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client { // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() userdb.Database { - db, err := userdb.NewDatabase( + db, err := userdb.NewUserAPIDatabase( &b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index b15470dd4..a4562cf19 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -27,18 +27,24 @@ import ( type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error } -type Database interface { - Profile - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) +type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string) (err error) + SetPassword(ctx context.Context, localpart string, plaintextPassword string) error +} + +type AccountData interface { SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given @@ -46,26 +52,9 @@ type Database interface { // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) - GetNewNumericLocalpart(ctx context.Context) (int64, error) - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) - RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) - GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) - - // Key backups - CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) - UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) - DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) - GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) - UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) - GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) - CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) +} +type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) @@ -79,11 +68,22 @@ type Database interface { CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) +} +type KeyBackup interface { + CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error) + UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error) + DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error) + GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) + UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error) + GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error) + CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) +} + +type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is // determined by the loginTokenLifetime given to the Database constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) @@ -94,21 +94,50 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} - InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) - DeleteOldNotifications(ctx context.Context) error +type OpenID interface { + CreateOpenIDToken(ctx context.Context, token, userID string) (exp int64, err error) + GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) +} +type Pusher interface { UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) RemovePusher(ctx context.Context, appid, pushkey, localpart string) error RemovePushers(ctx context.Context, appid, pushkey string) error } +type ThreePID interface { + SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) +} + +type Notification interface { + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + DeleteOldNotifications(ctx context.Context) error +} + +type Database interface { + Account + AccountData + Device + KeyBackup + LoginToken + Notification + OpenID + Profile + Pusher + ThreePID +} + // Err3PIDInUse is the error returned when trying to save an association involving // a third-party identifier which is already associated to a local user. var Err3PIDInUse = errors.New("this third-party identifier is already in use") diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 92311d56d..f86812f17 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); --- Create sequence for autogenerated numeric usernames -CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + @@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT nextval('numeric_username_seq')" + "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart( stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) - return + return id + 1, err } diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7bc5dc69b..fe8c54e04 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -93,7 +93,7 @@ const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" @@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s } defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") var devices []api.Device + var dev api.Device + var localpart string + var lastseents sql.NullInt64 + var displayName sql.NullString for rows.Next() { - var dev api.Device - var localpart string - var displayName sql.NullString - if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { dev.DisplayName = displayName.String } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } @@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart( } defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") + var dev api.Device + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString for rows.Next() { - var dev api.Device - var lastseents sql.NullInt64 - var id, displayname, ip, useragent sql.NullString err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) if err != nil { return devices, err diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 72ae96ecc..f7212e030 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -577,21 +577,6 @@ func (d *Database) UpdateDevice( }) } -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - // RemoveDevices revokes one or more devices by deleting the entry in the database // matching with the given device IDs and user ID localpart. // If the devices don't exist, it will not return an error diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index e6c37e58e..6c5fe3071 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COUNT(localpart) FROM account_accounts" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" type accountsStatements struct { db *sql.DB @@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } @@ -177,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart( stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) - return + if err == sql.ErrNoRows { + return 1, nil + } + return id + 1, err } diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 423640e90..7860bd6a2 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -78,7 +78,7 @@ const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" @@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart( return devices, err } + var dev api.Device + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString for rows.Next() { - var dev api.Device - var lastseents sql.NullInt64 - var id, displayname, ip, useragent sql.NullString err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) if err != nil { return devices, err @@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s } defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") var devices []api.Device + var dev api.Device + var localpart string + var displayName sql.NullString + var lastseents sql.NullInt64 for rows.Next() { - var dev api.Device - var localpart string - var displayName sql.NullString - if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { dev.DisplayName = displayName.String } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index f372fe7dc..faf1ce75c 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -28,9 +28,9 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go new file mode 100644 index 000000000..e6c7d35fc --- /dev/null +++ b/userapi/storage/storage_test.go @@ -0,0 +1,539 @@ +package storage_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +const loginTokenLifetime = time.Minute + +var ( + openIDLifetimeMS = time.Minute.Milliseconds() + ctx = context.Background() +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") + if err != nil { + t.Fatalf("NewUserAPIDatabase returned %s", err) + } + return db, close +} + +// Tests storing and getting account data +func Test_AccountData(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + room := test.NewRoom(t, alice) + events := room.Events() + + contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) + err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) + assert.NoError(t, err, "unable to save account data") + + contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) + err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) + assert.NoError(t, err, "unable to save account data") + + accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") + assert.NoError(t, err, "unable to get account data by type") + assert.Equal(t, contentRoom, accountData) + + globalData, roomData, err := db.GetAccountData(ctx, localpart) + assert.NoError(t, err) + assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) + assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) + }) +} + +// Tests the creation of accounts +func Test_Accounts(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser() + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + assert.NoError(t, err, "failed to create account") + // verify the newly create account is the same as returned by CreateAccount + var accGet *api.Account + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + assert.NoError(t, err, "failed to get account by password") + assert.Equal(t, accAlice, accGet) + accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "failed to get account by localpart") + assert.Equal(t, accAlice, accGet) + + // check account availability + available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) + assert.NoError(t, err, "failed to checkout account availability") + assert.Equal(t, false, available) + + available, err = db.CheckAccountAvailability(ctx, "unusedname") + assert.NoError(t, err, "failed to checkout account availability") + assert.Equal(t, true, available) + + // get guest account numeric aliceLocalpart + first, err := db.GetNewNumericLocalpart(ctx) + assert.NoError(t, err, "failed to get new numeric localpart") + // Create a new account to verify the numeric localpart is updated + _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) + assert.NoError(t, err, "failed to create account") + second, err := db.GetNewNumericLocalpart(ctx) + assert.NoError(t, err) + assert.Greater(t, second, first) + + // update password for alice + err = db.SetPassword(ctx, aliceLocalpart, "newPassword") + assert.NoError(t, err, "failed to update password") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + assert.NoError(t, err, "failed to get account by new password") + assert.Equal(t, accAlice, accGet) + + // deactivate account + err = db.DeactivateAccount(ctx, aliceLocalpart) + assert.NoError(t, err, "failed to deactivate account") + // This should fail now, as the account is deactivated + _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + assert.Error(t, err, "expected an error, got none") + + _, err = db.GetAccountByLocalpart(ctx, "unusename") + assert.Error(t, err, "expected an error for non existent localpart") + }) +} + +func Test_Devices(t *testing.T) { + alice := test.NewUser() + localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + deviceID := util.RandomString(8) + accessToken := util.RandomString(16) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") + assert.NoError(t, err, "unable to create deviceWithoutID") + + gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields + + gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken) + assert.NoError(t, err, "unable to get device by access token") + assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields + + // create a device without existing device ID + accessToken = util.RandomString(16) + deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") + assert.NoError(t, err, "unable to create deviceWithoutID") + gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields + + // Get devices + devices, err := db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get devices by localpart") + assert.Equal(t, 2, len(devices)) + deviceIDs := make([]string, 0, len(devices)) + for _, dev := range devices { + deviceIDs = append(deviceIDs, dev.ID) + } + + devices2, err := db.GetDevicesByID(ctx, deviceIDs) + assert.NoError(t, err, "unable to get devices by id") + assert.Equal(t, devices, devices2) + + // Update device + newName := "new display name" + err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + assert.NoError(t, err, "unable to update device displayname") + err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1") + assert.NoError(t, err, "unable to update device last seen") + + deviceWithID.DisplayName = newName + deviceWithID.LastSeenIP = "127.0.0.1" + deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 2, len(devices)) + assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) + assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) + truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) + assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) + + // create one more device and remove the devices step by step + newDeviceID := util.RandomString(16) + accessToken = util.RandomString(16) + _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + assert.NoError(t, err, "unable to create new device") + + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 3, len(devices)) + + err = db.RemoveDevices(ctx, localpart, deviceIDs) + assert.NoError(t, err, "unable to remove devices") + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 1, len(devices)) + + deleted, err := db.RemoveAllDevices(ctx, localpart, "") + assert.NoError(t, err, "unable to remove all devices") + assert.Equal(t, 1, len(deleted)) + assert.Equal(t, newDeviceID, deleted[0].ID) + }) +} + +func Test_KeyBackup(t *testing.T) { + alice := test.NewUser() + room := test.NewRoom(t, alice) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + wantAuthData := json.RawMessage("my auth data") + wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData) + assert.NoError(t, err, "unable to create key backup") + // get key backup by version + gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion) + assert.NoError(t, err, "unable to get key backup") + assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") + assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") + assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") + + // get any key backup + gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "") + assert.NoError(t, err, "unable to get key backup") + assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") + assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") + assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") + + err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data")) + assert.NoError(t, err, "unable to update key backup auth data") + + uploads := []api.InternalKeyBackupSession{ + { + KeyBackupSession: api.KeyBackupSession{ + IsVerified: true, + SessionData: wantAuthData, + }, + RoomID: room.ID, + SessionID: "1", + }, + { + KeyBackupSession: api.KeyBackupSession{}, + RoomID: room.ID, + SessionID: "2", + }, + } + count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads) + assert.NoError(t, err, "unable to upsert backup keys") + assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") + + // do it again to update a key + uploads[1].IsVerified = true + count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:]) + assert.NoError(t, err, "unable to upsert backup keys") + assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") + + // get backup keys by session id + gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1") + assert.NoError(t, err, "unable to get backup keys") + assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) + + // get backup keys by room id + gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "") + assert.NoError(t, err, "unable to get backup keys") + assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) + + gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID) + assert.NoError(t, err, "unable to get backup keys count") + assert.Equal(t, count, gotCount, "unexpected backup count") + + // finally delete a key + exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion) + assert.NoError(t, err, "unable to delete key backup") + assert.True(t, exists) + + // this key should not exist + exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3") + assert.NoError(t, err, "unable to delete key backup") + assert.False(t, exists) + }) +} + +func Test_LoginToken(t *testing.T) { + alice := test.NewUser() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + // create a new token + wantLoginToken := &api.LoginTokenData{UserID: alice.ID} + + gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken) + assert.NoError(t, err, "unable to create login token") + assert.NotNil(t, gotMetadata) + assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime)) + + // get the new token + gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) + assert.NoError(t, err, "unable to get login token") + assert.NotNil(t, gotLoginToken) + assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token") + + // remove the login token again + err = db.RemoveLoginToken(ctx, gotMetadata.Token) + assert.NoError(t, err, "unable to remove login token") + + // check if the token was actually deleted + _, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) + assert.Error(t, err, "expected an error, but got none") + }) +} + +func Test_OpenID(t *testing.T) { + alice := test.NewUser() + token := util.RandomString(24) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS + expires, err := db.CreateOpenIDToken(ctx, token, alice.ID) + assert.NoError(t, err, "unable to create OpenID token") + assert.Equal(t, expiresAtMS, expires) + + attributes, err := db.GetOpenIDTokenAttributes(ctx, token) + assert.NoError(t, err, "unable to get OpenID token attributes") + assert.Equal(t, alice.ID, attributes.UserID) + assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS) + }) +} + +func Test_Profile(t *testing.T) { + alice := test.NewUser() + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + // create account, which also creates a profile + _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + assert.NoError(t, err, "failed to create account") + + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get profile by localpart") + wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + assert.Equal(t, wantProfile, gotProfile) + + // set avatar & displayname + wantProfile.DisplayName = "Alice" + wantProfile.AvatarURL = "mxc://aliceAvatar" + err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") + assert.NoError(t, err, "unable to set displayname") + err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + assert.NoError(t, err, "unable to set avatar url") + // verify profile + gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get profile by localpart") + assert.Equal(t, wantProfile, gotProfile) + + // search profiles + searchRes, err := db.SearchProfiles(ctx, "Alice", 2) + assert.NoError(t, err, "unable to search profiles") + assert.Equal(t, 1, len(searchRes)) + assert.Equal(t, *wantProfile, searchRes[0]) + }) +} + +func Test_Pusher(t *testing.T) { + alice := test.NewUser() + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + appID := util.RandomString(8) + var pushKeys []string + var gotPushers []api.Pusher + for i := 0; i < 2; i++ { + pushKey := util.RandomString(8) + + wantPusher := api.Pusher{ + PushKey: pushKey, + Kind: api.HTTPKind, + AppID: appID, + AppDisplayName: util.RandomString(8), + DeviceDisplayName: util.RandomString(8), + ProfileTag: util.RandomString(8), + Language: util.RandomString(2), + } + err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) + assert.NoError(t, err, "unable to upsert pusher") + + // check it was actually persisted + gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get pushers") + assert.Equal(t, i+1, len(gotPushers)) + assert.Equal(t, wantPusher, gotPushers[i]) + pushKeys = append(pushKeys, pushKey) + } + + // remove single pusher + err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) + assert.NoError(t, err, "unable to remove pusher") + gotPushers, err := db.GetPushers(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get pushers") + assert.Equal(t, 1, len(gotPushers)) + + // remove last pusher + err = db.RemovePushers(ctx, appID, pushKeys[1]) + assert.NoError(t, err, "unable to remove pusher") + gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get pushers") + assert.Equal(t, 0, len(gotPushers)) + }) +} + +func Test_ThreePID(t *testing.T) { + alice := test.NewUser() + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + threePID := util.RandomString(8) + medium := util.RandomString(8) + err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) + assert.NoError(t, err, "unable to save threepid association") + + // get the stored threepid + gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) + assert.NoError(t, err, "unable to get localpart for threepid") + assert.Equal(t, aliceLocalpart, gotLocalpart) + + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get threepids for localpart") + assert.Equal(t, 1, len(threepids)) + assert.Equal(t, authtypes.ThreePID{ + Address: threePID, + Medium: medium, + }, threepids[0]) + + // remove threepid association + err = db.RemoveThreePIDAssociation(ctx, threePID, medium) + assert.NoError(t, err, "unexpected error") + + // verify it was deleted + threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + assert.NoError(t, err, "unable to get threepids for localpart") + assert.Equal(t, 0, len(threepids)) + }) +} + +func Test_Notification(t *testing.T) { + alice := test.NewUser() + aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + assert.NoError(t, err) + room := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + // generate some dummy notifications + for i := 0; i < 10; i++ { + eventID := util.RandomString(16) + roomID := room.ID + ts := time.Now() + if i > 5 { + roomID = room2.ID + // create some old notifications to test DeleteOldNotifications + ts = ts.AddDate(0, -2, 0) + } + notification := &api.Notification{ + Actions: []*pushrules.Action{ + {}, + }, + Event: gomatrixserverlib.ClientEvent{ + Content: gomatrixserverlib.RawJSON("{}"), + }, + Read: false, + RoomID: roomID, + TS: gomatrixserverlib.AsTimestamp(ts), + } + err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) + assert.NoError(t, err, "unable to insert notification") + } + + // get notifications + count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) + assert.NoError(t, err, "unable to get notification count") + assert.Equal(t, int64(10), count) + notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) + assert.NoError(t, err, "unable to get notifications") + assert.Equal(t, int64(10), count) + assert.Equal(t, 10, len(notifs)) + // ... for a specific room + total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + assert.NoError(t, err, "unable to get notifications for room") + assert.Equal(t, int64(4), total) + + // mark notification as read + affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) + assert.NoError(t, err, "unable to set notifications read") + assert.True(t, affected) + + // this should delete 2 notifications + affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) + assert.NoError(t, err, "unable to set notifications read") + assert.True(t, affected) + + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + assert.NoError(t, err, "unable to get notifications for room") + assert.Equal(t, int64(2), total) + + // delete old notifications + err = db.DeleteOldNotifications(ctx) + assert.NoError(t, err) + + // this should now return 0 notifications + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + assert.NoError(t, err, "unable to get notifications for room") + assert.Equal(t, int64(0), total) + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 779f77568..a8e6f031c 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -func NewDatabase( +func NewUserAPIDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8c3608bd8..076b4f3c6 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s MaxOpenConnections: 1, MaxIdleConnections: 1, } - accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } From 6ee8507955f2b9674649acc928768b1a4d96f7c0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 14:45:51 +0100 Subject: [PATCH 15/22] Correct account data position mapping --- syncapi/storage/postgres/account_data_table.go | 10 +++++++--- syncapi/storage/sqlite3/account_data_table.go | 11 +++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index ec1919fca..7c0d03030 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -105,7 +105,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) - pos = r.Low() + pos = r.High() rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), @@ -120,6 +120,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string var id types.StreamPosition + var highest types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { @@ -131,10 +132,13 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } - if id > pos { - pos = id + if id > highest { + highest = id } } + if highest < pos { + pos = highest + } return data, pos, rows.Err() } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 2c7272ea8..1bbfe9c96 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -96,7 +96,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( r types.Range, filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { - pos = r.Low() + pos = r.High() data = make(map[string][]string) stmt, params, err := prepareWithFilters( s.db, nil, selectAccountDataInRangeSQL, @@ -116,6 +116,7 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string var id types.StreamPosition + var highest types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { @@ -127,11 +128,13 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } - if id > pos { - pos = id + if id > highest { + highest = id } } - + if highest < pos { + pos = highest + } return data, pos, nil } From 655ac3e8fb83e1cb9b670ab420a0f661dc19786e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 14:53:11 +0100 Subject: [PATCH 16/22] Try that again --- syncapi/storage/postgres/account_data_table.go | 10 ++++------ syncapi/storage/sqlite3/account_data_table.go | 10 ++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 7c0d03030..e9c72058b 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -105,7 +105,6 @@ func (s *accountDataStatements) SelectAccountDataInRange( accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) - pos = r.High() rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), @@ -120,7 +119,6 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string var id types.StreamPosition - var highest types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { @@ -132,12 +130,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } - if id > highest { - highest = id + if id > pos { + pos = id } } - if highest < pos { - pos = highest + if pos == 0 { + pos = r.High() } return data, pos, rows.Err() } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 1bbfe9c96..21a16dcd3 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -96,7 +96,6 @@ func (s *accountDataStatements) SelectAccountDataInRange( r types.Range, filter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { - pos = r.High() data = make(map[string][]string) stmt, params, err := prepareWithFilters( s.db, nil, selectAccountDataInRangeSQL, @@ -116,7 +115,6 @@ func (s *accountDataStatements) SelectAccountDataInRange( var dataType string var roomID string var id types.StreamPosition - var highest types.StreamPosition for rows.Next() { if err = rows.Scan(&id, &roomID, &dataType); err != nil { @@ -128,12 +126,12 @@ func (s *accountDataStatements) SelectAccountDataInRange( } else { data[roomID] = []string{dataType} } - if id > highest { - highest = id + if id > pos { + pos = id } } - if highest < pos { - pos = highest + if pos == 0 { + pos = r.High() } return data, pos, nil } From cafa2853c5d67b3dd4d247abdd1ad5806f0c951b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 15:01:57 +0100 Subject: [PATCH 17/22] Use process context as base context for all HTTP --- setup/base/base.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup/base/base.go b/setup/base/base.go index dbc5d2394..51c43198a 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -346,6 +346,9 @@ func (b *BaseDendrite) SetupAndServeHTTP( Addr: string(externalAddr), WriteTimeout: HTTPServerTimeout, Handler: externalRouter, + BaseContext: func(_ net.Listener) context.Context { + return b.ProcessContext.Context() + }, } internalServ := externalServ @@ -361,6 +364,9 @@ func (b *BaseDendrite) SetupAndServeHTTP( internalServ = &http.Server{ Addr: string(internalAddr), Handler: h2c.NewHandler(internalRouter, internalH2S), + BaseContext: func(_ net.Listener) context.Context { + return b.ProcessContext.Context() + }, } } From 103795d33a09728d7619e73014d507505ff121e2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 15:06:20 +0100 Subject: [PATCH 18/22] Defer cancel on shutdown context --- setup/base/base.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/base/base.go b/setup/base/base.go index 51c43198a..03ea2ad7e 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -472,7 +472,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( b.WaitForShutdown() ctx, cancel := context.WithCancel(context.Background()) - cancel() + defer cancel() _ = internalServ.Shutdown(ctx) _ = externalServ.Shutdown(ctx) From 923f789ca3174a685bd53ce5e64a5e86cabd38cb Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 15:29:49 +0100 Subject: [PATCH 19/22] Fix graceful shutdown --- federationapi/queue/destinationqueue.go | 13 ++++++++----- federationapi/queue/queue.go | 13 ++++++------- setup/base/base.go | 12 ++++++------ setup/jetstream/helpers.go | 16 +++++++++++++--- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a5f8c03b9..747940403 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -78,7 +78,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // this destination queue. We'll then be able to retrieve the PDU // later. if err := oq.db.AssociatePDUWithDestination( - context.TODO(), + oq.process.Context(), "", // TODO: remove this, as we don't need to persist the transaction ID oq.destination, // the destination server name receipt, // NIDs from federationapi_queue_json table @@ -122,7 +122,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share // this destination queue. We'll then be able to retrieve the PDU // later. if err := oq.db.AssociateEDUWithDestination( - context.TODO(), + oq.process.Context(), oq.destination, // the destination server name receipt, // NIDs from federationapi_queue_json table event.Type, @@ -177,7 +177,7 @@ func (oq *destinationQueue) getPendingFromDatabase() { // Check to see if there's anything to do for this server // in the database. retrieved := false - ctx := context.Background() + ctx := oq.process.Context() oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() @@ -271,6 +271,9 @@ func (oq *destinationQueue) backgroundSend() { // restarted automatically the next time we have an event to // send. return + case <-oq.process.Context().Done(): + // The parent process is shutting down, so stop. + return } // If we are backing off this server then wait for the @@ -420,13 +423,13 @@ func (oq *destinationQueue) nextTransaction( // Clean up the transaction in the database. if pduReceipts != nil { //logrus.Infof("Cleaning PDUs %q", pduReceipt.String()) - if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil { + if err = oq.db.CleanPDUs(oq.process.Context(), oq.destination, pduReceipts); err != nil { logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination) } } if eduReceipts != nil { //logrus.Infof("Cleaning EDUs %q", eduReceipt.String()) - if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil { + if err = oq.db.CleanEDUs(oq.process.Context(), oq.destination, eduReceipts); err != nil { logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination) } } diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index c45bbd1d4..d152886f5 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -15,7 +15,6 @@ package queue import ( - "context" "crypto/ed25519" "encoding/json" "fmt" @@ -105,14 +104,14 @@ func NewOutgoingQueues( // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { serverNames := map[gomatrixserverlib.ServerName]struct{}{} - if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil { + if names, err := db.GetPendingPDUServerNames(process.Context()); err == nil { for _, serverName := range names { serverNames[serverName] = struct{}{} } } else { log.WithError(err).Error("Failed to get PDU server names for destination queue hydration") } - if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil { + if names, err := db.GetPendingEDUServerNames(process.Context()); err == nil { for _, serverName := range names { serverNames[serverName] = struct{}{} } @@ -215,7 +214,7 @@ func (oqs *OutgoingQueues) SendEvent( // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { if api.IsServerBannedFromRoom( - context.TODO(), + oqs.process.Context(), oqs.rsAPI, ev.RoomID(), destination, @@ -238,7 +237,7 @@ func (oqs *OutgoingQueues) SendEvent( return fmt.Errorf("json.Marshal: %w", err) } - nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON)) + nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(headeredJSON)) if err != nil { return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } @@ -286,7 +285,7 @@ func (oqs *OutgoingQueues) SendEDU( if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() { for destination := range destmap { if api.IsServerBannedFromRoom( - context.TODO(), + oqs.process.Context(), oqs.rsAPI, result.Str, destination, @@ -310,7 +309,7 @@ func (oqs *OutgoingQueues) SendEDU( return fmt.Errorf("json.Marshal: %w", err) } - nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON)) + nid, err := oqs.db.StoreJSON(oqs.process.Context(), string(ephemeralJSON)) if err != nil { return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err) } diff --git a/setup/base/base.go b/setup/base/base.go index 03ea2ad7e..e67b034a3 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -469,14 +469,14 @@ func (b *BaseDendrite) SetupAndServeHTTP( } minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite) - b.WaitForShutdown() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - _ = internalServ.Shutdown(ctx) - _ = externalServ.Shutdown(ctx) + <-b.ProcessContext.WaitForShutdown() + logrus.Infof("Stopping HTTP listeners") + _ = internalServ.Shutdown(context.Background()) + _ = externalServ.Shutdown(context.Background()) logrus.Infof("Stopped HTTP listeners") + + b.WaitForShutdown() } func (b *BaseDendrite) WaitForShutdown() { diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 78cecb6ae..1c07583e9 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -35,6 +35,16 @@ func JetStreamConsumer( } go func() { for { + // If the parent context has given up then there's no point in + // carrying on doing anything, so stop the listener. + select { + case <-ctx.Done(): + if err := sub.Unsubscribe(); err != nil { + logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable) + } + return + default: + } // The context behaviour here is surprising — we supply a context // so that we can interrupt the fetch if we want, but NATS will still // enforce its own deadline (roughly 5 seconds by default). Therefore @@ -65,18 +75,18 @@ func JetStreamConsumer( continue } msg := msgs[0] - if err = msg.InProgress(); err != nil { + if err = msg.InProgress(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) sentry.CaptureException(err) continue } if f(ctx, msg) { - if err = msg.AckSync(); err != nil { + if err = msg.AckSync(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) sentry.CaptureException(err) } } else { - if err = msg.Nak(); err != nil { + if err = msg.Nak(nats.Context(ctx)); err != nil { logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) sentry.CaptureException(err) } From 34221938ccb1f3a885ac9e5a36b79d3d74850d38 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 27 Apr 2022 16:04:11 +0100 Subject: [PATCH 20/22] Version 0.8.2 (#2386) * Version 0.8.2 * Correct account data position mapping * Try that again * Don't duplicate wait-for-shutdowns --- CHANGES.md | 29 +++++++++++++++++++++++++++++ internal/version.go | 2 +- setup/base/base.go | 4 +--- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 831a8969d..6278bcba4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,34 @@ # Changelog +## Dendrite 0.8.2 (2022-04-27) + +### Features + +* Lazy-loading has been added to the `/sync` endpoint, which should speed up syncs considerably +* Filtering has been added to the `/messages` endpoint +* The room summary now contains "heroes" (up to 5 users in the room) for clients to display when no room name is set +* The existing lazy-loading caches will now be used by `/messages` and `/context` so that member events will not be sent to clients more times than necessary +* The account data stream now uses the provided filters +* The built-in NATS Server has been updated to version 2.8.0 +* The `/state` and `/state_ids` endpoints will now return `M_NOT_FOUND` for rejected events +* Repeated calls to the `/redact` endpoint will now be idempotent when a transaction ID is given +* Dendrite should now be able to run as a Windows service under Service Control Manager + +### Fixes + +* Fictitious presence updates will no longer be created for users which have not sent us presence updates, which should speed up complete syncs considerably +* Uploading cross-signing device signatures should now be more reliable, fixing a number of bugs with cross-signing +* All account data should now be sent properly on a complete sync, which should eliminate problems with client settings or key backups appearing to be missing +* Account data will now be limited correctly on incremental syncs, returning the stream position of the most recent update rather than the latest stream position +* Account data will not be sent for parted rooms, which should reduce the number of left/forgotten rooms reappearing in clients as empty rooms +* The TURN username hash has been fixed which should help to resolve some problems when using TURN for voice calls (contributed by [fcwoknhenuxdfiyv](https://github.com/fcwoknhenuxdfiyv)) +* Push rules can no longer be modified using the account data endpoints +* Querying account availability should now work properly in polylith deployments +* A number of bugs with sync filters have been fixed +* A default sync filter will now be used if the request contains a filter ID that does not exist +* The `pushkey_ts` field is now using seconds instead of milliseconds +* A race condition when gracefully shutting down has been fixed, so JetStream should no longer cause the process to exit before other Dendrite components are finished shutting down + ## Dendrite 0.8.1 (2022-04-07) ### Fixes diff --git a/internal/version.go b/internal/version.go index 5227a03bf..2477bc9ac 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 8 - VersionPatch = 1 + VersionPatch = 2 VersionTag = "" // example: "rc1" ) diff --git a/setup/base/base.go b/setup/base/base.go index e67b034a3..7091c6ba5 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -469,14 +469,12 @@ func (b *BaseDendrite) SetupAndServeHTTP( } minwinsvc.SetOnExit(b.ProcessContext.ShutdownDendrite) - <-b.ProcessContext.WaitForShutdown() + logrus.Infof("Stopping HTTP listeners") _ = internalServ.Shutdown(context.Background()) _ = externalServ.Shutdown(context.Background()) logrus.Infof("Stopped HTTP listeners") - - b.WaitForShutdown() } func (b *BaseDendrite) WaitForShutdown() { From 8d69e2f0b897dd1ecd99c7bd348ad8bfc8999c4e Mon Sep 17 00:00:00 2001 From: 0x1a8510f2 Date: Wed, 27 Apr 2022 20:19:46 +0100 Subject: [PATCH 21/22] Use Go 1.18 to build Docker images (#2391) Go 1.18 has now been released for a while and the CI already tests Dendrite with Go 1.18 so there should be no issues. Go 1.18 brings some performance improvements for ARM via the register calling convention so it makes sense to switch to it. --- build/docker/Dockerfile.monolith | 4 ++-- build/docker/Dockerfile.polylith | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/build/docker/Dockerfile.monolith b/build/docker/Dockerfile.monolith index 0d2a141ad..891a3a9e0 100644 --- a/build/docker/Dockerfile.monolith +++ b/build/docker/Dockerfile.monolith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.17-alpine AS base +FROM docker.io/golang:1.18-alpine AS base RUN apk --update --no-cache add bash build-base @@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/ VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] diff --git a/build/docker/Dockerfile.polylith b/build/docker/Dockerfile.polylith index c266fd480..ffdc35586 100644 --- a/build/docker/Dockerfile.polylith +++ b/build/docker/Dockerfile.polylith @@ -1,4 +1,4 @@ -FROM docker.io/golang:1.17-alpine AS base +FROM docker.io/golang:1.18-alpine AS base RUN apk --update --no-cache add bash build-base @@ -23,4 +23,4 @@ COPY --from=base /build/bin/* /usr/bin/ VOLUME /etc/dendrite WORKDIR /etc/dendrite -ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] \ No newline at end of file +ENTRYPOINT ["/usr/bin/dendrite-polylith-multi"] From 74259f296f225510e9fbb6c5aae191c3f86c729e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 27 Apr 2022 21:31:30 +0200 Subject: [PATCH 22/22] Fix #2390 (#2392) Fix duplicate heroes in `/sync` response. --- syncapi/storage/postgres/memberships_table.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 8c049977f..00223c57a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -64,7 +64,7 @@ const selectMembershipCountSQL = "" + ") t WHERE t.membership = $3" const selectHeroesSQL = "" + - "SELECT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" type membershipsStatements struct { upsertMembershipStmt *sql.Stmt