From e930959e4911453ec51bba0b1469f4b15c594f32 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 9 Aug 2022 10:40:46 +0200 Subject: [PATCH 1/9] Send-to-device/sync tweaks (#2630) * Always delete send to device messages * Omit empty to_device * Tweak /sync response to omit empty values --- syncapi/streams/stream_sendtodevice.go | 20 ++++++-------------- syncapi/sync/requestpool.go | 15 +++++++++++---- syncapi/types/types.go | 23 ++++++++++++----------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go index 6a18df506..31c6187cb 100644 --- a/syncapi/streams/stream_sendtodevice.go +++ b/syncapi/streams/stream_sendtodevice.go @@ -39,21 +39,13 @@ func (p *SendToDeviceStreamProvider) IncrementalSync( return from } - if len(events) > 0 { - // Clean up old send-to-device messages from before this stream position. - if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil { - req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed") - return from - } - - // Add the updates into the sync response. - for _, event := range events { - // skip ignored user events - if _, ok := req.IgnoredUsers.List[event.Sender]; ok { - continue - } - req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent) + // Add the updates into the sync response. + for _, event := range events { + // skip ignored user events + if _, ok := req.IgnoredUsers.List[event.Sender]; ok { + continue } + req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent) } return lastPos diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b6b4779a5..d908a9629 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -25,6 +25,11 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/jsonerror" keyapi "github.com/matrix-org/dendrite/keyserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" @@ -35,10 +40,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // RequestPool manages HTTP long-poll connections for /sync @@ -251,6 +252,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. waitingSyncRequests.Inc() defer waitingSyncRequests.Dec() + // Clean up old send-to-device messages from before this stream position. + // This is needed to avoid sending the same message multiple times + if err = rp.db.CleanSendToDeviceUpdates(syncReq.Context, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Since.SendToDevicePosition); err != nil { + syncReq.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed") + } + // loop until we get some data for { startTime := time.Now() diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 159fa08b6..39b085d9c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -21,9 +21,10 @@ import ( "strconv" "strings" - "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/roomserver/api" ) var ( @@ -330,23 +331,23 @@ type Response struct { NextBatch StreamingToken `json:"next_batch"` AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"` - } `json:"account_data"` + } `json:"account_data,omitempty"` Presence struct { Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"` - } `json:"presence"` + } `json:"presence,omitempty"` Rooms struct { - Join map[string]JoinResponse `json:"join"` - Peek map[string]JoinResponse `json:"peek"` - Invite map[string]InviteResponse `json:"invite"` - Leave map[string]LeaveResponse `json:"leave"` - } `json:"rooms"` + Join map[string]JoinResponse `json:"join,omitempty"` + Peek map[string]JoinResponse `json:"peek,omitempty"` + Invite map[string]InviteResponse `json:"invite,omitempty"` + Leave map[string]LeaveResponse `json:"leave,omitempty"` + } `json:"rooms,omitempty"` ToDevice struct { - Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` - } `json:"to_device"` + Events []gomatrixserverlib.SendToDeviceEvent `json:"events,omitempty"` + } `json:"to_device,omitempty"` DeviceLists struct { Changed []string `json:"changed,omitempty"` Left []string `json:"left,omitempty"` - } `json:"device_lists"` + } `json:"device_lists,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` } From 240ae257deb74b7be8a17500b77d5e1bca56e8f5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 9 Aug 2022 11:15:58 +0200 Subject: [PATCH 2/9] Add housekeeping function to delete old/expired EDUs (#2399) * Add housekeeping function to delete old/expired EDUs * Add migrations * Evict EDUs from cache * Fix queries * Fix upgrade * Use map[string]time.Duration to specify different expiry times * Fix copy & paste mistake * Set expires_at to tomorrow * Don't allow NULL * Add comment * Add tests * Use new testrig package * Fix migrations * Never expire m.direct_to_device Co-authored-by: Neil Alexander Co-authored-by: kegsay --- federationapi/federationapi.go | 13 +++ federationapi/queue/destinationqueue.go | 1 + federationapi/storage/interface.go | 5 +- .../deltas/2022042812473400_addexpiresat.go | 44 ++++++++ .../storage/postgres/queue_edus_table.go | 99 ++++++++++++----- federationapi/storage/postgres/storage.go | 3 + federationapi/storage/shared/storage_edus.go | 49 +++++++++ .../deltas/2022042812473400_addexpiresat.go | 68 ++++++++++++ .../storage/sqlite3/queue_edus_table.go | 100 ++++++++++++++---- federationapi/storage/sqlite3/storage.go | 3 + federationapi/storage/storage_test.go | 81 ++++++++++++++ federationapi/storage/tables/interface.go | 5 +- 12 files changed, 422 insertions(+), 49 deletions(-) create mode 100644 federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go create mode 100644 federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go create mode 100644 federationapi/storage/storage_test.go diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 97bcc12a5..ff01b1952 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -15,6 +15,8 @@ package federationapi import ( + "time" + "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationapi/api" federationAPI "github.com/matrix-org/dendrite/federationapi/api" @@ -167,5 +169,16 @@ func NewInternalAPI( if err = presenceConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start presence consumer") } + + var cleanExpiredEDUs func() + cleanExpiredEDUs = func() { + logrus.Infof("Cleaning expired EDUs") + if err := federationDB.DeleteExpiredEDUs(base.Context()); err != nil { + logrus.WithError(err).Error("Failed to clean expired EDUs") + } + time.AfterFunc(time.Hour, cleanExpiredEDUs) + } + time.AfterFunc(time.Minute, cleanExpiredEDUs) + return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues, keyRing) } diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index b6edec5da..0d937ffaf 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -127,6 +127,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.destination, // the destination server name receipt, // NIDs from federationapi_queue_json table event.Type, + nil, // this will use the default expireEDUTypes map ); err != nil { logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) return diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 29254948b..b8109b432 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "time" "github.com/matrix-org/dendrite/federationapi/storage/shared" "github.com/matrix-org/dendrite/federationapi/types" @@ -38,7 +39,7 @@ type Database interface { GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error - AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string) error + AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error @@ -70,4 +71,6 @@ type Database interface { // Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs)) // such that the combination of all server keys will include all the `optKeyIDs`. GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) + // DeleteExpiredEDUs cleans up expired EDUs + DeleteExpiredEDUs(ctx context.Context) error } diff --git a/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go new file mode 100644 index 000000000..53a7a025e --- /dev/null +++ b/federationapi/storage/postgres/deltas/2022042812473400_addexpiresat.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus ADD COLUMN IF NOT EXISTS expires_at BIGINT NOT NULL DEFAULT 0;") + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + return nil +} + +func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;") + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index 1fedf0ef1..d6507e13b 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -19,9 +19,11 @@ import ( "database/sql" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) const queueEDUsSchema = ` @@ -31,7 +33,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -43,8 +47,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx ` const insertQueueEDUSQL = "" + - "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" + + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)" @@ -65,6 +69,12 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + type queueEDUsStatements struct { db *sql.DB insertQueueEDUStmt *sql.Stmt @@ -73,6 +83,8 @@ type queueEDUsStatements struct { selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -81,27 +93,34 @@ func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { } _, err = s.db.Exec(queueEDUsSchema) if err != nil { - return + return s, err } - if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil { - return + + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "federationapi: add expiresat column", + Up: deltas.UpAddexpiresat, + }, + ) + if err := m.Up(context.Background()); err != nil { + return s, err } - if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + + return s, nil +} + +func (s *queueEDUsStatements) Prepare() error { + return sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(s.db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -110,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -117,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -150,7 +171,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -200,3 +221,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index 7c2883c1b..6e208d096 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -91,6 +91,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + if err = queueEDUs.Prepare(); err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, ServerName: serverName, diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index 02a23338f..b62e5d9c5 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -20,10 +20,21 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/matrix-org/gomatrixserverlib" ) +// defaultExpiry for EDUs if not listed below +var defaultExpiry = time.Hour * 24 + +// defaultExpireEDUTypes contains EDUs which can/should be expired after a given time +// if the target server isn't reachable for some reason. +var defaultExpireEDUTypes = map[string]time.Duration{ + gomatrixserverlib.MTyping: time.Minute, + gomatrixserverlib.MPresence: time.Minute * 10, +} + // AssociateEDUWithDestination creates an association that the // destination queues will use to determine which JSON blobs to send // to which servers. @@ -32,7 +43,21 @@ func (d *Database) AssociateEDUWithDestination( serverName gomatrixserverlib.ServerName, receipt *Receipt, eduType string, + expireEDUTypes map[string]time.Duration, ) error { + if expireEDUTypes == nil { + expireEDUTypes = defaultExpireEDUTypes + } + expiresAt := gomatrixserverlib.AsTimestamp(time.Now().Add(defaultExpiry)) + if duration, ok := expireEDUTypes[eduType]; ok { + // Keep EDUs for at least x minutes before deleting them + expiresAt = gomatrixserverlib.AsTimestamp(time.Now().Add(duration)) + } + // We forcibly set m.direct_to_device events to 0, as we always want them + // to be delivered. (required for E2EE) + if eduType == gomatrixserverlib.MDirectToDevice { + expiresAt = 0 + } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if err := d.FederationQueueEDUs.InsertQueueEDU( ctx, // context @@ -40,6 +65,7 @@ func (d *Database) AssociateEDUWithDestination( eduType, // EDU type for coalescing serverName, // destination server name receipt.nid, // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ); err != nil { return fmt.Errorf("InsertQueueEDU: %w", err) } @@ -150,3 +176,26 @@ func (d *Database) GetPendingEDUServerNames( ) ([]gomatrixserverlib.ServerName, error) { return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil) } + +// DeleteExpiredEDUs deletes expired EDUs +func (d *Database) DeleteExpiredEDUs(ctx context.Context) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + expiredBefore := gomatrixserverlib.AsTimestamp(time.Now()) + jsonNIDs, err := d.FederationQueueEDUs.SelectExpiredEDUs(ctx, txn, expiredBefore) + if err != nil { + return err + } + if len(jsonNIDs) == 0 { + return nil + } + for i := range jsonNIDs { + d.Cache.EvictFederationQueuedEDU(jsonNIDs[i]) + } + + if err = d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, jsonNIDs); err != nil { + return err + } + + return d.FederationQueueEDUs.DeleteExpiredEDUs(ctx, txn, expiredBefore) + }) +} diff --git a/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go new file mode 100644 index 000000000..c5030163b --- /dev/null +++ b/federationapi/storage/sqlite3/deltas/2022042812473400_addexpiresat.go @@ -0,0 +1,68 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/matrix-org/gomatrixserverlib" +) + +func UpAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus RENAME TO federationsender_queue_edus_old;") + if err != nil { + return fmt.Errorf("failed to rename table: %w", err) + } + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( + edu_type TEXT NOT NULL, + server_name TEXT NOT NULL, + json_nid BIGINT NOT NULL, + expires_at BIGINT NOT NULL DEFAULT 0 +); + +CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx + ON federationsender_queue_edus (json_nid, server_name); +`) + if err != nil { + return fmt.Errorf("failed to create new table: %w", err) + } + _, err = tx.ExecContext(ctx, ` +INSERT + INTO federationsender_queue_edus ( + edu_type, server_name, json_nid, expires_at + ) SELECT edu_type, server_name, json_nid, 0 FROM federationsender_queue_edus_old; +`) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + _, err = tx.ExecContext(ctx, "UPDATE federationsender_queue_edus SET expires_at = $1 WHERE edu_type != 'm.direct_to_device'", gomatrixserverlib.AsTimestamp(time.Now().Add(time.Hour*24))) + if err != nil { + return fmt.Errorf("failed to update queue_edus: %w", err) + } + return nil +} + +func DownAddexpiresat(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE federationsender_queue_edus DROP COLUMN expires_at;") + if err != nil { + return fmt.Errorf("failed to rename table: %w", err) + } + return nil +} diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index f4c84f094..8e7e7901f 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -20,9 +20,11 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/gomatrixserverlib" ) const queueEDUsSchema = ` @@ -32,7 +34,9 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_edus ( -- The domain part of the user ID the EDU event is for. server_name TEXT NOT NULL, -- The JSON NID from the federationsender_queue_edus_json table. - json_nid BIGINT NOT NULL + json_nid BIGINT NOT NULL, + -- The expiry time of this edu, if any. + expires_at BIGINT NOT NULL DEFAULT 0 ); CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx @@ -44,8 +48,8 @@ CREATE INDEX IF NOT EXISTS federationsender_queue_edus_server_name_idx ` const insertQueueEDUSQL = "" + - "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" + - " VALUES ($1, $2, $3)" + "INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid, expires_at)" + + " VALUES ($1, $2, $3, $4)" const deleteQueueEDUsSQL = "" + "DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)" @@ -66,13 +70,22 @@ const selectQueueEDUCountSQL = "" + const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" +const selectExpiredEDUsSQL = "" + + "SELECT DISTINCT json_nid FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + +const deleteExpiredEDUsSQL = "" + + "DELETE FROM federationsender_queue_edus WHERE expires_at > 0 AND expires_at <= $1" + type queueEDUsStatements struct { - db *sql.DB - insertQueueEDUStmt *sql.Stmt + db *sql.DB + insertQueueEDUStmt *sql.Stmt + // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt + selectExpiredEDUsStmt *sql.Stmt + deleteExpiredEDUsStmt *sql.Stmt } func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { @@ -81,24 +94,33 @@ func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { } _, err = db.Exec(queueEDUsSchema) if err != nil { - return + return s, err } - if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil { - return + + m := sqlutil.NewMigrator(db) + m.AddMigrations( + sqlutil.Migration{ + Version: "federationapi: add expiresat column", + Up: deltas.UpAddexpiresat, + }, + ) + if err := m.Up(context.Background()); err != nil { + return s, err } - if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil { - return - } - if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil { - return - } - if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil { - return - } - if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { - return - } - return + + return s, nil +} + +func (s *queueEDUsStatements) Prepare() error { + return sqlutil.StatementList{ + {&s.insertQueueEDUStmt, insertQueueEDUSQL}, + {&s.selectQueueEDUStmt, selectQueueEDUSQL}, + {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, + {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, + {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, + {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, + {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, + }.Prepare(s.db) } func (s *queueEDUsStatements) InsertQueueEDU( @@ -107,6 +129,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType string, serverName gomatrixserverlib.ServerName, nid int64, + expiresAt gomatrixserverlib.Timestamp, ) error { stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) _, err := stmt.ExecContext( @@ -114,6 +137,7 @@ func (s *queueEDUsStatements) InsertQueueEDU( eduType, // the EDU type serverName, // destination server name nid, // JSON blob NID + expiresAt, // timestamp of expiry ) return err } @@ -159,7 +183,7 @@ func (s *queueEDUsStatements) SelectQueueEDUs( } result = append(result, nid) } - return result, nil + return result, rows.Err() } func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( @@ -209,3 +233,33 @@ func (s *queueEDUsStatements) SelectQueueEDUServerNames( return result, rows.Err() } + +func (s *queueEDUsStatements) SelectExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectExpiredEDUsStmt) + rows, err := stmt.QueryContext(ctx, expiredBefore) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectExpiredEDUs: rows.close() failed") + var result []int64 + var nid int64 + for rows.Next() { + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + return result, rows.Err() +} + +func (s *queueEDUsStatements) DeleteExpiredEDUs( + ctx context.Context, txn *sql.Tx, + expiredBefore gomatrixserverlib.Timestamp, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteExpiredEDUsStmt) + _, err := stmt.ExecContext(ctx, expiredBefore) + return err +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index 9594aaec6..c89cb6bea 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -90,6 +90,9 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + if err = queueEDUs.Prepare(); err != nil { + return nil, err + } d.Database = shared.Database{ DB: d.db, ServerName: serverName, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go new file mode 100644 index 000000000..7eba2cbee --- /dev/null +++ b/federationapi/storage/storage_test.go @@ -0,0 +1,81 @@ +package storage_test + +import ( + "context" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + b, baseClose := testrig.CreateBaseDendrite(t, dbType) + connStr, dbClose := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, b.Caches, b.Cfg.Global.ServerName) + if err != nil { + t.Fatalf("NewDatabase returned %s", err) + } + return db, func() { + dbClose() + baseClose() + } +} + +func TestExpireEDUs(t *testing.T) { + var expireEDUTypes = map[string]time.Duration{ + gomatrixserverlib.MReceipt: time.Millisecond, + } + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateFederationDatabase(t, dbType) + defer close() + // insert some data + for i := 0; i < 100; i++ { + receipt, err := db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes) + assert.NoError(t, err) + } + // add data without expiry + receipt, err := db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes) + assert.NoError(t, err) + + // Delete expired EDUs + err = db.DeleteExpiredEDUs(ctx) + assert.NoError(t, err) + + // verify the data is gone + data, err := db.GetPendingEDUs(ctx, "localhost", 100) + assert.NoError(t, err) + assert.Equal(t, 1, len(data)) + + // check that m.direct_to_device is never expired + receipt, err = db.StoreJSON(ctx, "{}") + assert.NoError(t, err) + + err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes) + assert.NoError(t, err) + + err = db.DeleteExpiredEDUs(ctx) + assert.NoError(t, err) + + // We should get two EDUs, the m.read_marker and the m.direct_to_device + data, err = db.GetPendingEDUs(ctx, "localhost", 100) + assert.NoError(t, err) + assert.Equal(t, 2, len(data)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 19357393d..3c116a1d0 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -34,12 +34,15 @@ type FederationQueuePDUs interface { } type FederationQueueEDUs interface { - InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error + InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64, expiresAt gomatrixserverlib.Timestamp) error DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) + SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) + DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error + Prepare() error } type FederationQueueJSON interface { From c45d0936b59b0eb65f12fe22a3da3690ae0b5494 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Aug 2022 15:29:33 +0100 Subject: [PATCH 3/9] Generic-based internal HTTP API (#2626) * Generic-based internal HTTP API (tested out on a few endpoints in the federation API) * Add `PerformInvite` * More tweaks * Fix metric name * Fix LookupStateIDs * Lots of changes to clients * Some serverside stuff * Some error handling * Use paths as metric names * Revert "Use paths as metric names" This reverts commit a9323a6a343f5ce6461a2e5bd570fe06465f1b15. * Namespace metric names * Remove duplicate entry * Remove another duplicate entry * Tweak error handling * Some more tweaks * Update error behaviour * Some more error tweaking * Fix API path for `PerformDeleteKeys` * Fix another path * Tweak federation client proxying * Fix another path * Don't return typed nils * Some more tweaks, not that it makes any difference * Tweak federation client proxying * Maybe fix the key backup test --- appservice/inthttp/client.go | 19 +- appservice/inthttp/server.go | 29 +- clientapi/jsonerror/jsonerror.go | 14 + clientapi/routing/admin.go | 12 +- clientapi/routing/createroom.go | 6 +- clientapi/routing/directory.go | 6 +- clientapi/routing/joinroom.go | 5 +- clientapi/routing/key_backup.go | 12 +- clientapi/routing/key_crosssigning.go | 8 +- clientapi/routing/keys.go | 16 +- clientapi/routing/peekroom.go | 9 +- clientapi/routing/upgrade_room.go | 4 +- federationapi/api/api.go | 2 +- federationapi/federationapi_test.go | 5 +- federationapi/inthttp/client.go | 437 +++++--------- federationapi/inthttp/server.go | 461 +++++---------- federationapi/routing/devices.go | 10 +- federationapi/routing/join.go | 6 +- federationapi/routing/keys.go | 16 +- federationapi/routing/leave.go | 6 +- federationapi/routing/send.go | 4 +- federationapi/routing/send_test.go | 3 +- internal/httputil/http.go | 40 +- internal/httputil/internalapi.go | 93 +++ keyserver/api/api.go | 28 +- keyserver/internal/cross_signing.go | 41 +- keyserver/internal/device_list_update.go | 4 +- keyserver/internal/device_list_update_test.go | 4 +- keyserver/internal/internal.go | 32 +- keyserver/inthttp/client.go | 161 ++--- keyserver/inthttp/server.go | 143 ++--- roomserver/api/api.go | 18 +- roomserver/api/api_trace.go | 63 +- roomserver/api/wrapper.go | 4 +- roomserver/internal/input/input.go | 10 +- roomserver/internal/perform/perform_admin.go | 37 +- roomserver/internal/perform/perform_invite.go | 4 +- roomserver/internal/perform/perform_join.go | 12 +- roomserver/internal/perform/perform_leave.go | 4 +- roomserver/internal/perform/perform_peek.go | 3 +- .../internal/perform/perform_publish.go | 3 +- roomserver/internal/perform/perform_unpeek.go | 3 +- .../internal/perform/perform_upgrade.go | 17 +- roomserver/inthttp/client.go | 431 ++++++-------- roomserver/inthttp/server.go | 525 ++++------------- setup/mscs/msc2836/msc2836_test.go | 6 +- syncapi/internal/keychange.go | 4 +- syncapi/internal/keychange_test.go | 30 +- syncapi/syncapi_test.go | 7 +- userapi/api/api.go | 2 +- userapi/api/api_trace.go | 5 +- userapi/internal/api.go | 46 +- userapi/inthttp/client.go | 440 +++++++------- userapi/inthttp/client_logintoken.go | 28 +- userapi/inthttp/server.go | 550 +++++------------- userapi/inthttp/server_logintoken.go | 51 +- userapi/userapi_test.go | 14 +- 57 files changed, 1535 insertions(+), 2418 deletions(-) create mode 100644 internal/httputil/internalapi.go diff --git a/appservice/inthttp/client.go b/appservice/inthttp/client.go index 0a8baea99..3ae2c9278 100644 --- a/appservice/inthttp/client.go +++ b/appservice/inthttp/client.go @@ -7,7 +7,6 @@ import ( "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -42,11 +41,10 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists( request *api.RoomAliasExistsRequest, response *api.RoomAliasExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "RoomAliasExists", h.appserviceURL+AppServiceRoomAliasExistsPath, + h.httpClient, ctx, request, response, + ) } // UserIDExists implements AppServiceQueryAPI @@ -55,9 +53,8 @@ func (h *httpAppServiceQueryAPI) UserIDExists( request *api.UserIDExistsRequest, response *api.UserIDExistsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceUserIDExists") - defer span.Finish() - - apiURL := h.appserviceURL + AppServiceUserIDExistsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "UserIDExists", h.appserviceURL+AppServiceUserIDExistsPath, + h.httpClient, ctx, request, response, + ) } diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index 645b43871..01d9f9895 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -1,43 +1,20 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/util" ) // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, - httputil.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse { - var request api.RoomAliasExistsRequest - var response api.RoomAliasExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.RoomAliasExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists), ) + internalAPIMux.Handle( AppServiceUserIDExistsPath, - httputil.MakeInternalAPI("appserviceUserIDExists", func(req *http.Request) util.JSONResponse { - var request api.UserIDExistsRequest - var response api.UserIDExistsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := a.UserIDExists(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), ) } diff --git a/clientapi/jsonerror/jsonerror.go b/clientapi/jsonerror/jsonerror.go index 70bac61dc..be7d13a96 100644 --- a/clientapi/jsonerror/jsonerror.go +++ b/clientapi/jsonerror/jsonerror.go @@ -15,11 +15,13 @@ package jsonerror import ( + "context" "fmt" "net/http" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // MatrixError represents the "standard error response" in Matrix. @@ -213,3 +215,15 @@ func NotTrusted(serverName string) *MatrixError { Err: fmt.Sprintf("Untrusted server '%s'", serverName), } } + +// InternalAPIError is returned when Dendrite failed to reach an internal API. +func InternalAPIError(ctx context.Context, err error) util.JSONResponse { + logrus.WithContext(ctx).WithError(err).Error("Error reaching an internal API") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: &MatrixError{ + ErrCode: "M_INTERNAL_SERVER_ERROR", + Err: "Dendrite encountered an error reaching an internal API.", + }, + } +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 523b88c99..a8dd0e64f 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -30,13 +30,15 @@ func AdminEvacuateRoom(req *http.Request, device *userapi.Device, rsAPI roomserv } } res := &roomserverAPI.PerformAdminEvacuateRoomResponse{} - rsAPI.PerformAdminEvacuateRoom( + if err := rsAPI.PerformAdminEvacuateRoom( req.Context(), &roomserverAPI.PerformAdminEvacuateRoomRequest{ RoomID: roomID, }, res, - ) + ); err != nil { + return util.ErrorResponse(err) + } if err := res.Error; err != nil { return err.JSONResponse() } @@ -67,13 +69,15 @@ func AdminEvacuateUser(req *http.Request, device *userapi.Device, rsAPI roomserv } } res := &roomserverAPI.PerformAdminEvacuateUserResponse{} - rsAPI.PerformAdminEvacuateUser( + if err := rsAPI.PerformAdminEvacuateUser( req.Context(), &roomserverAPI.PerformAdminEvacuateUserRequest{ UserID: userID, }, res, - ) + ); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := res.Error; err != nil { return err.JSONResponse() } diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 3f92b7ba6..874908639 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -556,10 +556,12 @@ func createRoom( if r.Visibility == "public" { // expose this room in the published room list var pubRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ + if err := rsAPI.PerformPublish(ctx, &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: "public", - }, &pubRes) + }, &pubRes); err != nil { + return jsonerror.InternalAPIError(ctx, err) + } if pubRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(pubRes.Error).Error("failed to visibility:public") diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index 53ba3f190..836d9e152 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -302,10 +302,12 @@ func SetVisibility( } var publishRes roomserverAPI.PerformPublishResponse - rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ + if err := rsAPI.PerformPublish(req.Context(), &roomserverAPI.PerformPublishRequest{ RoomID: roomID, Visibility: v.Visibility, - }, &publishRes) + }, &publishRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if publishRes.Error != nil { util.GetLogger(req.Context()).WithError(publishRes.Error).Error("PerformPublish failed") return publishRes.Error.JSONResponse() diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 4e6acebc3..c50e552bd 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -81,8 +81,9 @@ func JoinRoomByIDOrAlias( done := make(chan util.JSONResponse, 1) go func() { defer close(done) - rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes) - if joinRes.Error != nil { + if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { + done <- jsonerror.InternalAPIError(req.Context(), err) + } else if joinRes.Error != nil { done <- joinRes.Error.JSONResponse() } else { done <- util.JSONResponse{ diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index 28c80415b..b6f8fe1b9 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -91,10 +91,12 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, de // Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version} func KeyBackupVersion(req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version string) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse - userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, - }, &queryResp) + }, &queryResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if queryResp.Error != "" { return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) } @@ -233,13 +235,15 @@ func GetBackupKeys( req *http.Request, userAPI userapi.ClientUserAPI, device *userapi.Device, version, roomID, sessionID string, ) util.JSONResponse { var queryResp userapi.QueryKeyBackupResponse - userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ + if err := userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{ UserID: device.UserID, Version: version, ReturnKeys: true, KeysForRoomID: roomID, KeysForSessionID: sessionID, - }, &queryResp) + }, &queryResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if queryResp.Error != "" { return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error)) } diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 8fbb86f7a..2570db09c 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -72,7 +72,9 @@ func UploadCrossSigningDeviceKeys( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) + if err := keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := uploadRes.Error; err != nil { switch { @@ -114,7 +116,9 @@ func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.Clie } uploadReq.UserID = device.UserID - keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes) + if err := keyserverAPI.PerformUploadDeviceSignatures(req.Context(), uploadReq, uploadRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if err := uploadRes.Error; err != nil { switch { diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index fdda34a53..b7a76b47e 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -62,7 +62,9 @@ func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devi } var uploadRes api.PerformUploadKeysResponse - keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes) + if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil { + return util.ErrorResponse(err) + } if uploadRes.Error != nil { util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") return jsonerror.InternalServerError() @@ -107,12 +109,14 @@ func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Devic return *resErr } queryRes := api.QueryKeysResponse{} - keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + if err := keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ UserID: device.UserID, UserToDevices: r.DeviceKeys, Timeout: r.GetTimeout(), // TODO: Token? - }, &queryRes) + }, &queryRes); err != nil { + return util.ErrorResponse(err) + } return util.JSONResponse{ Code: 200, JSON: map[string]interface{}{ @@ -145,10 +149,12 @@ func ClaimKeys(req *http.Request, keyAPI api.ClientKeyAPI) util.JSONResponse { return *resErr } claimRes := api.PerformClaimKeysResponse{} - keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ + if err := keyAPI.PerformClaimKeys(req.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: r.OneTimeKeys, Timeout: r.GetTimeout(), - }, &claimRes) + }, &claimRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if claimRes.Error != nil { util.GetLogger(req.Context()).WithError(claimRes.Error).Error("failed to PerformClaimKeys") return jsonerror.InternalServerError() diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index d0eeccf17..9b2592eb5 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -17,6 +17,7 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/clientapi/jsonerror" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -54,7 +55,9 @@ func PeekRoomByIDOrAlias( } // Ask the roomserver to perform the peek. - rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes) + if err := rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes); err != nil { + return util.ErrorResponse(err) + } if peekRes.Error != nil { return peekRes.Error.JSONResponse() } @@ -89,7 +92,9 @@ func UnpeekRoomByID( } unpeekRes := roomserverAPI.PerformUnpeekResponse{} - rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes) + if err := rsAPI.PerformUnpeek(req.Context(), &unpeekReq, &unpeekRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if unpeekRes.Error != nil { return unpeekRes.Error.JSONResponse() } diff --git a/clientapi/routing/upgrade_room.go b/clientapi/routing/upgrade_room.go index 744e2d889..34c7eb004 100644 --- a/clientapi/routing/upgrade_room.go +++ b/clientapi/routing/upgrade_room.go @@ -64,7 +64,9 @@ func UpgradeRoom( } upgradeResp := roomserverAPI.PerformRoomUpgradeResponse{} - rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp) + if err := rsAPI.PerformRoomUpgrade(req.Context(), &upgradeReq, &upgradeResp); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } if upgradeResp.Error != nil { if upgradeResp.Error.Code == roomserverAPI.PerformErrorNoRoom { diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 53d4701f3..292ed55ad 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -110,7 +110,7 @@ type FederationClientError struct { Blacklisted bool } -func (e *FederationClientError) Error() string { +func (e FederationClientError) Error() string { return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted) } diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 8884e34c6..bdcb9f57c 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -32,11 +32,12 @@ type fedRoomserverAPI struct { } // PerformJoin will call this function -func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { +func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) error { if f.inputRoomEvents == nil { - return + return nil } f.inputRoomEvents(ctx, req, res) + return nil } // keychange consumer calls this diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 295ddc495..812d3c6da 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -10,7 +10,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP API @@ -48,7 +47,11 @@ func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, if httpClient == nil { return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is ") } - return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil + return &httpFederationInternalAPI{ + federationAPIURL: federationSenderURL, + httpClient: httpClient, + cache: cache, + }, nil } type httpFederationInternalAPI struct { @@ -63,11 +66,10 @@ func (h *httpFederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLeave", h.federationAPIURL+FederationAPIPerformLeaveRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle sending an invite to a remote server. @@ -76,11 +78,10 @@ func (h *httpFederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInvite", h.federationAPIURL+FederationAPIPerformInviteRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle starting a peek on a remote server. @@ -89,11 +90,10 @@ func (h *httpFederationInternalAPI) PerformOutboundPeek( request *api.PerformOutboundPeekRequest, response *api.PerformOutboundPeekResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformOutboundPeek", h.federationAPIURL+FederationAPIPerformOutboundPeekRequestPath, + h.httpClient, ctx, request, response, + ) } // QueryJoinedHostServerNamesInRoom implements FederationInternalAPI @@ -102,11 +102,10 @@ func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryJoinedHostServerNamesInRoom", h.federationAPIURL+FederationAPIQueryJoinedHostServerNamesInRoomPath, + h.httpClient, ctx, request, response, + ) } // Handle an instruction to make_join & send_join with a remote server. @@ -115,12 +114,10 @@ func (h *httpFederationInternalAPI) PerformJoin( request *api.PerformJoinRequest, response *api.PerformJoinResponse, ) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { + if err := httputil.CallInternalRPCAPI( + "PerformJoinRequest", h.federationAPIURL+FederationAPIPerformJoinRequestPath, + h.httpClient, ctx, request, response, + ); err != nil { response.LastError = &gomatrix.HTTPError{ Message: err.Error(), Code: 0, @@ -135,11 +132,10 @@ func (h *httpFederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDirectoryLookup", h.federationAPIURL+FederationAPIPerformDirectoryLookupRequestPath, + h.httpClient, ctx, request, response, + ) } // Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. @@ -148,101 +144,61 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( request *api.PerformBroadcastEDURequest, response *api.PerformBroadcastEDUResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformBroadcastEDU", h.federationAPIURL+FederationAPIPerformBroadcastEDUPath, + h.httpClient, ctx, request, response, + ) } type getUserDevices struct { S gomatrixserverlib.ServerName UserID string - Res *gomatrixserverlib.RespUserDevices - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetUserDevices( ctx context.Context, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices") - defer span.Finish() - - var result gomatrixserverlib.RespUserDevices - request := getUserDevices{ - S: s, - UserID: userID, - } - var response getUserDevices - apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( + "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, + ctx, &getUserDevices{ + S: s, + UserID: userID, + }, + ) } type claimKeys struct { S gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string - Res *gomatrixserverlib.RespClaimKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) ClaimKeys( ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys") - defer span.Finish() - - var result gomatrixserverlib.RespClaimKeys - request := claimKeys{ - S: s, - OneTimeKeys: oneTimeKeys, - } - var response claimKeys - apiURL := h.federationAPIURL + FederationAPIClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( + "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, + ctx, &claimKeys{ + S: s, + OneTimeKeys: oneTimeKeys, + }, + ) } type queryKeys struct { S gomatrixserverlib.ServerName Keys map[string][]string - Res *gomatrixserverlib.RespQueryKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) QueryKeys( ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") - defer span.Finish() - - var result gomatrixserverlib.RespQueryKeys - request := queryKeys{ - S: s, - Keys: keys, - } - var response queryKeys - apiURL := h.federationAPIURL + FederationAPIQueryKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return result, err - } - if response.Err != nil { - return result, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( + "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, + ctx, &queryKeys{ + S: s, + Keys: keys, + }, + ) } type backfill struct { @@ -250,32 +206,20 @@ type backfill struct { RoomID string Limit int EventIDs []string - Res *gomatrixserverlib.Transaction - Err *api.FederationClientError } func (h *httpFederationInternalAPI) Backfill( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill") - defer span.Finish() - - request := backfill{ - S: s, - RoomID: roomID, - Limit: limit, - EventIDs: eventIDs, - } - var response backfill - apiURL := h.federationAPIURL + FederationAPIBackfillPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - if response.Err != nil { - return gomatrixserverlib.Transaction{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( + "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, + ctx, &backfill{ + S: s, + RoomID: roomID, + Limit: limit, + EventIDs: eventIDs, + }, + ) } type lookupState struct { @@ -283,63 +227,39 @@ type lookupState struct { RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion - Res *gomatrixserverlib.RespState - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupState( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState") - defer span.Finish() - - request := lookupState{ - S: s, - RoomID: roomID, - EventID: eventID, - RoomVersion: roomVersion, - } - var response lookupState - apiURL := h.federationAPIURL + FederationAPILookupStatePath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespState{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespState{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( + "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, + ctx, &lookupState{ + S: s, + RoomID: roomID, + EventID: eventID, + RoomVersion: roomVersion, + }, + ) } type lookupStateIDs struct { S gomatrixserverlib.ServerName RoomID string EventID string - Res *gomatrixserverlib.RespStateIDs - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupStateIDs( ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs") - defer span.Finish() - - request := lookupStateIDs{ - S: s, - RoomID: roomID, - EventID: eventID, - } - var response lookupStateIDs - apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespStateIDs{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespStateIDs{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( + "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, + ctx, &lookupStateIDs{ + S: s, + RoomID: roomID, + EventID: eventID, + }, + ) } type lookupMissingEvents struct { @@ -347,64 +267,38 @@ type lookupMissingEvents struct { RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion - Res struct { - Events []gomatrixserverlib.RawJSON `json:"events"` - } - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupMissingEvents( ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupMissingEvents") - defer span.Finish() - - request := lookupMissingEvents{ - S: s, - RoomID: roomID, - Missing: missing, - RoomVersion: roomVersion, - } - apiURL := h.federationAPIURL + FederationAPILookupMissingEventsPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &request) - if err != nil { - return res, err - } - if request.Err != nil { - return res, request.Err - } - res.Events = request.Res.Events - return res, nil + return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( + "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, + ctx, &lookupMissingEvents{ + S: s, + RoomID: roomID, + Missing: missing, + RoomVersion: roomVersion, + }, + ) } type getEvent struct { S gomatrixserverlib.ServerName EventID string - Res *gomatrixserverlib.Transaction - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetEvent( ctx context.Context, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent") - defer span.Finish() - - request := getEvent{ - S: s, - EventID: eventID, - } - var response getEvent - apiURL := h.federationAPIURL + FederationAPIGetEventPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.Transaction{}, err - } - if response.Err != nil { - return gomatrixserverlib.Transaction{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( + "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, + ctx, &getEvent{ + S: s, + EventID: eventID, + }, + ) } type getEventAuth struct { @@ -412,135 +306,86 @@ type getEventAuth struct { RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string - Res *gomatrixserverlib.RespEventAuth - Err *api.FederationClientError } func (h *httpFederationInternalAPI) GetEventAuth( ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetEventAuth") - defer span.Finish() - - request := getEventAuth{ - S: s, - RoomVersion: roomVersion, - RoomID: roomID, - EventID: eventID, - } - var response getEventAuth - apiURL := h.federationAPIURL + FederationAPIGetEventAuthPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return gomatrixserverlib.RespEventAuth{}, err - } - if response.Err != nil { - return gomatrixserverlib.RespEventAuth{}, response.Err - } - return *response.Res, nil + return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( + "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, + ctx, &getEventAuth{ + S: s, + RoomVersion: roomVersion, + RoomID: roomID, + EventID: eventID, + }, + ) } func (h *httpFederationInternalAPI) QueryServerKeys( ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryServerKeys", h.federationAPIURL+FederationAPIQueryServerKeysPath, + h.httpClient, ctx, req, res, + ) } type lookupServerKeys struct { S gomatrixserverlib.ServerName KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp - ServerKeys []gomatrixserverlib.ServerKeys - Err *api.FederationClientError } func (h *httpFederationInternalAPI) LookupServerKeys( ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) ([]gomatrixserverlib.ServerKeys, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys") - defer span.Finish() - - request := lookupServerKeys{ - S: s, - KeyRequests: keyRequests, - } - var response lookupServerKeys - apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return []gomatrixserverlib.ServerKeys{}, err - } - if response.Err != nil { - return []gomatrixserverlib.ServerKeys{}, response.Err - } - return response.ServerKeys, nil + return httputil.CallInternalProxyAPI[lookupServerKeys, []gomatrixserverlib.ServerKeys, *api.FederationClientError]( + "LookupServerKeys", h.federationAPIURL+FederationAPILookupServerKeysPath, h.httpClient, + ctx, &lookupServerKeys{ + S: s, + KeyRequests: keyRequests, + }, + ) } type eventRelationships struct { S gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion - Res gomatrixserverlib.MSC2836EventRelationshipsResponse - Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2836EventRelationships( ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships") - defer span.Finish() - - request := eventRelationships{ - S: s, - Req: r, - RoomVer: roomVersion, - } - var response eventRelationships - apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return res, err - } - if response.Err != nil { - return res, response.Err - } - return response.Res, nil + return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( + "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, + ctx, &eventRelationships{ + S: s, + Req: r, + RoomVer: roomVersion, + }, + ) } type spacesReq struct { S gomatrixserverlib.ServerName SuggestedOnly bool RoomID string - Res gomatrixserverlib.MSC2946SpacesResponse - Err *api.FederationClientError } func (h *httpFederationInternalAPI) MSC2946Spaces( ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces") - defer span.Finish() - - request := spacesReq{ - S: dst, - SuggestedOnly: suggestedOnly, - RoomID: roomID, - } - var response spacesReq - apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath - err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) - if err != nil { - return res, err - } - if response.Err != nil { - return res, response.Err - } - return response.Res, nil + return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( + "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, + ctx, &spacesReq{ + S: dst, + SuggestedOnly: suggestedOnly, + RoomID: roomID, + }, + ) } func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing { @@ -614,11 +459,10 @@ func (h *httpFederationInternalAPI) InputPublicKeys( request *api.InputPublicKeysRequest, response *api.InputPublicKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "InputPublicKey", h.federationAPIURL+FederationAPIInputPublicKeyPath, + h.httpClient, ctx, request, response, + ) } func (h *httpFederationInternalAPI) QueryPublicKeys( @@ -626,9 +470,8 @@ func (h *httpFederationInternalAPI) QueryPublicKeys( request *api.QueryPublicKeysRequest, response *api.QueryPublicKeysResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey") - defer span.Finish() - - apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryPublicKeys", h.federationAPIURL+FederationAPIQueryPublicKeyPath, + h.httpClient, ctx, request, response, + ) } diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 28e52b32d..a8b829a71 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -1,12 +1,14 @@ package inthttp import ( + "context" "encoding/json" "net/http" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -15,372 +17,180 @@ import ( func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryJoinedHostServerNamesInRoomRequest - var response api.QueryJoinedHostServerNamesInRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle( - FederationAPIPerformJoinRequestPath, - httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformJoinRequest - var response api.PerformJoinResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - intAPI.PerformJoin(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle( - FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformLeaveRequest - var response api.PerformLeaveResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom), ) + internalAPIMux.Handle( FederationAPIPerformInviteRequestPath, - httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformInviteRequest - var response api.PerformInviteResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite), ) + + internalAPIMux.Handle( + FederationAPIPerformLeaveRequestPath, + httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave), + ) + internalAPIMux.Handle( FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse { - var request api.PerformDirectoryLookupRequest - var response api.PerformDirectoryLookupResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup), ) + internalAPIMux.Handle( FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { - var request api.PerformBroadcastEDURequest - var response api.PerformBroadcastEDUResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + + internalAPIMux.Handle( + FederationAPIPerformJoinRequestPath, + httputil.MakeInternalRPCAPI( + "FederationAPIPerformJoinRequest", + func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { + intAPI.PerformJoin(ctx, req, res) + return nil + }, + ), + ) + internalAPIMux.Handle( FederationAPIGetUserDevicesPath, - httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse { - var request getUserDevices - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetUserDevices", + func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { + res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIClaimKeysPath, - httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse { - var request claimKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIClaimKeys", + func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { + res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIQueryKeysPath, - httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse { - var request queryKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIQueryKeys", + func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { + res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIBackfillPath, - httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse { - var request backfill - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIBackfill", + func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { + res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupStatePath, - httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse { - var request lookupState - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupState", + func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { + res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupStateIDsPath, - httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse { - var request lookupStateIDs - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupStateIDs", + func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { + res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPILookupMissingEventsPath, - httputil.MakeInternalAPI("LookupMissingEvents", func(req *http.Request) util.JSONResponse { - var request lookupMissingEvents - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupMissingEvents(req.Context(), request.S, request.RoomID, request.Missing, request.RoomVersion) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - for _, event := range res.Events { - js, err := json.Marshal(event) - if err != nil { - return util.MessageResponse(http.StatusInternalServerError, err.Error()) - } - request.Res.Events = append(request.Res.Events, js) - } - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupMissingEvents", + func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { + res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIGetEventPath, - httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { - var request getEvent - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetEvent", + func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { + res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIGetEventAuthPath, - httputil.MakeInternalAPI("GetEventAuth", func(req *http.Request) util.JSONResponse { - var request getEventAuth - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.GetEventAuth(req.Context(), request.S, request.RoomVersion, request.RoomID, request.EventID) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = &res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIGetEventAuth", + func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { + res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIQueryServerKeysPath, - httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse { - var request api.QueryServerKeysRequest - var response api.QueryServerKeysResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys), ) + internalAPIMux.Handle( FederationAPILookupServerKeysPath, - httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse { - var request lookupServerKeys - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.ServerKeys = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPILookupServerKeys", + func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { + res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPIEventRelationshipsPath, - httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse { - var request eventRelationships - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIMSC2836EventRelationships", + func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { + res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + return &res, federationClientError(err) + }, + ), ) + internalAPIMux.Handle( FederationAPISpacesSummaryPath, - httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse { - var request spacesReq - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.SuggestedOnly) - if err != nil { - ferr, ok := err.(*api.FederationClientError) - if ok { - request.Err = ferr - } else { - request.Err = &api.FederationClientError{ - Err: err.Error(), - } - } - } - request.Res = res - return util.JSONResponse{Code: http.StatusOK, JSON: request} - }), + httputil.MakeInternalProxyAPI( + "FederationAPIMSC2946SpacesSummary", + func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { + res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + return &res, federationClientError(err) + }, + ), ) + + // TODO: Look at this shape internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse { request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -394,8 +204,10 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + + // TODO: Look at this shape internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse { request := api.InputPublicKeysRequest{} response := api.InputPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -408,3 +220,18 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { }), ) } + +func federationClientError(err error) error { + switch ferr := err.(type) { + case nil: + return nil + case api.FederationClientError: + return &ferr + case *api.FederationClientError: + return ferr + default: + return &api.FederationClientError{ + Err: err.Error(), + } + } +} diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 2f9da1f25..ce8b06b70 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -30,9 +30,11 @@ func GetUserDevices( userID string, ) util.JSONResponse { var res keyapi.QueryDeviceMessagesResponse - keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ + if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{ UserID: userID, - }, &res) + }, &res); err != nil { + return util.ErrorResponse(err) + } if res.Error != nil { util.GetLogger(req.Context()).WithError(res.Error).Error("keyAPI.QueryDeviceMessages failed") return jsonerror.InternalServerError() @@ -47,7 +49,9 @@ func GetUserDevices( for _, dev := range res.Devices { sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID)) } - keyAPI.QuerySignatures(req.Context(), sigReq, sigRes) + if err := keyAPI.QuerySignatures(req.Context(), sigReq, sigRes); err != nil { + return jsonerror.InternalAPIError(req.Context(), err) + } response := gomatrixserverlib.RespUserDevices{ UserID: userID, diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 9ad3bd8eb..b48eaf78e 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -392,7 +392,7 @@ func SendJoin( // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { var response api.InputRoomEventsResponse - rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -401,7 +401,9 @@ func SendJoin( TransactionID: nil, }, }, - }, &response) + }, &response); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed") if response.NotAllowed { diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index b1a9b6710..b03d4c1d6 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,7 +19,7 @@ import ( "net/http" "time" - "github.com/matrix-org/dendrite/clientapi/httputil" + clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" @@ -61,9 +61,11 @@ func QueryDeviceKeys( } var queryRes api.QueryKeysResponse - keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ + if err := keyAPI.QueryKeys(httpReq.Context(), &api.QueryKeysRequest{ UserToDevices: qkr.DeviceKeys, - }, &queryRes) + }, &queryRes); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if queryRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(queryRes.Error).Error("Failed to QueryKeys") return jsonerror.InternalServerError() @@ -113,9 +115,11 @@ func ClaimOneTimeKeys( } var claimRes api.PerformClaimKeysResponse - keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ + if err := keyAPI.PerformClaimKeys(httpReq.Context(), &api.PerformClaimKeysRequest{ OneTimeKeys: cor.OneTimeKeys, - }, &claimRes) + }, &claimRes); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if claimRes.Error != nil { util.GetLogger(httpReq.Context()).WithError(claimRes.Error).Error("Failed to PerformClaimKeys") return jsonerror.InternalServerError() @@ -184,7 +188,7 @@ func NotaryKeys( ) util.JSONResponse { if req == nil { req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} - if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { + if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { return *reqErr } } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index dbaf68e5b..8e43ce959 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -277,7 +277,7 @@ func SendLeave( // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName var response api.InputRoomEventsResponse - rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ + if err := rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{ InputRoomEvents: []api.InputRoomEvent{ { Kind: api.KindNew, @@ -286,7 +286,9 @@ func SendLeave( TransactionID: nil, }, }, - }, &response) + }, &response); err != nil { + return jsonerror.InternalAPIError(httpReq.Context(), err) + } if response.ErrMsg != "" { util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).WithField("not_allowed", response.NotAllowed).Error("producer.SendEvents failed") diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 43003be38..87b6fa33e 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -458,7 +458,9 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli UserID: updatePayload.UserID, } uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return uploadRes.Error } diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index a111580c7..1c796f542 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -64,11 +64,12 @@ func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { +) error { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } + return nil } // Query the latest events and state for a room from the room server. diff --git a/internal/httputil/http.go b/internal/httputil/http.go index 4527e2b95..1e07ee33c 100644 --- a/internal/httputil/http.go +++ b/internal/httputil/http.go @@ -19,19 +19,21 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" - "github.com/matrix-org/dendrite/userapi/api" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) -// PostJSON performs a POST request with JSON on an internal HTTP API -func PostJSON( +// PostJSON performs a POST request with JSON on an internal HTTP API. +// The error will match the errtype if returned from the remote API, or +// will be a different type if there was a problem reaching the API. +func PostJSON[reqtype, restype any, errtype error]( ctx context.Context, span opentracing.Span, httpClient *http.Client, - apiURL string, request, response interface{}, + apiURL string, request *reqtype, response *restype, ) error { jsonBytes, err := json.Marshal(request) if err != nil { @@ -69,17 +71,23 @@ func PostJSON( if err != nil { return err } - if res.StatusCode != http.StatusOK { - var errorBody struct { - Message string `json:"message"` - } - if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing - return nil - } - if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil { - return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message) - } - return fmt.Errorf("internal API: %d from %s", res.StatusCode, apiURL) + var body []byte + body, err = io.ReadAll(res.Body) + if err != nil { + return err } - return json.NewDecoder(res.Body).Decode(response) + if res.StatusCode != http.StatusOK { + if len(body) == 0 { + return fmt.Errorf("HTTP %d from %s (no response body)", res.StatusCode, apiURL) + } + var reserr errtype + if err = json.Unmarshal(body, reserr); err != nil { + return fmt.Errorf("HTTP %d from %s", res.StatusCode, apiURL) + } + return reserr + } + if err = json.Unmarshal(body, response); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + return nil } diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go new file mode 100644 index 000000000..385092d9c --- /dev/null +++ b/internal/httputil/internalapi.go @@ -0,0 +1,93 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httputil + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "reflect" + + "github.com/matrix-org/util" + opentracing "github.com/opentracing/opentracing-go" +) + +type InternalAPIError struct { + Type string + Message string +} + +func (e InternalAPIError) Error() string { + return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) +} + +func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler { + return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { + var request reqtype + var response restype + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := f(req.Context(), &request, &response); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: &InternalAPIError{ + Type: reflect.TypeOf(err).String(), + Message: fmt.Sprintf("%s", err), + }, + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: &response, + } + }) +} + +func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler { + return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { + var request reqtype + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + response, err := f(req.Context(), &request) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } + }) +} + +func CallInternalRPCAPI[reqtype, restype any](name, url string, client *http.Client, ctx context.Context, request *reqtype, response *restype) error { + span, ctx := opentracing.StartSpanFromContext(ctx, name) + defer span.Finish() + + return PostJSON[reqtype, restype, InternalAPIError](ctx, span, client, url, request, response) +} + +func CallInternalProxyAPI[reqtype, restype any, errtype error](name, url string, client *http.Client, ctx context.Context, request *reqtype) (restype, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, name) + defer span.Finish() + + var response restype + return response, PostJSON[reqtype, restype, errtype](ctx, span, client, url, request, &response) +} diff --git a/keyserver/api/api.go b/keyserver/api/api.go index c0a1eedbb..9ba3988b9 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -38,32 +38,32 @@ type KeyInternalAPI interface { // API functions required by the clientapi type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // API functions required by the userapi type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error } // API functions required by the syncapi type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error } type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // KeyError is returned if there was a problem performing/querying the server diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 08bbfedb8..99859dff6 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos } // nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { +func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { // Find the keys to store. byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} toStore := types.CrossSigningKeyMap{} @@ -115,7 +115,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "Master key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey @@ -131,7 +131,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "Self-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey @@ -146,7 +146,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "User-signing key sanity check failed: " + err.Error(), IsInvalidParam: true, } - return + return nil } byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey @@ -161,7 +161,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "No keys were supplied in the request", IsMissingParam: true, } - return + return nil } // We can't have a self-signing or user-signing key without a master @@ -174,7 +174,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: "Retrieving cross-signing keys from database failed: " + err.Error(), } - return + return nil } // If we still can't find a master key for the user then stop the upload. @@ -185,7 +185,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P Err: "No master key was found", IsMissingParam: true, } - return + return nil } } @@ -212,7 +212,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P } } if !changed { - return + return nil } // Store the keys. @@ -220,7 +220,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), } - return + return nil } // Now upload any signatures that were included with the keys. @@ -238,7 +238,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), } - return + return nil } } } @@ -255,17 +255,18 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P update.SelfSigningKey = &ssk } if update.MasterKey == nil && update.SelfSigningKey == nil { - return + return nil } if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return + return nil } + return nil } -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) { +func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { // Before we do anything, we need the master and self-signing keys for this user. // Then we can verify the signatures make sense. queryReq := &api.QueryKeysRequest{ @@ -276,7 +277,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req for userID := range req.Signatures { queryReq.UserToDevices[userID] = []string{} } - a.QueryKeys(ctx, queryReq, queryRes) + _ = a.QueryKeys(ctx, queryReq, queryRes) selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} @@ -322,14 +323,14 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processSelfSignatures: %s", err), } - return + return nil } if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("a.processOtherSignatures: %s", err), } - return + return nil } // Finally, generate a notification that we updated the signatures. @@ -345,9 +346,10 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req res.Error = &api.KeyError{ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), } - return + return nil } } + return nil } func (a *KeyInternalAPI) processSelfSignatures( @@ -520,7 +522,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase( } } -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) { +func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { for targetUserID, forTargetUser := range req.TargetIDs { keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) if err != nil && err != sql.ErrNoRows { @@ -559,7 +561,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign res.Error = &api.KeyError{ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), } - return + return nil } for sourceUserID, forSourceUser := range sigMap { @@ -581,4 +583,5 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign } } } + return nil } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index e317755e4..80efbec51 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -119,7 +119,7 @@ type DeviceListUpdaterDatabase interface { } type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error } // KeyChangeProducer is the interface for producers.KeyChange useful for testing. @@ -421,7 +421,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam uploadReq.SelfSigningKey = *res.SelfSigningKey } } - u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) } err = u.updateDeviceList(&res) if err != nil { diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 0c2405f34..0520a9e66 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -113,8 +113,8 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys type mockDeviceListUpdaterAPI struct { } -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) { - +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + return nil } type roundTripper struct { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 91f011517..41b4d44a4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -48,18 +48,20 @@ func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { a.UserAPI = i } -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { +func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), } + return nil } res.Offset = latest res.UserIDs = userIDs + return nil } -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { res.KeyErrors = make(map[string]map[string]*api.KeyError) if len(req.DeviceKeys) > 0 { a.uploadLocalDeviceKeys(ctx, req, res) @@ -67,9 +69,10 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } + return nil } -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { +func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // wrap request map in a top-level by-domain map @@ -113,6 +116,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } + return nil } func (a *KeyInternalAPI) claimRemoteKeys( @@ -172,32 +176,34 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) { +func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to delete device keys: %s", err), } } + return nil } -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { +func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("Failed to query OTK counts: %s", err), } - return + return nil } res.Count = *count + return nil } -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) { +func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query DB for device keys: %s", err), } - return + return nil } maxStreamID := int64(0) for _, m := range msgs { @@ -215,10 +221,11 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query } res.Devices = result res.StreamID = maxStreamID + return nil } // nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { +func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) @@ -244,7 +251,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to query local device keys: %s", err), } - return + return nil } // pull out display names after we have the keys so we handle wildcards correctly @@ -318,7 +325,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return + return nil } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue @@ -344,7 +351,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // Stop executing the function if the context was canceled/the deadline was exceeded, // as we can't continue without a valid context. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return + return nil } logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") continue @@ -372,6 +379,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } } } + return nil } func (a *KeyInternalAPI) remoteKeysFromDatabase( diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index dac61d1ea..7a7131145 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -68,168 +67,108 @@ func (h *httpKeyInternalAPI) PerformClaimKeys( ctx context.Context, request *api.PerformClaimKeysRequest, response *api.PerformClaimKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformClaimKeys", h.apiURL+PerformClaimKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformDeleteKeys( ctx context.Context, request *api.PerformDeleteKeysRequest, response *api.PerformDeleteKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformClaimKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformDeleteKeys", h.apiURL+PerformDeleteKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadKeys( ctx context.Context, request *api.PerformUploadKeysRequest, response *api.PerformUploadKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadKeys", h.apiURL+PerformUploadKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryKeys( ctx context.Context, request *api.QueryKeysRequest, response *api.QueryKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys") - defer span.Finish() - - apiURL := h.apiURL + QueryKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryKeys", h.apiURL+QueryKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryOneTimeKeys( ctx context.Context, request *api.QueryOneTimeKeysRequest, response *api.QueryOneTimeKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") - defer span.Finish() - - apiURL := h.apiURL + QueryOneTimeKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryOneTimeKeys", h.apiURL+QueryOneTimeKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryDeviceMessages( ctx context.Context, request *api.QueryDeviceMessagesRequest, response *api.QueryDeviceMessagesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceMessages") - defer span.Finish() - - apiURL := h.apiURL + QueryDeviceMessagesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryDeviceMessages", h.apiURL+QueryDeviceMessagesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, response *api.QueryKeyChangesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges") - defer span.Finish() - - apiURL := h.apiURL + QueryKeyChangesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QueryKeyChanges", h.apiURL+QueryKeyChangesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadDeviceKeys( ctx context.Context, request *api.PerformUploadDeviceKeysRequest, response *api.PerformUploadDeviceKeysResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadDeviceKeysPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadDeviceKeys", h.apiURL+PerformUploadDeviceKeysPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures( ctx context.Context, request *api.PerformUploadDeviceSignaturesRequest, response *api.PerformUploadDeviceSignaturesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures") - defer span.Finish() - - apiURL := h.apiURL + PerformUploadDeviceSignaturesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUploadDeviceSignatures", h.apiURL+PerformUploadDeviceSignaturesPath, + h.httpClient, ctx, request, response, + ) } func (h *httpKeyInternalAPI) QuerySignatures( ctx context.Context, request *api.QuerySignaturesRequest, response *api.QuerySignaturesResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures") - defer span.Finish() - - apiURL := h.apiURL + QuerySignaturesPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.KeyError{ - Err: err.Error(), - } - } +) error { + return httputil.CallInternalRPCAPI( + "QuerySignatures", h.apiURL+QuerySignaturesPath, + h.httpClient, ctx, request, response, + ) } diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 5bf5976a8..4e5f9fba4 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -15,124 +15,59 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/util" ) func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { - internalAPIMux.Handle(PerformClaimKeysPath, - httputil.MakeInternalAPI("performClaimKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformClaimKeysRequest{} - response := api.PerformClaimKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformClaimKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformClaimKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys), ) - internalAPIMux.Handle(PerformDeleteKeysPath, - httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformDeleteKeysRequest{} - response := api.PerformDeleteKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformDeleteKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformDeleteKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys), ) - internalAPIMux.Handle(PerformUploadKeysPath, - httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadKeysRequest{} - response := api.PerformUploadKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys), ) - internalAPIMux.Handle(PerformUploadDeviceKeysPath, - httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadDeviceKeysRequest{} - response := api.PerformUploadDeviceKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadDeviceKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadDeviceKeysPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys), ) - internalAPIMux.Handle(PerformUploadDeviceSignaturesPath, - httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse { - request := api.PerformUploadDeviceSignaturesRequest{} - response := api.PerformUploadDeviceSignaturesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.PerformUploadDeviceSignatures(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformUploadDeviceSignaturesPath, + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures), ) - internalAPIMux.Handle(QueryKeysPath, - httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse { - request := api.QueryKeysRequest{} - response := api.QueryKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryKeysPath, + httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys), ) - internalAPIMux.Handle(QueryOneTimeKeysPath, - httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { - request := api.QueryOneTimeKeysRequest{} - response := api.QueryOneTimeKeysResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryOneTimeKeys(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryOneTimeKeysPath, + httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys), ) - internalAPIMux.Handle(QueryDeviceMessagesPath, - httputil.MakeInternalAPI("queryDeviceMessages", func(req *http.Request) util.JSONResponse { - request := api.QueryDeviceMessagesRequest{} - response := api.QueryDeviceMessagesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryDeviceMessages(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryDeviceMessagesPath, + httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages), ) - internalAPIMux.Handle(QueryKeyChangesPath, - httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { - request := api.QueryKeyChangesRequest{} - response := api.QueryKeyChangesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeyChanges(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryKeyChangesPath, + httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges), ) - internalAPIMux.Handle(QuerySignaturesPath, - httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse { - request := api.QuerySignaturesRequest{} - response := api.QuerySignaturesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QuerySignatures(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QuerySignaturesPath, + httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), ) } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 38baa617f..ee0212ecf 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -40,7 +40,7 @@ type InputRoomEventsAPI interface { ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, - ) + ) error } // Query the latest events and state for a room from the room server. @@ -139,15 +139,15 @@ type ClientRoomserverAPI interface { GetAliasesForRoomID(ctx context.Context, req *GetAliasesForRoomIDRequest, res *GetAliasesForRoomIDResponse) error // PerformRoomUpgrade upgrades a room to a newer version - PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) - PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) - PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) - PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) + PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error + PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error + PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error + PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error PerformInvite(ctx context.Context, req *PerformInviteRequest, res *PerformInviteResponse) error - PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) + PerformJoin(ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse) error PerformLeave(ctx context.Context, req *PerformLeaveRequest, res *PerformLeaveResponse) error - PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) + PerformPublish(ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse) error // PerformForget forgets a rooms history for a specific user PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error SetRoomAlias(ctx context.Context, req *SetRoomAliasRequest, res *SetRoomAliasResponse) error @@ -158,7 +158,7 @@ type UserRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error - PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) + PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error } type FederationRoomserverAPI interface { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 211f320ff..18a617331 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -35,9 +35,10 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, -) { - t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformInvite( @@ -45,44 +46,49 @@ func (t *RoomserverInternalAPITrace) PerformInvite( req *PerformInviteRequest, res *PerformInviteResponse, ) error { - util.GetLogger(ctx).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) - return t.Impl.PerformInvite(ctx, req, res) + err := t.Impl.PerformInvite(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformInvite req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformPeek( ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse, -) { - t.Impl.PerformPeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformPeek(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformUnpeek( ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse, -) { - t.Impl.PerformUnpeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformUnpeek(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformUnpeek req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformRoomUpgrade( ctx context.Context, req *PerformRoomUpgradeRequest, res *PerformRoomUpgradeResponse, -) { - t.Impl.PerformRoomUpgrade(ctx, req, res) - util.GetLogger(ctx).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformRoomUpgrade(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformRoomUpgrade req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformJoin( ctx context.Context, req *PerformJoinRequest, res *PerformJoinResponse, -) { - t.Impl.PerformJoin(ctx, req, res) - util.GetLogger(ctx).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformJoin(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformJoin req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformLeave( @@ -99,27 +105,30 @@ func (t *RoomserverInternalAPITrace) PerformPublish( ctx context.Context, req *PerformPublishRequest, res *PerformPublishResponse, -) { - t.Impl.PerformPublish(ctx, req, res) - util.GetLogger(ctx).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformPublish(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformPublish req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformAdminEvacuateRoom( ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse, -) { - t.Impl.PerformAdminEvacuateRoom(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformAdminEvacuateRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateRoom req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse, -) { - t.Impl.PerformAdminEvacuateUser(ctx, req, res) - util.GetLogger(ctx).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) +) error { + err := t.Impl.PerformAdminEvacuateUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminEvacuateUser req=%+v res=%+v", js(req), js(res)) + return err } func (t *RoomserverInternalAPITrace) PerformInboundPeek( @@ -128,7 +137,7 @@ func (t *RoomserverInternalAPITrace) PerformInboundPeek( res *PerformInboundPeekResponse, ) error { err := t.Impl.PerformInboundPeek(ctx, req, res) - util.GetLogger(ctx).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) + util.GetLogger(ctx).WithError(err).Infof("PerformInboundPeek req=%+v res=%+v", js(req), js(res)) return err } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 344e9b079..bc2f28176 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -90,7 +90,9 @@ func SendInputRoomEvents( Asynchronous: async, } var response InputRoomEventsResponse - rsAPI.InputRoomEvents(ctx, &request, &response) + if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { + return err + } return response.Err() } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 339c9796c..8d24f3c59 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -337,18 +337,18 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { +) error { // Queue up the event into the roomserver. replySub, err := r.queueInputRoomEvents(ctx, request) if err != nil { response.ErrMsg = err.Error() - return + return nil } // If we aren't waiting for synchronous responses then we can // give up here, there is nothing further to do. if replySub == nil { - return + return nil } // Otherwise, we'll want to sit and wait for the responses @@ -360,12 +360,14 @@ func (r *Inputer) InputRoomEvents( msg, err := replySub.NextMsgWithContext(ctx) if err != nil { response.ErrMsg = err.Error() - return + return nil } if len(msg.Data) > 0 { response.ErrMsg = string(msg.Data) } } + + return nil } var roomserverInputBackpressure = prometheus.NewGaugeVec( diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 6c7c6c98b..cb6b22d32 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -43,21 +43,21 @@ func (r *Admin) PerformAdminEvacuateRoom( ctx context.Context, req *api.PerformAdminEvacuateRoomRequest, res *api.PerformAdminEvacuateRoomResponse, -) { +) error { roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.RoomInfo: %s", err), } - return + return nil } if roomInfo == nil || roomInfo.IsStub() { res.Error = &api.PerformError{ Code: api.PerformErrorNoRoom, Msg: fmt.Sprintf("Room %s not found", req.RoomID), } - return + return nil } memberNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) @@ -66,7 +66,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetMembershipEventNIDsForRoom: %s", err), } - return + return nil } memberEvents, err := r.DB.Events(ctx, memberNIDs) @@ -75,7 +75,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.Events: %s", err), } - return + return nil } inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) @@ -89,7 +89,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Queryer.QueryLatestEventsAndState: %s", err), } - return + return nil } prevEvents := latestRes.LatestEvents @@ -104,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("json.Unmarshal: %s", err), } - return + return nil } memberContent.Membership = gomatrixserverlib.Leave @@ -122,7 +122,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("json.Marshal: %s", err), } - return + return nil } eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) @@ -131,7 +131,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("gomatrixserverlib.StateNeededForEventBuilder: %s", err), } - return + return nil } event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) @@ -140,7 +140,7 @@ func (r *Admin) PerformAdminEvacuateRoom( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("eventutil.BuildEvent: %s", err), } - return + return nil } inputEvents = append(inputEvents, api.InputRoomEvent{ @@ -160,28 +160,28 @@ func (r *Admin) PerformAdminEvacuateRoom( Asynchronous: true, } inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) + return r.Inputer.InputRoomEvents(ctx, inputReq, inputRes) } func (r *Admin) PerformAdminEvacuateUser( ctx context.Context, req *api.PerformAdminEvacuateUserRequest, res *api.PerformAdminEvacuateUserResponse, -) { +) error { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("Malformed user ID: %s", err), } - return + return nil } if domain != r.Cfg.Matrix.ServerName { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, Msg: "Can only evacuate local users using this endpoint", } - return + return nil } roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Join) @@ -190,7 +190,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), } - return + return nil } inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, gomatrixserverlib.Invite) @@ -199,7 +199,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.DB.GetRoomsByMembership: %s", err), } - return + return nil } for _, roomID := range append(roomIDs, inviteRoomIDs...) { @@ -214,7 +214,7 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Leaver.PerformLeave: %s", err), } - return + return nil } if len(outputEvents) == 0 { continue @@ -224,9 +224,10 @@ func (r *Admin) PerformAdminEvacuateUser( Code: api.PerformErrorBadRequest, Msg: fmt.Sprintf("r.Inputer.WriteOutputEvents: %s", err), } - return + return nil } res.Affected = append(res.Affected, roomID) } + return nil } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e1ff4eabb..483e78c3f 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -241,7 +241,9 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) + } if err = inputRes.Err(); err != nil { res.Error = &api.PerformError{ Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 1445b4088..43be54beb 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -52,7 +52,7 @@ func (r *Joiner) PerformJoin( ctx context.Context, req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, -) { +) error { logger := logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomIDOrAlias, "user_id": req.UserID, @@ -71,11 +71,12 @@ func (r *Joiner) PerformJoin( Msg: err.Error(), } } - return + return nil } logger.Info("User joined room successfully") res.RoomID = roomID res.JoinedVia = joinedVia + return nil } func (r *Joiner) performJoin( @@ -291,7 +292,12 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := rsAPI.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNoOperation, + Msg: fmt.Sprintf("InputRoomEvents failed: %s", err), + } + } if err = inputRes.Err(); err != nil { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorNotAllowed, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 56e7240d0..036404cd2 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -186,7 +186,9 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return nil, fmt.Errorf("r.Inputer.InputRoomEvents: %w", err) + } if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/roomserver/internal/perform/perform_peek.go b/roomserver/internal/perform/perform_peek.go index 5560916b2..74d87a5b4 100644 --- a/roomserver/internal/perform/perform_peek.go +++ b/roomserver/internal/perform/perform_peek.go @@ -44,7 +44,7 @@ func (r *Peeker) PerformPeek( ctx context.Context, req *api.PerformPeekRequest, res *api.PerformPeekResponse, -) { +) error { roomID, err := r.performPeek(ctx, req) if err != nil { perr, ok := err.(*api.PerformError) @@ -57,6 +57,7 @@ func (r *Peeker) PerformPeek( } } res.RoomID = roomID + return nil } func (r *Peeker) performPeek( diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go index 6ff42ac1a..1631fc657 100644 --- a/roomserver/internal/perform/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -29,11 +29,12 @@ func (r *Publisher) PerformPublish( ctx context.Context, req *api.PerformPublishRequest, res *api.PerformPublishResponse, -) { +) error { err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public") if err != nil { res.Error = &api.PerformError{ Msg: err.Error(), } } + return nil } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 1fe8d5a0f..49e9067c9 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -41,7 +41,7 @@ func (r *Unpeeker) PerformUnpeek( ctx context.Context, req *api.PerformUnpeekRequest, res *api.PerformUnpeekResponse, -) { +) error { if err := r.performUnpeek(ctx, req); err != nil { perr, ok := err.(*api.PerformError) if ok { @@ -52,6 +52,7 @@ func (r *Unpeeker) PerformUnpeek( } } } + return nil } func (r *Unpeeker) performUnpeek( diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 393d7dd14..d6dc9708c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -45,12 +45,13 @@ func (r *Upgrader) PerformRoomUpgrade( ctx context.Context, req *api.PerformRoomUpgradeRequest, res *api.PerformRoomUpgradeResponse, -) { +) error { res.NewRoomID, res.Error = r.performRoomUpgrade(ctx, req) if res.Error != nil { res.NewRoomID = "" logrus.WithContext(ctx).WithError(res.Error).Error("Room upgrade failed") } + return nil } func (r *Upgrader) performRoomUpgrade( @@ -286,22 +287,24 @@ func publishNewRoomAndUnpublishOldRoom( ) { // expose this room in the published room list var pubNewRoomRes api.PerformPublishResponse - URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: newRoomID, Visibility: "public", - }, &pubNewRoomRes) - if pubNewRoomRes.Error != nil { + }, &pubNewRoomRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") + } else if pubNewRoomRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(pubNewRoomRes.Error).Error("failed to visibility:public") } var unpubOldRoomRes api.PerformPublishResponse // remove the old room from the published room list - URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ + if err := URSAPI.PerformPublish(ctx, &api.PerformPublishRequest{ RoomID: oldRoomID, Visibility: "private", - }, &unpubOldRoomRes) - if unpubOldRoomRes.Error != nil { + }, &unpubOldRoomRes); err != nil { + util.GetLogger(ctx).WithError(err).Error("failed to reach internal API") + } else if unpubOldRoomRes.Error != nil { // treat as non-fatal since the room is already made by this point util.GetLogger(ctx).WithError(unpubOldRoomRes.Error).Error("failed to visibility:private") } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 2fa8afc49..d16f67c69 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -3,7 +3,6 @@ package inthttp import ( "context" "errors" - "fmt" "net/http" asAPI "github.com/matrix-org/dendrite/appservice/api" @@ -14,7 +13,6 @@ import ( userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/opentracing/opentracing-go" ) const ( @@ -106,11 +104,10 @@ func (h *httpRoomserverInternalAPI) SetRoomAlias( request *api.SetRoomAliasRequest, response *api.SetRoomAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "SetRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverSetRoomAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "SetRoomAlias", h.roomserverURL+RoomserverSetRoomAliasPath, + h.httpClient, ctx, request, response, + ) } // GetRoomIDForAlias implements RoomserverAliasAPI @@ -119,11 +116,10 @@ func (h *httpRoomserverInternalAPI) GetRoomIDForAlias( request *api.GetRoomIDForAliasRequest, response *api.GetRoomIDForAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetRoomIDForAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetRoomIDForAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "GetRoomIDForAlias", h.roomserverURL+RoomserverGetRoomIDForAliasPath, + h.httpClient, ctx, request, response, + ) } // GetAliasesForRoomID implements RoomserverAliasAPI @@ -132,11 +128,10 @@ func (h *httpRoomserverInternalAPI) GetAliasesForRoomID( request *api.GetAliasesForRoomIDRequest, response *api.GetAliasesForRoomIDResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "GetAliasesForRoomID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverGetAliasesForRoomIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "GetAliasesForRoomID", h.roomserverURL+RoomserverGetAliasesForRoomIDPath, + h.httpClient, ctx, request, response, + ) } // RemoveRoomAlias implements RoomserverAliasAPI @@ -145,11 +140,10 @@ func (h *httpRoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "RemoveRoomAlias") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverRemoveRoomAliasPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "RemoveRoomAlias", h.roomserverURL+RoomserverRemoveRoomAliasPath, + h.httpClient, ctx, request, response, + ) } // InputRoomEvents implements RoomserverInputAPI @@ -157,15 +151,14 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { +) error { + if err := httputil.CallInternalRPCAPI( + "InputRoomEvents", h.roomserverURL+RoomserverInputRoomEventsPath, + h.httpClient, ctx, request, response, + ); err != nil { response.ErrMsg = err.Error() } + return nil } func (h *httpRoomserverInternalAPI) PerformInvite( @@ -173,45 +166,32 @@ func (h *httpRoomserverInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInvite") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformInvitePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInvite", h.roomserverURL+RoomserverPerformInvitePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformJoin( ctx context.Context, request *api.PerformJoinRequest, response *api.PerformJoinResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoin") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformJoinPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformJoin", h.roomserverURL+RoomserverPerformJoinPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformPeek( ctx context.Context, request *api.PerformPeekRequest, response *api.PerformPeekResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformPeekPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformPeek", h.roomserverURL+RoomserverPerformPeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformInboundPeek( @@ -219,45 +199,32 @@ func (h *httpRoomserverInternalAPI) PerformInboundPeek( request *api.PerformInboundPeekRequest, response *api.PerformInboundPeekResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInboundPeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformInboundPeekPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformInboundPeek", h.roomserverURL+RoomserverPerformInboundPeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformUnpeek( ctx context.Context, request *api.PerformUnpeekRequest, response *api.PerformUnpeekResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUnpeek") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformUnpeekPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformUnpeek", h.roomserverURL+RoomserverPerformUnpeekPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformRoomUpgrade( ctx context.Context, request *api.PerformRoomUpgradeRequest, response *api.PerformRoomUpgradeResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformRoomUpgrade") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformRoomUpgradePath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err != nil { - response.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } +) error { + return httputil.CallInternalRPCAPI( + "PerformRoomUpgrade", h.roomserverURL+RoomserverPerformRoomUpgradePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformLeave( @@ -265,62 +232,43 @@ func (h *httpRoomserverInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeave") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformLeavePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLeave", h.roomserverURL+RoomserverPerformLeavePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformPublish( ctx context.Context, - req *api.PerformPublishRequest, - res *api.PerformPublishResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPublish") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformPublishPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformPublishRequest, + response *api.PerformPublishResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformPublish", h.roomserverURL+RoomserverPerformPublishPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformAdminEvacuateRoom( ctx context.Context, - req *api.PerformAdminEvacuateRoomRequest, - res *api.PerformAdminEvacuateRoomResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateRoomPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformAdminEvacuateRoomRequest, + response *api.PerformAdminEvacuateRoomResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminEvacuateRoom", h.roomserverURL+RoomserverPerformAdminEvacuateRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( ctx context.Context, - req *api.PerformAdminEvacuateUserRequest, - res *api.PerformAdminEvacuateUserResponse, -) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAdminEvacuateUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformAdminEvacuateUserPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = &api.PerformError{ - Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), - } - } + request *api.PerformAdminEvacuateUserRequest, + response *api.PerformAdminEvacuateUserResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminEvacuateUser", h.roomserverURL+RoomserverPerformAdminEvacuateUserPath, + h.httpClient, ctx, request, response, + ) } // QueryLatestEventsAndState implements RoomserverQueryAPI @@ -329,11 +277,10 @@ func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLatestEventsAndState") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryLatestEventsAndStatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryLatestEventsAndState", h.roomserverURL+RoomserverQueryLatestEventsAndStatePath, + h.httpClient, ctx, request, response, + ) } // QueryStateAfterEvents implements RoomserverQueryAPI @@ -342,11 +289,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAfterEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAfterEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryStateAfterEvents", h.roomserverURL+RoomserverQueryStateAfterEventsPath, + h.httpClient, ctx, request, response, + ) } // QueryEventsByID implements RoomserverQueryAPI @@ -355,11 +301,10 @@ func (h *httpRoomserverInternalAPI) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsByID") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryEventsByID", h.roomserverURL+RoomserverQueryEventsByIDPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryPublishedRooms( @@ -367,11 +312,10 @@ func (h *httpRoomserverInternalAPI) QueryPublishedRooms( request *api.QueryPublishedRoomsRequest, response *api.QueryPublishedRoomsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublishedRooms") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryPublishedRoomsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryPublishedRooms", h.roomserverURL+RoomserverQueryPublishedRoomsPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipForUser implements RoomserverQueryAPI @@ -380,11 +324,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMembershipForUser", h.roomserverURL+RoomserverQueryMembershipForUserPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipsForRoom implements RoomserverQueryAPI @@ -393,11 +336,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipsForRoom( request *api.QueryMembershipsForRoomRequest, response *api.QueryMembershipsForRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMembershipsForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMembershipsForRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMembershipsForRoom", h.roomserverURL+RoomserverQueryMembershipsForRoomPath, + h.httpClient, ctx, request, response, + ) } // QueryMembershipsForRoom implements RoomserverQueryAPI @@ -406,11 +348,10 @@ func (h *httpRoomserverInternalAPI) QueryServerJoinedToRoom( request *api.QueryServerJoinedToRoomRequest, response *api.QueryServerJoinedToRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerJoinedToRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryServerJoinedToRoom", h.roomserverURL+RoomserverQueryServerJoinedToRoomPath, + h.httpClient, ctx, request, response, + ) } // QueryServerAllowedToSeeEvent implements RoomserverQueryAPI @@ -419,11 +360,10 @@ func (h *httpRoomserverInternalAPI) QueryServerAllowedToSeeEvent( request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, ) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerAllowedToSeeEvent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerAllowedToSeeEventPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryServerAllowedToSeeEvent", h.roomserverURL+RoomserverQueryServerAllowedToSeeEventPath, + h.httpClient, ctx, request, response, + ) } // QueryMissingEvents implements RoomServerQueryAPI @@ -432,11 +372,10 @@ func (h *httpRoomserverInternalAPI) QueryMissingEvents( request *api.QueryMissingEventsRequest, response *api.QueryMissingEventsResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryMissingEvents", h.roomserverURL+RoomserverQueryMissingEventsPath, + h.httpClient, ctx, request, response, + ) } // QueryStateAndAuthChain implements RoomserverQueryAPI @@ -445,11 +384,10 @@ func (h *httpRoomserverInternalAPI) QueryStateAndAuthChain( request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryStateAndAuthChain") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryStateAndAuthChainPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryStateAndAuthChain", h.roomserverURL+RoomserverQueryStateAndAuthChainPath, + h.httpClient, ctx, request, response, + ) } // PerformBackfill implements RoomServerQueryAPI @@ -458,11 +396,10 @@ func (h *httpRoomserverInternalAPI) PerformBackfill( request *api.PerformBackfillRequest, response *api.PerformBackfillResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBackfill") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformBackfillPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformBackfill", h.roomserverURL+RoomserverPerformBackfillPath, + h.httpClient, ctx, request, response, + ) } // QueryRoomVersionCapabilities implements RoomServerQueryAPI @@ -471,11 +408,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionCapabilities( request *api.QueryRoomVersionCapabilitiesRequest, response *api.QueryRoomVersionCapabilitiesResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionCapabilities") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionCapabilitiesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryRoomVersionCapabilities", h.roomserverURL+RoomserverQueryRoomVersionCapabilitiesPath, + h.httpClient, ctx, request, response, + ) } // QueryRoomVersionForRoom implements RoomServerQueryAPI @@ -488,16 +424,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( response.RoomVersion = roomVersion return nil } - - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomVersionForRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomVersionForRoomPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) - if err == nil { - h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) - } - return err + return httputil.CallInternalRPCAPI( + "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryCurrentState( @@ -505,11 +435,10 @@ func (h *httpRoomserverInternalAPI) QueryCurrentState( request *api.QueryCurrentStateRequest, response *api.QueryCurrentStateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryCurrentState", h.roomserverURL+RoomserverQueryCurrentStatePath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryRoomsForUser( @@ -517,11 +446,10 @@ func (h *httpRoomserverInternalAPI) QueryRoomsForUser( request *api.QueryRoomsForUserRequest, response *api.QueryRoomsForUserResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryRoomsForUser", h.roomserverURL+RoomserverQueryRoomsForUserPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryBulkStateContent( @@ -529,68 +457,75 @@ func (h *httpRoomserverInternalAPI) QueryBulkStateContent( request *api.QueryBulkStateContentRequest, response *api.QueryBulkStateContentResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryBulkStateContent", h.roomserverURL+RoomserverQueryBulkStateContentPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QuerySharedUsers( - ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, + ctx context.Context, + request *api.QuerySharedUsersRequest, + response *api.QuerySharedUsersResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QuerySharedUsers", h.roomserverURL+RoomserverQuerySharedUsersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryKnownUsers( - ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, + ctx context.Context, + request *api.QueryKnownUsersRequest, + response *api.QueryKnownUsersResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryKnownUsers", h.roomserverURL+RoomserverQueryKnownUsersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryAuthChain( - ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse, + ctx context.Context, + request *api.QueryAuthChainRequest, + response *api.QueryAuthChainResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAuthChain") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryAuthChainPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryAuthChain", h.roomserverURL+RoomserverQueryAuthChainPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, + ctx context.Context, + request *api.QueryServerBannedFromRoomRequest, + response *api.QueryServerBannedFromRoomResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryServerBannedFromRoom", h.roomserverURL+RoomserverQueryServerBannedFromRoomPath, + h.httpClient, ctx, request, response, + ) } func (h *httpRoomserverInternalAPI) QueryRestrictedJoinAllowed( - ctx context.Context, req *api.QueryRestrictedJoinAllowedRequest, res *api.QueryRestrictedJoinAllowedResponse, + ctx context.Context, + request *api.QueryRestrictedJoinAllowedRequest, + response *api.QueryRestrictedJoinAllowedResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRestrictedJoinAllowed") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryRestrictedJoinAllowed - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "QueryRestrictedJoinAllowed", h.roomserverURL+RoomserverQueryRestrictedJoinAllowed, + h.httpClient, ctx, request, response, + ) } -func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api.PerformForgetRequest, res *api.PerformForgetResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformForget") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverPerformForgetPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpRoomserverInternalAPI) PerformForget( + ctx context.Context, + request *api.PerformForgetRequest, + response *api.PerformForgetResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformForget", h.roomserverURL+RoomserverPerformForgetPath, + h.httpClient, ctx, request, response, + ) } diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 993381585..e325d76a5 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -1,499 +1,196 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/util" ) // AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. // nolint: gocyclo func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { - internalAPIMux.Handle(RoomserverInputRoomEventsPath, - httputil.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse { - var request api.InputRoomEventsRequest - var response api.InputRoomEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.InputRoomEvents(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + RoomserverInputRoomEventsPath, + httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents), ) - internalAPIMux.Handle(RoomserverPerformInvitePath, - httputil.MakeInternalAPI("performInvite", func(req *http.Request) util.JSONResponse { - var request api.PerformInviteRequest - var response api.PerformInviteResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformInvite(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformInvitePath, + httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite), ) - internalAPIMux.Handle(RoomserverPerformJoinPath, - httputil.MakeInternalAPI("performJoin", func(req *http.Request) util.JSONResponse { - var request api.PerformJoinRequest - var response api.PerformJoinResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformJoin(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformJoinPath, + httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin), ) - internalAPIMux.Handle(RoomserverPerformLeavePath, - httputil.MakeInternalAPI("performLeave", func(req *http.Request) util.JSONResponse { - var request api.PerformLeaveRequest - var response api.PerformLeaveResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformLeave(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformLeavePath, + httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave), ) - internalAPIMux.Handle(RoomserverPerformPeekPath, - httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse { - var request api.PerformPeekRequest - var response api.PerformPeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformPeek(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformPeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek), ) - internalAPIMux.Handle(RoomserverPerformInboundPeekPath, - httputil.MakeInternalAPI("performInboundPeek", func(req *http.Request) util.JSONResponse { - var request api.PerformInboundPeekRequest - var response api.PerformInboundPeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.PerformInboundPeek(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformInboundPeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek), ) - internalAPIMux.Handle(RoomserverPerformPeekPath, - httputil.MakeInternalAPI("performUnpeek", func(req *http.Request) util.JSONResponse { - var request api.PerformUnpeekRequest - var response api.PerformUnpeekResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformUnpeek(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformUnpeekPath, + httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek), ) - internalAPIMux.Handle(RoomserverPerformRoomUpgradePath, - httputil.MakeInternalAPI("performRoomUpgrade", func(req *http.Request) util.JSONResponse { - var request api.PerformRoomUpgradeRequest - var response api.PerformRoomUpgradeResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformRoomUpgrade(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformRoomUpgradePath, + httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade), ) - internalAPIMux.Handle(RoomserverPerformPublishPath, - httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { - var request api.PerformPublishRequest - var response api.PerformPublishResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformPublish(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformPublishPath, + httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish), ) - internalAPIMux.Handle(RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalAPI("performAdminEvacuateRoom", func(req *http.Request) util.JSONResponse { - var request api.PerformAdminEvacuateRoomRequest - var response api.PerformAdminEvacuateRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformAdminEvacuateRoom(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformAdminEvacuateRoomPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom), ) - internalAPIMux.Handle(RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalAPI("performAdminEvacuateUser", func(req *http.Request) util.JSONResponse { - var request api.PerformAdminEvacuateUserRequest - var response api.PerformAdminEvacuateUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - r.PerformAdminEvacuateUser(req.Context(), &request, &response) - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverPerformAdminEvacuateUserPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), ) + internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalAPI("queryPublishedRooms", func(req *http.Request) util.JSONResponse { - var request api.QueryPublishedRoomsRequest - var response api.QueryPublishedRoomsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryPublishedRooms(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), ) + internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse { - var request api.QueryLatestEventsAndStateRequest - var response api.QueryLatestEventsAndStateResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState), ) + internalAPIMux.Handle( RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAfterEventsRequest - var response api.QueryStateAfterEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents), ) + internalAPIMux.Handle( RoomserverQueryEventsByIDPath, - httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { - var request api.QueryEventsByIDRequest - var response api.QueryEventsByIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID), ) + internalAPIMux.Handle( RoomserverQueryMembershipForUserPath, - httputil.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipForUserRequest - var response api.QueryMembershipForUserResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser), ) + internalAPIMux.Handle( RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryMembershipsForRoomRequest - var response api.QueryMembershipsForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom), ) + internalAPIMux.Handle( RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalAPI("queryServerJoinedToRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryServerJoinedToRoomRequest - var response api.QueryServerJoinedToRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom), ) + internalAPIMux.Handle( RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse { - var request api.QueryServerAllowedToSeeEventRequest - var response api.QueryServerAllowedToSeeEventResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent), ) + internalAPIMux.Handle( RoomserverQueryMissingEventsPath, - httputil.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingEventsRequest - var response api.QueryMissingEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents), ) + internalAPIMux.Handle( RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse { - var request api.QueryStateAndAuthChainRequest - var response api.QueryStateAndAuthChainResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain), ) + internalAPIMux.Handle( RoomserverPerformBackfillPath, - httputil.MakeInternalAPI("PerformBackfill", func(req *http.Request) util.JSONResponse { - var request api.PerformBackfillRequest - var response api.PerformBackfillResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.PerformBackfill(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill), ) + internalAPIMux.Handle( RoomserverPerformForgetPath, - httputil.MakeInternalAPI("PerformForget", func(req *http.Request) util.JSONResponse { - var request api.PerformForgetRequest - var response api.PerformForgetResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.PerformForget(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget), ) + internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionCapabilitiesRequest - var response api.QueryRoomVersionCapabilitiesResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities), ) + internalAPIMux.Handle( RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse { - var request api.QueryRoomVersionForRoomRequest - var response api.QueryRoomVersionForRoomResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom), ) + internalAPIMux.Handle( RoomserverSetRoomAliasPath, - httputil.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse { - var request api.SetRoomAliasRequest - var response api.SetRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias), ) + internalAPIMux.Handle( RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse { - var request api.GetRoomIDForAliasRequest - var response api.GetRoomIDForAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias), ) + internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse { - var request api.GetAliasesForRoomIDRequest - var response api.GetAliasesForRoomIDResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID), ) + internalAPIMux.Handle( RoomserverRemoveRoomAliasPath, - httputil.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse { - var request api.RemoveRoomAliasRequest - var response api.RemoveRoomAliasResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias), ) - internalAPIMux.Handle(RoomserverQueryCurrentStatePath, - httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { - request := api.QueryCurrentStateRequest{} - response := api.QueryCurrentStateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryCurrentStatePath, + httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState), ) - internalAPIMux.Handle(RoomserverQueryRoomsForUserPath, - httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { - request := api.QueryRoomsForUserRequest{} - response := api.QueryRoomsForUserResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryRoomsForUserPath, + httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser), ) - internalAPIMux.Handle(RoomserverQueryBulkStateContentPath, - httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { - request := api.QueryBulkStateContentRequest{} - response := api.QueryBulkStateContentResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryBulkStateContentPath, + httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent), ) - internalAPIMux.Handle(RoomserverQuerySharedUsersPath, - httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { - request := api.QuerySharedUsersRequest{} - response := api.QuerySharedUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQuerySharedUsersPath, + httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers), ) - internalAPIMux.Handle(RoomserverQueryKnownUsersPath, - httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { - request := api.QueryKnownUsersRequest{} - response := api.QueryKnownUsersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryKnownUsersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers), ) - internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { - request := api.QueryServerBannedFromRoomRequest{} - response := api.QueryServerBannedFromRoomResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryServerBannedFromRoomPath, + httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom), ) - internalAPIMux.Handle(RoomserverQueryAuthChainPath, - httputil.MakeInternalAPI("queryAuthChain", func(req *http.Request) util.JSONResponse { - request := api.QueryAuthChainRequest{} - response := api.QueryAuthChainResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryAuthChain(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryAuthChainPath, + httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain), ) - internalAPIMux.Handle(RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalAPI("queryRestrictedJoinAllowed", func(req *http.Request) util.JSONResponse { - request := api.QueryRestrictedJoinAllowedRequest{} - response := api.QueryRestrictedJoinAllowedResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := r.QueryRestrictedJoinAllowed(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + RoomserverQueryRestrictedJoinAllowed, + httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), ) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index eeded4275..3e9d90a1f 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -164,9 +164,9 @@ func TestMSC2836(t *testing.T) { // make everyone joined to each other's rooms nopRsAPI := &testRoomserverAPI{ userToJoinedRooms: map[string][]string{ - alice: []string{roomID}, - bob: []string{roomID}, - charlie: []string{roomID}, + alice: {roomID}, + bob: {roomID}, + charlie: {roomID}, }, events: map[string]*gomatrixserverlib.HeaderedEvent{ eventA.EventID(): eventA, diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 4bf54cae0..23824e366 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -31,7 +31,7 @@ import ( // DeviceOTKCounts adds one-time key counts to the /sync response func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error { var queryRes keyapi.QueryOneTimeKeysResponse - keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ + _ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{ UserID: userID, DeviceID: deviceID, }, &queryRes) @@ -73,7 +73,7 @@ func DeviceListCatchup( offset = int64(from) } var queryRes keyapi.QueryKeyChangesResponse - keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ + _ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ Offset: offset, ToOffset: toOffset, }, &queryRes) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index c7d8df740..80d2811be 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -22,31 +22,41 @@ var ( type mockKeyAPI struct{} -func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) { +func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error { + return nil } func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {} // PerformClaimKeys claims one-time keys for use in pre-key messages -func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) { +func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) { +func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) { +func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error { + return nil } -func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) { +func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error { + return nil } -func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { +func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error { + return nil } -func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { + return nil } -func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { + return nil } -func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) { +func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error { + return nil } -func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) { +func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error { + return nil } type mockRoomserverAPI struct { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index b10864ff5..931fef883 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -78,10 +78,11 @@ type syncKeyAPI struct { keyapi.SyncKeyAPI } -func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error { + return nil } -func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { - +func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error { + return nil } func TestSyncAPIAccessTokens(t *testing.T) { diff --git a/userapi/api/api.go b/userapi/api/api.go index df9408acb..388f97cb4 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -100,7 +100,7 @@ type ClientUserAPI interface { QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error - QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error QueryThreePIDsForLocalpart(ctx context.Context, req *QueryThreePIDsForLocalpartRequest, res *QueryThreePIDsForLocalpartResponse) error QueryLocalpartForThreePID(ctx context.Context, req *QueryLocalpartForThreePIDRequest, res *QueryLocalpartForThreePIDResponse) error diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 6d8d28007..7e2f69615 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -94,9 +94,10 @@ func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *Per util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) return err } -func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { - t.Impl.QueryKeyBackup(ctx, req, res) +func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) error { + err := t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) + return err } func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { err := t.Impl.QueryProfile(ctx, req, res) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 422eb076e..78b226d46 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -192,7 +192,9 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } deleteRes := &keyapi.PerformDeleteKeysResponse{} - a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes) + if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + return err + } if err := deleteRes.Error; err != nil { return fmt.Errorf("a.KeyAPI.PerformDeleteKeys: %w", err) } @@ -211,10 +213,12 @@ func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) er } var uploadRes keyapi.PerformUploadKeysResponse - a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, - }, &uploadRes) + }, &uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return fmt.Errorf("failed to delete device keys: %v", uploadRes.Error) } @@ -268,7 +272,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key var uploadRes keyapi.PerformUploadKeysResponse - a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ UserID: req.RequestingUserID, DeviceKeys: []keyapi.DeviceKeys{ { @@ -279,7 +283,9 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf }, }, OnlyDisplayNameUpdates: true, - }, &uploadRes) + }, &uploadRes); err != nil { + return err + } if uploadRes.Error != nil { return fmt.Errorf("failed to update device key display name: %v", uploadRes.Error) } @@ -479,7 +485,9 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} - a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes) + if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { + return err + } if err := evacuateRes.Error; err != nil { logrus.WithError(err).Errorf("Failed to evacuate user after account deactivation") } @@ -538,9 +546,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform if req.Version == "" { res.BadInput = true res.Error = "must specify a version to delete" - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version) @@ -549,9 +554,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = exists res.Version = req.Version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Create metadata @@ -562,9 +564,6 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Update metadata @@ -575,16 +574,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = req.Version - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } // Upload Keys for a specific version metadata a.uploadBackupKeys(ctx, req, res) - if res.Error != "" { - return fmt.Errorf(res.Error) - } return nil } @@ -627,16 +620,16 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform res.KeyETag = etag } -func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { +func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error { version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version) res.Version = version if err != nil { if err == sql.ErrNoRows { res.Exists = false - return + return nil } res.Error = fmt.Sprintf("failed to query key backup: %s", err) - return + return nil } res.Algorithm = algorithm res.AuthData = authData @@ -648,15 +641,16 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB if err != nil { res.Error = fmt.Sprintf("failed to count keys: %s", err) } - return + return nil } result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID) if err != nil { res.Error = fmt.Sprintf("failed to query keys: %s", err) - return + return nil } res.Keys = result + return nil } func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 23c335cf2..a375d6caa 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -21,7 +21,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) // HTTP paths for the internal HTTP APIs @@ -84,11 +83,10 @@ type httpUserInternalAPI struct { } func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") - defer span.Finish() - - apiURL := h.apiURL + InputAccountDataPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "InputAccountData", h.apiURL+InputAccountDataPath, + h.httpClient, ctx, req, res, + ) } func (h *httpUserInternalAPI) PerformAccountCreation( @@ -96,11 +94,10 @@ func (h *httpUserInternalAPI) PerformAccountCreation( request *api.PerformAccountCreationRequest, response *api.PerformAccountCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformAccountCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformAccountCreation", h.apiURL+PerformAccountCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPasswordUpdate( @@ -108,11 +105,10 @@ func (h *httpUserInternalAPI) PerformPasswordUpdate( request *api.PerformPasswordUpdateRequest, response *api.PerformPasswordUpdateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") - defer span.Finish() - - apiURL := h.apiURL + PerformPasswordUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPasswordUpdate", h.apiURL+PerformPasswordUpdatePath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformDeviceCreation( @@ -120,11 +116,10 @@ func (h *httpUserInternalAPI) PerformDeviceCreation( request *api.PerformDeviceCreationRequest, response *api.PerformDeviceCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDeviceCreation", h.apiURL+PerformDeviceCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformDeviceDeletion( @@ -132,47 +127,54 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( request *api.PerformDeviceDeletionRequest, response *api.PerformDeviceDeletionResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformDeviceDeletion", h.apiURL+PerformDeviceDeletionPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformLastSeenUpdate( ctx context.Context, - req *api.PerformLastSeenUpdateRequest, - res *api.PerformLastSeenUpdateResponse, + request *api.PerformLastSeenUpdateRequest, + response *api.PerformLastSeenUpdateResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLastSeen") - defer span.Finish() - - apiURL := h.apiURL + PerformLastSeenUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) + return httputil.CallInternalRPCAPI( + "PerformLastSeen", h.apiURL+PerformLastSeenUpdatePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDeviceUpdate") - defer span.Finish() - - apiURL := h.apiURL + PerformDeviceUpdatePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformDeviceUpdate( + ctx context.Context, + request *api.PerformDeviceUpdateRequest, + response *api.PerformDeviceUpdateResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformDeviceUpdate", h.apiURL+PerformDeviceUpdatePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformAccountDeactivation") - defer span.Finish() - - apiURL := h.apiURL + PerformAccountDeactivationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformAccountDeactivation( + ctx context.Context, + request *api.PerformAccountDeactivationRequest, + response *api.PerformAccountDeactivationResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAccountDeactivation", h.apiURL+PerformAccountDeactivationPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformOpenIDTokenCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +func (h *httpUserInternalAPI) PerformOpenIDTokenCreation( + ctx context.Context, + request *api.PerformOpenIDTokenCreationRequest, + response *api.PerformOpenIDTokenCreationResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformOpenIDTokenCreation", h.apiURL+PerformOpenIDTokenCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryProfile( @@ -180,11 +182,10 @@ func (h *httpUserInternalAPI) QueryProfile( request *api.QueryProfileRequest, response *api.QueryProfileResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryProfile") - defer span.Finish() - - apiURL := h.apiURL + QueryProfilePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryProfile", h.apiURL+QueryProfilePath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryDeviceInfos( @@ -192,11 +193,10 @@ func (h *httpUserInternalAPI) QueryDeviceInfos( request *api.QueryDeviceInfosRequest, response *api.QueryDeviceInfosResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos") - defer span.Finish() - - apiURL := h.apiURL + QueryDeviceInfosPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryDeviceInfos", h.apiURL+QueryDeviceInfosPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryAccessToken( @@ -204,72 +204,87 @@ func (h *httpUserInternalAPI) QueryAccessToken( request *api.QueryAccessTokenRequest, response *api.QueryAccessTokenResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccessToken") - defer span.Finish() - - apiURL := h.apiURL + QueryAccessTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryAccessToken", h.apiURL+QueryAccessTokenPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDevices") - defer span.Finish() - - apiURL := h.apiURL + QueryDevicesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryDevices( + ctx context.Context, + request *api.QueryDevicesRequest, + response *api.QueryDevicesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryDevices", h.apiURL+QueryDevicesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAccountDataRequest, res *api.QueryAccountDataResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryAccountData") - defer span.Finish() - - apiURL := h.apiURL + QueryAccountDataPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountData( + ctx context.Context, + request *api.QueryAccountDataRequest, + response *api.QueryAccountDataResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountData", h.apiURL+QueryAccountDataPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles") - defer span.Finish() - - apiURL := h.apiURL + QuerySearchProfilesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QuerySearchProfiles( + ctx context.Context, + request *api.QuerySearchProfilesRequest, + response *api.QuerySearchProfilesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QuerySearchProfiles", h.apiURL+QuerySearchProfilesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken") - defer span.Finish() - - apiURL := h.apiURL + QueryOpenIDTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryOpenIDToken( + ctx context.Context, + request *api.QueryOpenIDTokenRequest, + response *api.QueryOpenIDTokenResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryOpenIDToken", h.apiURL+QueryOpenIDTokenPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") - defer span.Finish() - - apiURL := h.apiURL + PerformKeyBackupPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = err.Error() - } - return nil -} -func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") - defer span.Finish() - - apiURL := h.apiURL + QueryKeyBackupPath - err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - if err != nil { - res.Error = err.Error() - } +func (h *httpUserInternalAPI) PerformKeyBackup( + ctx context.Context, + request *api.PerformKeyBackupRequest, + response *api.PerformKeyBackupResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformKeyBackup", h.apiURL+PerformKeyBackupPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") - defer span.Finish() +func (h *httpUserInternalAPI) QueryKeyBackup( + ctx context.Context, + request *api.QueryKeyBackupRequest, + response *api.QueryKeyBackupResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryKeyBackup", h.apiURL+QueryKeyBackupPath, + h.httpClient, ctx, request, response, + ) +} - return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +func (h *httpUserInternalAPI) QueryNotifications( + ctx context.Context, + request *api.QueryNotificationsRequest, + response *api.QueryNotificationsResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryNotifications", h.apiURL+QueryNotificationsPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPusherSet( @@ -277,27 +292,32 @@ func (h *httpUserInternalAPI) PerformPusherSet( request *api.PerformPusherSetRequest, response *struct{}, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") - defer span.Finish() - - apiURL := h.apiURL + PerformPusherSetPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPusherSet", h.apiURL+PerformPusherSetPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformPusherDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformPusherDeletion( + ctx context.Context, + request *api.PerformPusherDeletionRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformPusherDeletion", h.apiURL+PerformPusherDeletionPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") - defer span.Finish() - - apiURL := h.apiURL + QueryPushersPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryPushers( + ctx context.Context, + request *api.QueryPushersRequest, + response *api.QueryPushersResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryPushers", h.apiURL+QueryPushersPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformPushRulesPut( @@ -305,89 +325,117 @@ func (h *httpUserInternalAPI) PerformPushRulesPut( request *api.PerformPushRulesPutRequest, response *struct{}, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") - defer span.Finish() - - apiURL := h.apiURL + PerformPushRulesPutPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformPushRulesPut", h.apiURL+PerformPushRulesPutPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") - defer span.Finish() - - apiURL := h.apiURL + QueryPushRulesPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryPushRules( + ctx context.Context, + request *api.QueryPushRulesRequest, + response *api.QueryPushRulesResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryPushRules", h.apiURL+QueryPushRulesPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetAvatarURLPath) - defer span.Finish() - - apiURL := h.apiURL + PerformSetAvatarURLPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) SetAvatarURL( + ctx context.Context, + request *api.PerformSetAvatarURLRequest, + response *api.PerformSetAvatarURLResponse, +) error { + return httputil.CallInternalRPCAPI( + "SetAvatarURL", h.apiURL+PerformSetAvatarURLPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryNumericLocalpartPath) - defer span.Finish() - - apiURL := h.apiURL + QueryNumericLocalpartPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, struct{}{}, res) +func (h *httpUserInternalAPI) QueryNumericLocalpart( + ctx context.Context, + response *api.QueryNumericLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, + h.httpClient, ctx, &struct{}{}, response, + ) } -func (h *httpUserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountAvailabilityPath) - defer span.Finish() - - apiURL := h.apiURL + QueryAccountAvailabilityPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountAvailability( + ctx context.Context, + request *api.QueryAccountAvailabilityRequest, + response *api.QueryAccountAvailabilityResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountAvailability", h.apiURL+QueryAccountAvailabilityPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryAccountByPasswordPath) - defer span.Finish() - - apiURL := h.apiURL + QueryAccountByPasswordPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryAccountByPassword( + ctx context.Context, + request *api.QueryAccountByPasswordRequest, + response *api.QueryAccountByPasswordResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByPassword", h.apiURL+QueryAccountByPasswordPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSetDisplayNamePath) - defer span.Finish() - - apiURL := h.apiURL + PerformSetDisplayNamePath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) SetDisplayName( + ctx context.Context, + request *api.PerformUpdateDisplayNameRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "SetDisplayName", h.apiURL+PerformSetDisplayNamePath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryLocalpartForThreePIDPath) - defer span.Finish() - - apiURL := h.apiURL + QueryLocalpartForThreePIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryLocalpartForThreePID( + ctx context.Context, + request *api.QueryLocalpartForThreePIDRequest, + response *api.QueryLocalpartForThreePIDResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryLocalpartForThreePID", h.apiURL+QueryLocalpartForThreePIDPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - span, ctx := opentracing.StartSpanFromContext(ctx, QueryThreePIDsForLocalpartPath) - defer span.Finish() - - apiURL := h.apiURL + QueryThreePIDsForLocalpartPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) QueryThreePIDsForLocalpart( + ctx context.Context, + request *api.QueryThreePIDsForLocalpartRequest, + response *api.QueryThreePIDsForLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryThreePIDsForLocalpart", h.apiURL+QueryThreePIDsForLocalpartPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.PerformForgetThreePIDRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformForgetThreePIDPath) - defer span.Finish() - - apiURL := h.apiURL + PerformForgetThreePIDPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformForgetThreePID( + ctx context.Context, + request *api.PerformForgetThreePIDRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformForgetThreePID", h.apiURL+PerformForgetThreePIDPath, + h.httpClient, ctx, request, response, + ) } -func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, PerformSaveThreePIDAssociationPath) - defer span.Finish() - - apiURL := h.apiURL + PerformSaveThreePIDAssociationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( + ctx context.Context, + request *api.PerformSaveThreePIDAssociationRequest, + response *struct{}, +) error { + return httputil.CallInternalRPCAPI( + "PerformSaveThreePIDAssociation", h.apiURL+PerformSaveThreePIDAssociationPath, + h.httpClient, ctx, request, response, + ) } diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go index 366a97099..211b1b7a1 100644 --- a/userapi/inthttp/client_logintoken.go +++ b/userapi/inthttp/client_logintoken.go @@ -19,7 +19,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/opentracing/opentracing-go" ) const ( @@ -33,11 +32,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenCreation( request *api.PerformLoginTokenCreationRequest, response *api.PerformLoginTokenCreationResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") - defer span.Finish() - - apiURL := h.apiURL + PerformLoginTokenCreationPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLoginTokenCreation", h.apiURL+PerformLoginTokenCreationPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) PerformLoginTokenDeletion( @@ -45,11 +43,10 @@ func (h *httpUserInternalAPI) PerformLoginTokenDeletion( request *api.PerformLoginTokenDeletionRequest, response *api.PerformLoginTokenDeletionResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") - defer span.Finish() - - apiURL := h.apiURL + PerformLoginTokenDeletionPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "PerformLoginTokenDeletion", h.apiURL+PerformLoginTokenDeletionPath, + h.httpClient, ctx, request, response, + ) } func (h *httpUserInternalAPI) QueryLoginToken( @@ -57,9 +54,8 @@ func (h *httpUserInternalAPI) QueryLoginToken( request *api.QueryLoginTokenRequest, response *api.QueryLoginTokenResponse, ) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") - defer span.Finish() - - apiURL := h.apiURL + QueryLoginTokenPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + return httputil.CallInternalRPCAPI( + "QueryLoginToken", h.apiURL+QueryLoginTokenPath, + h.httpClient, ctx, request, response, + ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index ad532b901..99148b760 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,8 +15,6 @@ package inthttp import ( - "encoding/json" - "fmt" "net/http" "github.com/gorilla/mux" @@ -29,339 +27,134 @@ import ( func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { addRoutesLoginToken(internalAPIMux, s) - internalAPIMux.Handle(PerformAccountCreationPath, - httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformAccountCreationRequest{} - response := api.PerformAccountCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformAccountCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformPasswordUpdatePath, - httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformPasswordUpdateRequest{} - response := api.PerformPasswordUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceCreationPath, - httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceCreationRequest{} - response := api.PerformDeviceCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformLastSeenUpdatePath, - httputil.MakeInternalAPI("performLastSeenUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformLastSeenUpdateRequest{} - response := api.PerformLastSeenUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLastSeenUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceUpdatePath, - httputil.MakeInternalAPI("performDeviceUpdate", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceUpdateRequest{} - response := api.PerformDeviceUpdateResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceUpdate(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformDeviceDeletionPath, - httputil.MakeInternalAPI("performDeviceDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformDeviceDeletionRequest{} - response := api.PerformDeviceDeletionResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformDeviceDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformAccountDeactivationPath, - httputil.MakeInternalAPI("performAccountDeactivation", func(req *http.Request) util.JSONResponse { - request := api.PerformAccountDeactivationRequest{} - response := api.PerformAccountDeactivationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformAccountDeactivation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformOpenIDTokenCreationPath, - httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformOpenIDTokenCreationRequest{} - response := api.PerformOpenIDTokenCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryProfilePath, - httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse { - request := api.QueryProfileRequest{} - response := api.QueryProfileResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryProfile(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryAccessTokenPath, - httputil.MakeInternalAPI("queryAccessToken", func(req *http.Request) util.JSONResponse { - request := api.QueryAccessTokenRequest{} - response := api.QueryAccessTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccessToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryDevicesPath, - httputil.MakeInternalAPI("queryDevices", func(req *http.Request) util.JSONResponse { - request := api.QueryDevicesRequest{} - response := api.QueryDevicesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryDevices(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryAccountDataPath, - httputil.MakeInternalAPI("queryAccountData", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountDataRequest{} - response := api.QueryAccountDataResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountData(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryDeviceInfosPath, - httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse { - request := api.QueryDeviceInfosRequest{} - response := api.QueryDeviceInfosResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QuerySearchProfilesPath, - httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse { - request := api.QuerySearchProfilesRequest{} - response := api.QuerySearchProfilesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryOpenIDTokenPath, - httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse { - request := api.QueryOpenIDTokenRequest{} - response := api.QueryOpenIDTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(InputAccountDataPath, - httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse { - request := api.InputAccountDataRequest{} - response := api.InputAccountDataResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.InputAccountData(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryKeyBackupPath, - httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse { - request := api.QueryKeyBackupRequest{} - response := api.QueryKeyBackupResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - s.QueryKeyBackup(req.Context(), &request, &response) - if response.Error != "" { - return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error)) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformKeyBackupPath, - httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse { - request := api.PerformKeyBackupRequest{} - response := api.PerformKeyBackupResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - err := s.PerformKeyBackup(req.Context(), &request, &response) - if err != nil { - return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(QueryNotificationsPath, - httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { - var request api.QueryNotificationsRequest - var response api.QueryNotificationsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformAccountCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation), ) - internalAPIMux.Handle(PerformPusherSetPath, - httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { - request := api.PerformPusherSetRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) - internalAPIMux.Handle(PerformPusherDeletionPath, - httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformPusherDeletionRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformPasswordUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate), ) - internalAPIMux.Handle(QueryPushersPath, - httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { - request := api.QueryPushersRequest{} - response := api.QueryPushersResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryPushers(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformDeviceCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation), ) - internalAPIMux.Handle(PerformPushRulesPutPath, - httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { - request := api.PerformPushRulesPutRequest{} - response := struct{}{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformLastSeenUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate), ) - internalAPIMux.Handle(QueryPushRulesPath, - httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { - request := api.QueryPushRulesRequest{} - response := api.QueryPushRulesResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformDeviceUpdatePath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate), ) - internalAPIMux.Handle(PerformSetAvatarURLPath, - httputil.MakeInternalAPI("performSetAvatarURL", func(req *http.Request) util.JSONResponse { - request := api.PerformSetAvatarURLRequest{} - response := api.PerformSetAvatarURLResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.SetAvatarURL(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformDeviceDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion), ) + + internalAPIMux.Handle( + PerformAccountDeactivationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation), + ) + + internalAPIMux.Handle( + PerformOpenIDTokenCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation), + ) + + internalAPIMux.Handle( + QueryProfilePath, + httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile), + ) + + internalAPIMux.Handle( + QueryAccessTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken), + ) + + internalAPIMux.Handle( + QueryDevicesPath, + httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices), + ) + + internalAPIMux.Handle( + QueryAccountDataPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData), + ) + + internalAPIMux.Handle( + QueryDeviceInfosPath, + httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos), + ) + + internalAPIMux.Handle( + QuerySearchProfilesPath, + httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles), + ) + + internalAPIMux.Handle( + QueryOpenIDTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken), + ) + + internalAPIMux.Handle( + InputAccountDataPath, + httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData), + ) + + internalAPIMux.Handle( + QueryKeyBackupPath, + httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup), + ) + + internalAPIMux.Handle( + PerformKeyBackupPath, + httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup), + ) + + internalAPIMux.Handle( + QueryNotificationsPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications), + ) + + internalAPIMux.Handle( + PerformPusherSetPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet), + ) + + internalAPIMux.Handle( + PerformPusherDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion), + ) + + internalAPIMux.Handle( + QueryPushersPath, + httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers), + ) + + internalAPIMux.Handle( + PerformPushRulesPutPath, + httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut), + ) + + internalAPIMux.Handle( + QueryPushRulesPath, + httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules), + ) + + internalAPIMux.Handle( + PerformSetAvatarURLPath, + httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), + ) + + // TODO: Look at the shape of this internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("queryNumericLocalpart", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { response := api.QueryNumericLocalpartResponse{} if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { return util.ErrorResponse(err) @@ -369,92 +162,39 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle(QueryAccountAvailabilityPath, - httputil.MakeInternalAPI("queryAccountAvailability", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountAvailabilityRequest{} - response := api.QueryAccountAvailabilityResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountAvailability(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryAccountAvailabilityPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability), ) - internalAPIMux.Handle(QueryAccountByPasswordPath, - httputil.MakeInternalAPI("queryAccountByPassword", func(req *http.Request) util.JSONResponse { - request := api.QueryAccountByPasswordRequest{} - response := api.QueryAccountByPasswordResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryAccountByPassword(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryAccountByPasswordPath, + httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword), ) - internalAPIMux.Handle(PerformSetDisplayNamePath, - httputil.MakeInternalAPI("performSetDisplayName", func(req *http.Request) util.JSONResponse { - request := api.PerformUpdateDisplayNameRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.SetDisplayName(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformSetDisplayNamePath, + httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName), ) - internalAPIMux.Handle(QueryLocalpartForThreePIDPath, - httputil.MakeInternalAPI("queryLocalpartForThreePID", func(req *http.Request) util.JSONResponse { - request := api.QueryLocalpartForThreePIDRequest{} - response := api.QueryLocalpartForThreePIDResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryLocalpartForThreePID(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryLocalpartForThreePIDPath, + httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID), ) - internalAPIMux.Handle(QueryThreePIDsForLocalpartPath, - httputil.MakeInternalAPI("queryThreePIDsForLocalpart", func(req *http.Request) util.JSONResponse { - request := api.QueryThreePIDsForLocalpartRequest{} - response := api.QueryThreePIDsForLocalpartResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryThreePIDsForLocalpart(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryThreePIDsForLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart), ) - internalAPIMux.Handle(PerformForgetThreePIDPath, - httputil.MakeInternalAPI("performForgetThreePID", func(req *http.Request) util.JSONResponse { - request := api.PerformForgetThreePIDRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformForgetThreePID(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformForgetThreePIDPath, + httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID), ) - internalAPIMux.Handle(PerformSaveThreePIDAssociationPath, - httputil.MakeInternalAPI("performSaveThreePIDAssociation", func(req *http.Request) util.JSONResponse { - request := api.PerformSaveThreePIDAssociationRequest{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformSaveThreePIDAssociation(req.Context(), &request, &struct{}{}); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &struct{}{}} - }), + + internalAPIMux.Handle( + PerformSaveThreePIDAssociationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation), ) } diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go index 1f2eb34b9..b57348413 100644 --- a/userapi/inthttp/server_logintoken.go +++ b/userapi/inthttp/server_logintoken.go @@ -15,54 +15,25 @@ package inthttp import ( - "encoding/json" - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // addRoutesLoginToken adds routes for all login token API calls. func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { - internalAPIMux.Handle(PerformLoginTokenCreationPath, - httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { - request := api.PerformLoginTokenCreationRequest{} - response := api.PerformLoginTokenCreationResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + PerformLoginTokenCreationPath, + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation), ) - internalAPIMux.Handle(PerformLoginTokenDeletionPath, - httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { - request := api.PerformLoginTokenDeletionRequest{} - response := api.PerformLoginTokenDeletionResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + PerformLoginTokenDeletionPath, + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion), ) - internalAPIMux.Handle(QueryLoginTokenPath, - httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { - request := api.QueryLoginTokenRequest{} - response := api.QueryLoginTokenResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + + internalAPIMux.Handle( + QueryLoginTokenPath, + httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken), ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 40e37c5d6..31a69793b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -117,16 +117,20 @@ func TestQueryProfile(t *testing.T) { }, } - runCases := func(testAPI api.UserInternalAPI) { + runCases := func(testAPI api.UserInternalAPI, http bool) { + mode := "monolith" + if http { + mode = "HTTP" + } for _, tc := range testCases { var gotRes api.QueryProfileResponse gotErr := testAPI.QueryProfile(context.TODO(), &tc.req, &gotRes) if tc.wantErr == nil && gotErr != nil || tc.wantErr != nil && gotErr == nil { - t.Errorf("QueryProfile error, got %s want %s", gotErr, tc.wantErr) + t.Errorf("QueryProfile %s error, got %s want %s", mode, gotErr, tc.wantErr) continue } if !reflect.DeepEqual(tc.wantRes, gotRes) { - t.Errorf("QueryProfile response got %+v want %+v", gotRes, tc.wantRes) + t.Errorf("QueryProfile %s response got %+v want %+v", mode, gotRes, tc.wantRes) } } } @@ -140,10 +144,10 @@ func TestQueryProfile(t *testing.T) { if err != nil { t.Fatalf("failed to create HTTP client") } - runCases(httpAPI) + runCases(httpAPI, true) }) t.Run("Monolith", func(t *testing.T) { - runCases(userAPI) + runCases(userAPI, false) }) } From 25ad40f5e56ea3c84a2bacba3327d6ac65d80e01 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Aug 2022 16:06:13 +0100 Subject: [PATCH 4/9] Remove test from `sytest-blacklist` --- sytest-blacklist | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sytest-blacklist b/sytest-blacklist index e0b2767b1..bcc345f6e 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -49,7 +49,3 @@ Notifications can be viewed with GET /notifications If remote user leaves room we no longer receive device updates Guest users can join guest_access rooms - -# You'll be shocked to discover this is flakey too - -Inbound /v1/send_join rejects joins from other servers From 371336c6b5ffd510802d06b193a48b01a5e78d0c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Aug 2022 16:31:44 +0100 Subject: [PATCH 5/9] Set default room version to 9 --- roomserver/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 1f66995d8..729d00a80 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV6 + return gomatrixserverlib.RoomVersionV9 } // RoomVersions returns a map of all known room versions to this From 05cafbd197c99c0e116c9b61447e70ba5af992a3 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 11 Aug 2022 18:23:35 +0200 Subject: [PATCH 6/9] Implement history visibility on `/messages`, `/context`, `/sync` (#2511) * Add possibility to set history_visibility and user AccountType * Add new DB queries * Add actual history_visibility changes for /messages * Add passing tests * Extract check function * Cleanup * Cleanup * Fix build on 386 * Move ApplyHistoryVisibilityFilter to internal * Move queries to topology table * Add filtering to /sync and /context Some cleanup * Add passing tests; Remove failing tests :( * Re-add passing tests * Move filtering to own function to avoid duplication * Re-add passing test * Use newly added GMSL HistoryVisibility * Update gomatrixserverlib * Set the visibility when creating events * Default to shared history visibility * Remove unused query * Update history visibility checks to use gmsl Update tests * Remove unused statement * Update migrations to set "correct" history visibility * Add method to fetch the membership at a given event * Tweaks and logging * Use actual internal rsAPI, default to shared visibility in tests * Revert "Move queries to topology table" This reverts commit 4f0d41be9c194a46379796435ce73e79203edbd6. * Remove noise/unneeded code * More cleanup * Try to optimize database requests * Fix imports * PR peview fixes/changes * Move setting history visibility to own migration, be more restrictive * Fix unit tests * Lint * Fix missing entries * Tweaks for incremental syncs * Adapt generic changes Co-authored-by: Neil Alexander Co-authored-by: kegsay --- go.mod | 8 +- go.sum | 16 +- roomserver/api/api.go | 8 + roomserver/api/api_trace.go | 10 + roomserver/api/query.go | 14 ++ roomserver/internal/helpers/helpers.go | 6 + roomserver/internal/input/input_events.go | 4 +- roomserver/internal/query/query.go | 48 ++++ roomserver/inthttp/client.go | 12 +- roomserver/inthttp/server.go | 5 + roomserver/state/state.go | 60 ++++- syncapi/internal/history_visibility.go | 217 ++++++++++++++++++ syncapi/routing/context.go | 99 ++++++-- syncapi/routing/messages.go | 100 ++------ syncapi/storage/interface.go | 4 + ...2061412000000_history_visibility_column.go | 55 +++++ syncapi/storage/postgres/memberships_table.go | 28 ++- .../postgres/output_room_events_table.go | 10 +- syncapi/storage/postgres/syncserver.go | 15 ++ syncapi/storage/shared/syncserver.go | 49 ++-- ...2061412000000_history_visibility_column.go | 55 +++++ syncapi/storage/sqlite3/memberships_table.go | 22 ++ .../sqlite3/output_room_events_table.go | 10 +- syncapi/storage/sqlite3/syncserver.go | 19 +- syncapi/storage/storage_test.go | 20 +- syncapi/storage/tables/interface.go | 1 + syncapi/streams/stream_pdu.go | 123 ++++++---- syncapi/syncapi_test.go | 186 ++++++++++++++- sytest-whitelist | 25 +- test/room.go | 28 ++- test/user.go | 10 +- 31 files changed, 1043 insertions(+), 224 deletions(-) create mode 100644 syncapi/internal/history_visibility.go diff --git a/go.mod b/go.mod index 3559c5bb1..79ef5c3e8 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 + github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.13 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.12.2 - github.com/sirupsen/logrus v1.8.1 + github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.7.1 github.com/tidwall/gjson v1.14.1 github.com/tidwall/sjson v1.2.4 @@ -42,7 +42,7 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.3 go.uber.org/atomic v1.9.0 - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 golang.org/x/mobile v0.0.0-20220518205345-8578da9835fd golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e @@ -99,7 +99,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect - golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect + golang.org/x/sys v0.0.0-20220731174439-a90be440212d // indirect golang.org/x/text v0.3.8-0.20211004125949-5bd84dd9b33b // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/tools v0.1.10 // indirect diff --git a/go.sum b/go.sum index 2c8bb4f18..9d5c50d2c 100644 --- a/go.sum +++ b/go.sum @@ -343,8 +343,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771 h1:ZIPHFIPNDS9dmEbPEiJbNmyCGJtn9exfpLC7JOcn/bE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20220725104114-b6003e522771/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839 h1:QEFxKWH8PlEt3ZQKl31yJNAm8lvpNUwT51IMNTl9v1k= +github.com/matrix-org/gomatrixserverlib v0.0.0-20220801083850-5ff38e2c2839/go.mod h1:jX38yp3SSLJNftBg3PXU1ayd0PCLIiDHQ4xAc9DIixk= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9 h1:ed8yvWhTLk7+sNeK/eOZRTvESFTOHDRevoRoyeqPtvY= github.com/matrix-org/pinecone v0.0.0-20220803093810-b7a830c08fb9/go.mod h1:P4MqPf+u83OPulPJ+XTbSDbbWrdFYNY4LZ/B1PIduFE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= @@ -493,8 +493,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= @@ -569,8 +569,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -748,8 +748,10 @@ golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211102192858-4dd72447c267/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220405052023-b1e9470b6e64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM= diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ee0212ecf..baf63aa31 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -97,6 +97,14 @@ type SyncRoomserverAPI interface { req *PerformBackfillRequest, res *PerformBackfillResponse, ) error + + // QueryMembershipAtEvent queries the memberships at the given events. + // Returns a map from eventID to a slice of gomatrixserverlib.HeaderedEvent. + QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembershipAtEventRequest, + response *QueryMembershipAtEventResponse, + ) error } type AppserviceRoomserverAPI interface { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 18a617331..8bef35379 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -382,6 +382,16 @@ func (t *RoomserverInternalAPITrace) QueryRestrictedJoinAllowed( return err } +func (t *RoomserverInternalAPITrace) QueryMembershipAtEvent( + ctx context.Context, + request *QueryMembershipAtEventRequest, + response *QueryMembershipAtEventResponse, +) error { + err := t.Impl.QueryMembershipAtEvent(ctx, request, response) + util.GetLogger(ctx).WithError(err).Infof("QueryMembershipAtEvent req=%+v res=%+v", js(request), js(response)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index f157a9025..c8e6f9dc6 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -427,3 +427,17 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { } return nil } + +// QueryMembershipAtEventRequest requests the membership events for a user +// for a list of eventIDs. +type QueryMembershipAtEventRequest struct { + RoomID string + EventIDs []string + UserID string +} + +// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. +type QueryMembershipAtEventResponse struct { + // Memberships is a map from eventID to a list of events (if any). + Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` +} diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index d61aa08cb..6091f8ec2 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -208,6 +208,12 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } +func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) + // Fetch the state as it was when this event was fired + return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) +} + func LoadEvents( ctx context.Context, db storage.Database, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 866670d7a..81541260c 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -299,7 +299,7 @@ func (r *Inputer) processRoomEvent( // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up // burning CPU time. - historyVisibility := gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. + historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. if rejectionErr == nil && !isRejected && !softfail { var err error historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) @@ -429,7 +429,7 @@ func (r *Inputer) processStateBefore( input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { - historyVisibility = gomatrixserverlib.HistoryVisibilityJoined // Default to restrictive. + historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared. event := input.Event.Unwrap() isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") var stateBeforeEvent []*gomatrixserverlib.Event diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 5ba59b8f5..c41e1ea67 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -204,6 +204,54 @@ func (r *Queryer) QueryMembershipForUser( return err } +func (r *Queryer) QueryMembershipAtEvent( + ctx context.Context, + request *api.QueryMembershipAtEventRequest, + response *api.QueryMembershipAtEventResponse, +) error { + response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return fmt.Errorf("unable to get roomInfo: %w", err) + } + if info == nil { + return fmt.Errorf("no roomInfo found") + } + + // get the users stateKeyNID + stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) + if err != nil { + return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) + } + if _, ok := stateKeyNIDs[request.UserID]; !ok { + return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) + } + + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + if err != nil { + return fmt.Errorf("unable to get state before event: %w", err) + } + + for _, eventID := range request.EventIDs { + stateEntry := stateEntries[eventID] + memberships, err := helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + if err != nil { + return fmt.Errorf("unable to get memberships at state: %w", err) + } + res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) + + for i := range memberships { + ev := memberships[i] + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { + res = append(res, ev.Headered(info.RoomVersion)) + } + } + response.Memberships[eventID] = res + } + + return nil +} + // QueryMembershipsForRoom implements api.RoomserverInternalAPI func (r *Queryer) QueryMembershipsForRoom( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index d16f67c69..e9387bd97 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -5,14 +5,14 @@ import ( "errors" "net/http" + "github.com/matrix-org/gomatrixserverlib" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsInputAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" - - "github.com/matrix-org/gomatrixserverlib" ) const ( @@ -61,6 +61,7 @@ const ( RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" + RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" ) type httpRoomserverInternalAPI struct { @@ -529,3 +530,10 @@ func (h *httpRoomserverInternalAPI) PerformForget( ) } + +func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse) error { + return httputil.CallInternalRPCAPI( + "QueryMembershiptAtEvent", h.roomserverURL+RoomserverQueryMembershipAtEventPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index e325d76a5..3b688174a 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -2,6 +2,7 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver/api" ) @@ -193,4 +194,8 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { RoomserverQueryRestrictedJoinAllowed, httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), ) + internalAPIMux.Handle( + RoomserverQueryMembershipAtEventPath, + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), + ) } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index ca0c69f27..a40a2e9ba 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -23,12 +23,11 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" - - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) type StateResolutionStorage interface { @@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent( return stateEntries, nil } +func (v *StateResolution) LoadMembershipAtEvent( + ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID, +) (map[string][]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent") + defer span.Finish() + + // De-dupe snapshotNIDs + snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs + for i := range eventIDs { + eventID := eventIDs[i] + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) + if err != nil { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) + } + snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID) + } + + snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap)) + for nid := range snapshotNIDMap { + snapshotNIDs = append(snapshotNIDs, nid) + } + + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs) + if err != nil { + return nil, err + } + + result := make(map[string][]types.StateEntry) + for _, stateBlockNIDList := range stateBlockNIDLists { + // Query the membership event for the user at the given stateblocks + stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{ + { + EventTypeNID: types.MRoomMemberNID, + EventStateKeyNID: stateKeyNID, + }, + }) + if err != nil { + return nil, err + } + + evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID] + + for _, evID := range evIDs { + for _, x := range stateEntryLists { + result[evID] = append(result[evID], x.StateEntries...) + } + } + } + + return result, nil +} + // LoadStateAtEvent loads the full state of a room before a particular event. func (v *StateResolution) LoadStateAtEventForHistoryVisibility( ctx context.Context, eventID string, diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go new file mode 100644 index 000000000..e73c004e5 --- /dev/null +++ b/syncapi/internal/history_visibility.go @@ -0,0 +1,217 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "math" + "time" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/prometheus/client_golang/prometheus" + "github.com/tidwall/gjson" +) + +func init() { + prometheus.MustRegister(calculateHistoryVisibilityDuration) +} + +// calculateHistoryVisibilityDuration stores the time it takes to +// calculate the history visibility. In polylith mode the roundtrip +// to the roomserver is included in this time. +var calculateHistoryVisibilityDuration = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "dendrite", + Subsystem: "syncapi", + Name: "calculateHistoryVisibility_duration_millis", + Help: "How long it takes to calculate the history visibility", + Buckets: []float64{ // milliseconds + 5, 10, 25, 50, 75, 100, 250, 500, + 1000, 2000, 3000, 4000, 5000, 6000, + 7000, 8000, 9000, 10000, 15000, 20000, + }, + }, + []string{"api"}, +) + +var historyVisibilityPriority = map[gomatrixserverlib.HistoryVisibility]uint8{ + gomatrixserverlib.WorldReadable: 0, + gomatrixserverlib.HistoryVisibilityShared: 1, + gomatrixserverlib.HistoryVisibilityInvited: 2, + gomatrixserverlib.HistoryVisibilityJoined: 3, +} + +// eventVisibility contains the history visibility and membership state at a given event +type eventVisibility struct { + visibility gomatrixserverlib.HistoryVisibility + membershipAtEvent string + membershipCurrent string +} + +// allowed checks the eventVisibility if the user is allowed to see the event. +// Rules as defined by https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5 +func (ev eventVisibility) allowed() (allowed bool) { + switch ev.visibility { + case gomatrixserverlib.HistoryVisibilityWorldReadable: + // If the history_visibility was set to world_readable, allow. + return true + case gomatrixserverlib.HistoryVisibilityJoined: + // If the user’s membership was join, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join { + return true + } + return false + case gomatrixserverlib.HistoryVisibilityShared: + // If the user’s membership was join, allow. + // If history_visibility was set to shared, and the user joined the room at any point after the event was sent, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join || ev.membershipCurrent == gomatrixserverlib.Join { + return true + } + return false + case gomatrixserverlib.HistoryVisibilityInvited: + // If the user’s membership was join, allow. + if ev.membershipAtEvent == gomatrixserverlib.Join { + return true + } + if ev.membershipAtEvent == gomatrixserverlib.Invite { + return true + } + return false + default: + return false + } +} + +// ApplyHistoryVisibilityFilter applies the room history visibility filter on gomatrixserverlib.HeaderedEvents. +// Returns the filtered events and an error, if any. +func ApplyHistoryVisibilityFilter( + ctx context.Context, + syncDB storage.Database, + rsAPI api.SyncRoomserverAPI, + events []*gomatrixserverlib.HeaderedEvent, + alwaysIncludeEventIDs map[string]struct{}, + userID, endpoint string, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + if len(events) == 0 { + return events, nil + } + start := time.Now() + + // try to get the current membership of the user + membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) + if err != nil { + return nil, err + } + + // Get the mapping from eventID -> eventVisibility + eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) + visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) + if err != nil { + return eventsFiltered, err + } + for _, ev := range events { + evVis := visibilities[ev.EventID()] + evVis.membershipCurrent = membershipCurrent + // Always include specific state events for /sync responses + if alwaysIncludeEventIDs != nil { + if _, ok := alwaysIncludeEventIDs[ev.EventID()]; ok { + eventsFiltered = append(eventsFiltered, ev) + continue + } + } + // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules") + if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(userID) { + eventsFiltered = append(eventsFiltered, ev) + continue + } + // Always allow history evVis events on boundaries. This is done + // by setting the effective evVis to the least restrictive + // of the old vs new. + // https://spec.matrix.org/v1.3/client-server-api/#server-behaviour-5 + if hisVis, err := ev.HistoryVisibility(); err == nil { + prevHisVis := gjson.GetBytes(ev.Unsigned(), "prev_content.history_visibility").String() + oldPrio, ok := historyVisibilityPriority[gomatrixserverlib.HistoryVisibility(prevHisVis)] + // if we can't get the previous history visibility, default to shared. + if !ok { + oldPrio = historyVisibilityPriority[gomatrixserverlib.HistoryVisibilityShared] + } + // no OK check, since this should have been validated when setting the value + newPrio := historyVisibilityPriority[hisVis] + if oldPrio < newPrio { + evVis.visibility = gomatrixserverlib.HistoryVisibility(prevHisVis) + } + } + // do the actual check + allowed := evVis.allowed() + if allowed { + eventsFiltered = append(eventsFiltered, ev) + } + } + calculateHistoryVisibilityDuration.With(prometheus.Labels{"api": endpoint}).Observe(float64(time.Since(start).Milliseconds())) + return eventsFiltered, nil +} + +// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership +// of `userID` at the given event. +// Returns an error if the roomserver can't calculate the memberships. +func visibilityForEvents( + ctx context.Context, + rsAPI api.SyncRoomserverAPI, + events []*gomatrixserverlib.HeaderedEvent, + userID, roomID string, +) (map[string]eventVisibility, error) { + eventIDs := make([]string, len(events)) + for i := range events { + eventIDs[i] = events[i].EventID() + } + + result := make(map[string]eventVisibility, len(eventIDs)) + + // get the membership events for all eventIDs + membershipResp := &api.QueryMembershipAtEventResponse{} + err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ + RoomID: roomID, + EventIDs: eventIDs, + UserID: userID, + }, membershipResp) + if err != nil { + return result, err + } + + // Create a map from eventID -> eventVisibility + for _, event := range events { + eventID := event.EventID() + vis := eventVisibility{ + membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident + visibility: event.Visibility, + } + membershipEvs, ok := membershipResp.Memberships[eventID] + if !ok { + result[eventID] = vis + continue + } + for _, ev := range membershipEvs { + membership, err := ev.Membership() + if err != nil { + return result, err + } + vis.membershipAtEvent = membership + } + result[eventID] = vis + } + return result, nil +} diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index f6b4d15e0..13c4e9d89 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -21,10 +21,12 @@ import ( "fmt" "net/http" "strconv" + "time" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -95,24 +97,6 @@ func Context( ContainsURL: filter.ContainsURL, } - // TODO: Get the actual state at the last event returned by SelectContextAfterEvent - state, _ := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) - // verify the user is allowed to see the context for this room/event - for _, x := range state { - var hisVis gomatrixserverlib.HistoryVisibility - hisVis, err = x.HistoryVisibility() - if err != nil { - continue - } - allowed := hisVis == gomatrixserverlib.WorldReadable || membershipRes.Membership == gomatrixserverlib.Join - if !allowed { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("User is not allowed to query context"), - } - } - } - id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) if err != nil { if err == sql.ErrNoRows { @@ -125,6 +109,24 @@ func Context( return jsonerror.InternalServerError() } + // verify the user is allowed to see the context for this room/event + startTime := time.Now() + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + return jsonerror.InternalServerError() + } + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (context)") + if len(filteredEvents) == 0 { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("User is not allowed to query context"), + } + } + eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") @@ -137,8 +139,27 @@ func Context( return jsonerror.InternalServerError() } - eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBefore, gomatrixserverlib.FormatAll) - eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfter, gomatrixserverlib.FormatAll) + startTime = time.Now() + eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + return jsonerror.InternalServerError() + } + + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (context eventsBefore/eventsAfter)") + + // TODO: Get the actual state at the last event returned by SelectContextAfterEvent + state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + if err != nil { + logrus.WithError(err).Error("unable to fetch current room state") + return jsonerror.InternalServerError() + } + + eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) + eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) response := ContextRespsonse{ @@ -162,6 +183,44 @@ func Context( } } +// applyHistoryVisibilityOnContextEvents is a helper function to avoid roundtrips to the roomserver +// by combining the events before and after the context event. Returns the filtered events, +// and an error, if any. +func applyHistoryVisibilityOnContextEvents( + ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI, + eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, + userID string, +) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { + eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) + eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) + + // Remember before/after eventIDs, so we can restore them + // after applying history visibility checks + for _, ev := range eventsBefore { + eventIDsBefore[ev.EventID()] = struct{}{} + } + for _, ev := range eventsAfter { + eventIDsAfter[ev.EventID()] = struct{}{} + } + + allEvents := append(eventsBefore, eventsAfter...) + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context") + if err != nil { + return nil, nil, err + } + + // "Restore" events in the correct context + for _, ev := range filteredEvents { + if _, ok := eventIDsBefore[ev.EventID()]; ok { + filteredBefore = append(filteredBefore, ev) + } + if _, ok := eventIDsAfter[ev.EventID()]; ok { + filteredAfter = append(filteredAfter, ev) + } + } + return filteredBefore, filteredAfter, nil +} + func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { if len(startEvents) > 0 { start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index b4c9a5423..9db3d8e17 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -19,6 +19,7 @@ import ( "fmt" "net/http" "sort" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -28,6 +29,7 @@ import ( "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/internal" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/types" @@ -324,6 +326,9 @@ func (r *messagesReq) retrieveEvents() ( // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. start, end, err = r.getStartEnd(events) + if err != nil { + return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, err + } // Sort the events to ensure we send them in the right order. if r.backwardOrdering { @@ -337,97 +342,18 @@ func (r *messagesReq) retrieveEvents() ( } events = reversed(events) } - events = r.filterHistoryVisible(events) if len(events) == 0 { return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil } - // Convert all of the events into client events. - clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) - return clientEvents, start, end, err -} - -func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent { - // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the - // user shouldn't see, we check the recent events and remove any prior to the join event of the user - // which is equiv to history_visibility: joined - joinEventIndex := -1 - for i, ev := range events { - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(r.device.UserID) { - membership, _ := ev.Membership() - if membership == "join" { - joinEventIndex = i - break - } - } - } - - var result []*gomatrixserverlib.HeaderedEvent - var eventsToCheck []*gomatrixserverlib.HeaderedEvent - if joinEventIndex != -1 { - if r.backwardOrdering { - result = events[:joinEventIndex+1] - eventsToCheck = append(eventsToCheck, result[0]) - } else { - result = events[joinEventIndex:] - eventsToCheck = append(eventsToCheck, result[len(result)-1]) - } - } else { - eventsToCheck = []*gomatrixserverlib.HeaderedEvent{events[0], events[len(events)-1]} - result = events - } - // make sure the user was in the room for both the earliest and latest events, we need this because - // some backpagination results will not have the join event (e.g if they hit /messages at the join event itself) - wasJoined := true - for _, ev := range eventsToCheck { - var queryRes api.QueryStateAfterEventsResponse - err := r.rsAPI.QueryStateAfterEvents(r.ctx, &api.QueryStateAfterEventsRequest{ - RoomID: ev.RoomID(), - PrevEventIDs: ev.PrevEventIDs(), - StateToFetch: []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomMember, StateKey: r.device.UserID}, - {EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}, - }, - }, &queryRes) - if err != nil { - wasJoined = false - break - } - var hisVisEvent, membershipEvent *gomatrixserverlib.HeaderedEvent - for i := range queryRes.StateEvents { - switch queryRes.StateEvents[i].Type() { - case gomatrixserverlib.MRoomMember: - membershipEvent = queryRes.StateEvents[i] - case gomatrixserverlib.MRoomHistoryVisibility: - hisVisEvent = queryRes.StateEvents[i] - } - } - if hisVisEvent == nil { - return events // apply no filtering as it defaults to Shared. - } - hisVis, _ := hisVisEvent.HistoryVisibility() - if hisVis == "shared" || hisVis == "world_readable" { - return events // apply no filtering - } - if membershipEvent == nil { - wasJoined = false - break - } - membership, err := membershipEvent.Membership() - if err != nil { - wasJoined = false - break - } - if membership != "join" { - wasJoined = false - break - } - } - if !wasJoined { - util.GetLogger(r.ctx).WithField("num_events", len(events)).Warnf("%s was not joined to room during these events, omitting them", r.device.UserID) - return []*gomatrixserverlib.HeaderedEvent{} - } - return result + // Apply room history visibility filter + startTime := time.Now() + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": r.roomID, + }).Debug("applied history visibility (messages)") + return gomatrixserverlib.HeaderedToClientEvents(filteredEvents, gomatrixserverlib.FormatAll), start, end, err } func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9751670b2..43a75da95 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -161,6 +161,10 @@ type Database interface { IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error + // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found + // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty + // string as the membership. + SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) } type Presence interface { diff --git a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go index 29008adef..d68ed8d5f 100644 --- a/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/postgres/deltas/2022061412000000_history_visibility_column.go @@ -17,7 +17,10 @@ package deltas import ( "context" "database/sql" + "encoding/json" "fmt" + + "github.com/matrix-org/gomatrixserverlib" ) func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { @@ -31,6 +34,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T return nil } +// UpSetHistoryVisibility sets the history visibility for already stored events. +// Requires current_room_state and output_room_events to be created. +func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error { + // get the current room history visibilities + historyVisibilities, err := currentHistoryVisibilities(ctx, tx) + if err != nil { + return err + } + + // update the history visibility + for roomID, hisVis := range historyVisibilities { + _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1 + WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID) + if err != nil { + return fmt.Errorf("failed to update history visibility: %w", err) + } + } + + return nil +} + func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2; @@ -39,9 +63,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + return nil } +// currentHistoryVisibilities returns a map from roomID to current history visibility. +// If the history visibility was changed after room creation, defaults to joined. +func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) { + rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state + WHERE type = 'm.room.history_visibility' AND state_key = ''; +`) + if err != nil { + return nil, fmt.Errorf("failed to query current room state: %w", err) + } + defer rows.Close() // nolint: errcheck + var eventBytes []byte + var roomID string + var event gomatrixserverlib.HeaderedEvent + var hisVis gomatrixserverlib.HistoryVisibility + historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) + for rows.Next() { + if err = rows.Scan(&roomID, &eventBytes); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + if err = json.Unmarshal(eventBytes, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal event: %w", err) + } + historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined + if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 { + historyVisibilities[roomID] = hisVis + } + } + return historyVisibilities, nil +} + func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility; diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 00223c57a..939d6b3f5 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -66,10 +66,14 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { - upsertMembershipStmt *sql.Stmt - selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt + upsertMembershipStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectHeroesStmt *sql.Stmt + selectMembershipForUserStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -82,6 +86,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectHeroesStmt, selectHeroesSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, }.Prepare(db) } @@ -132,3 +137,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 34ff6700f..8f633640e 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -191,10 +191,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "syncapi: add history visibility column (output_room_events)", - Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: add history visibility column (output_room_events)", + Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index a044716ce..979ff6647 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/shared" ) @@ -97,6 +98,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if err != nil { return nil, err } + + // apply migrations which need multiple tables + m := sqlutil.NewMigrator(d.db) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: set history visibility for existing events", + Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. + }, + ) + err = m.Up(base.Context()) + if err != nil { + return nil, err + } + d.Database = shared.Database{ DB: d.db, Writer: d.writer, diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 932bda164..a46e55256 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -231,7 +231,7 @@ func (d *Database) AddPeek( return } -// DeletePeeks tracks the fact that a user has stopped peeking from the specified +// DeletePeek tracks the fact that a user has stopped peeking from the specified // device. If the peeks was successfully deleted this returns the stream ID it was // stored at. Returns an error if there was a problem communicating with the database. func (d *Database) DeletePeek( @@ -372,6 +372,7 @@ func (d *Database) WriteEvent( ) (pduPosition types.StreamPosition, returnErr error) { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error + ev.Visibility = historyVisibility pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, ) @@ -563,7 +564,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda return err } -// Retrieve the backward topology position, i.e. the position of the +// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the // oldest event in the room's topology. func (d *Database) GetBackwardTopologyPos( ctx context.Context, @@ -674,7 +675,7 @@ func (d *Database) fetchMissingStateEvents( return events, nil } -// getStateDeltas returns the state deltas between fromPos and toPos, +// GetStateDeltas returns the state deltas between fromPos and toPos, // exclusive of oldPos, inclusive of newPos, for the rooms in which // the user has new membership events. // A list of joined room IDs is also returned in case the caller needs it. @@ -812,7 +813,7 @@ func (d *Database) GetStateDeltas( return deltas, joinedRoomIDs, nil } -// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync // requests with full_state=true. // Fetches full state for all joined rooms and uses selectStateInRange to get // updates for other rooms. @@ -1039,37 +1040,41 @@ func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID s return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to) } -func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) } -func (s *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) } -func (s *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) } -func (s *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { - return s.Ignores.SelectIgnores(ctx, userID) +func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { + return d.Ignores.SelectIgnores(ctx, userID) } -func (s *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { - return s.Ignores.UpsertIgnores(ctx, userID, ignores) +func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { + return d.Ignores.UpsertIgnores(ctx, userID, ignores) } -func (s *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { - return s.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync) } -func (s *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return s.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, nil, userID) } -func (s *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return s.Presence.GetPresenceAfter(ctx, nil, after, filter) +func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, nil, after, filter) } -func (s *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return s.Presence.GetMaxPresenceID(ctx, nil) +func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, nil) +} + +func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { + return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) } diff --git a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go index 079177217..d23f07566 100644 --- a/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go +++ b/syncapi/storage/sqlite3/deltas/2022061412000000_history_visibility_column.go @@ -17,7 +17,10 @@ package deltas import ( "context" "database/sql" + "encoding/json" "fmt" + + "github.com/matrix-org/gomatrixserverlib" ) func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error { @@ -37,6 +40,27 @@ func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.T return nil } +// UpSetHistoryVisibility sets the history visibility for already stored events. +// Requires current_room_state and output_room_events to be created. +func UpSetHistoryVisibility(ctx context.Context, tx *sql.Tx) error { + // get the current room history visibilities + historyVisibilities, err := currentHistoryVisibilities(ctx, tx) + if err != nil { + return err + } + + // update the history visibility + for roomID, hisVis := range historyVisibilities { + _, err = tx.ExecContext(ctx, `UPDATE syncapi_output_room_events SET history_visibility = $1 + WHERE type IN ('m.room.message', 'm.room.encrypted') AND room_id = $2 AND history_visibility <> $1`, hisVis, roomID) + if err != nil { + return fmt.Errorf("failed to update history visibility: %w", err) + } + } + + return nil +} + func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error { // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists. // Required for unit tests, as otherwise a duplicate column error will show up. @@ -51,9 +75,40 @@ func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.T if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + return nil } +// currentHistoryVisibilities returns a map from roomID to current history visibility. +// If the history visibility was changed after room creation, defaults to joined. +func currentHistoryVisibilities(ctx context.Context, tx *sql.Tx) (map[string]gomatrixserverlib.HistoryVisibility, error) { + rows, err := tx.QueryContext(ctx, `SELECT DISTINCT room_id, headered_event_json FROM syncapi_current_room_state + WHERE type = 'm.room.history_visibility' AND state_key = ''; +`) + if err != nil { + return nil, fmt.Errorf("failed to query current room state: %w", err) + } + defer rows.Close() // nolint: errcheck + var eventBytes []byte + var roomID string + var event gomatrixserverlib.HeaderedEvent + var hisVis gomatrixserverlib.HistoryVisibility + historyVisibilities := make(map[string]gomatrixserverlib.HistoryVisibility) + for rows.Next() { + if err = rows.Scan(&roomID, &eventBytes); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + if err = json.Unmarshal(eventBytes, &event); err != nil { + return nil, fmt.Errorf("failed to unmarshal event: %w", err) + } + historyVisibilities[roomID] = gomatrixserverlib.HistoryVisibilityJoined + if hisVis, err = event.HistoryVisibility(); err == nil && event.Depth() < 10 { + historyVisibilities[roomID] = hisVis + } + } + return historyVisibilities, nil +} + func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error { // SQLite doesn't have "if exists", so check if the column exists. _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1") diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index e4daa99c1..0c966fca0 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -66,11 +66,15 @@ const selectMembershipCountSQL = "" + const selectHeroesSQL = "" + "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" +const selectMembershipBeforeSQL = "" + + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic + selectMembershipForUserStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -84,6 +88,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, + {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -148,3 +153,20 @@ func (s *membershipsStatements) SelectHeroes( } return heroes, rows.Err() } + +// SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found +// returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty +// string as the membership. +func (s *membershipsStatements) SelectMembershipForUser( + ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64, +) (membership string, topologyPos int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipForUserStmt) + err = stmt.QueryRowContext(ctx, roomID, userID, pos).Scan(&membership, &topologyPos) + if err != nil { + if err == sql.ErrNoRows { + return "leave", 0, nil + } + return "", 0, err + } + return membership, topologyPos, nil +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index de389fa9b..91fd35b5b 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -139,10 +139,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "syncapi: add history visibility column (output_room_events)", - Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, - }) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: add history visibility column (output_room_events)", + Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, + }, + ) err = m.Up(context.Background()) if err != nil { return nil, err diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 5c5eb0f55..a84e2bd16 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -16,12 +16,14 @@ package sqlite3 import ( + "context" "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/storage/shared" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas" ) // SyncServerDatasource represents a sync server datasource which manages @@ -41,13 +43,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } - if err = d.prepare(); err != nil { + if err = d.prepare(base.Context()); err != nil { return nil, err } return &d, nil } -func (d *SyncServerDatasource) prepare() (err error) { +func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) { if err = d.streamID.Prepare(d.db); err != nil { return err } @@ -107,6 +109,19 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + + // apply migrations which need multiple tables + m := sqlutil.NewMigrator(d.db) + m.AddMigrations( + sqlutil.Migration{ + Version: "syncapi: set history visibility for existing events", + Up: deltas.UpSetHistoryVisibility, // Requires current_room_state and output_room_events to be created. + }, + ) + err = m.Up(ctx) + if err != nil { + return err + } d.Database = shared.Database{ DB: d.db, Writer: d.writer, diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index eda5ef3e6..a62818e9b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -12,20 +12,22 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" ) var ctx = context.Background() -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func(), func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewSyncServerDatasource(nil, &config.DatabaseOptions{ + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewSyncServerDatasource(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } - return db, close + return db, close, closeBase } func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) { @@ -51,8 +53,9 @@ func TestWriteEvents(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { alice := test.NewUser(t) r := test.NewRoom(t, alice) - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() MustWriteEvents(t, db, r.Events()) }) } @@ -60,8 +63,9 @@ func TestWriteEvents(t *testing.T) { // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() alice := test.NewUser(t) // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) @@ -163,8 +167,9 @@ func TestRecentEventsPDU(t *testing.T) { // The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token func TestGetEventsInRangeWithTopologyToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() alice := test.NewUser(t) r := test.NewRoom(t, alice) for i := 0; i < 10; i++ { @@ -404,8 +409,9 @@ func TestSendToDeviceBehaviour(t *testing.T) { bob := test.NewUser(t) deviceID := "one" test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := MustCreateDatabase(t, dbType) + db, close, closeBase := MustCreateDatabase(t, dbType) defer close() + defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index d68351d4c..468d26aca 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -185,6 +185,7 @@ type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) + SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) } type NotificationData interface { diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 1ad3adc4c..136cbea5a 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -10,10 +10,13 @@ import ( "github.com/matrix-org/dendrite/internal/caching" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/internal" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "go.uber.org/atomic" @@ -123,7 +126,7 @@ func (p *PDUStreamProvider) CompleteSync( defer reqWaitGroup.Done() jr, jerr := p.getJoinResponseForCompleteSync( - ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, + ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -149,7 +152,7 @@ func (p *PDUStreamProvider) CompleteSync( if !peek.Deleted { var jr *types.JoinResponse jr, err = p.getJoinResponseForCompleteSync( - ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, + ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -281,12 +284,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } } - if len(recentEvents) > 0 { - updateLatestPosition(recentEvents[len(recentEvents)-1].EventID()) - } - if len(delta.StateEvents) > 0 { - updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) - } if stateFilter.LazyLoadMembers { delta.StateEvents, err = p.lazyLoadMembers( @@ -306,6 +303,19 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } + // Applies the history visibility rules + events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + } + + if len(events) > 0 { + updateLatestPosition(events[len(events)-1].EventID()) + } + if len(delta.StateEvents) > 0 { + updateLatestPosition(delta.StateEvents[len(delta.StateEvents)-1].EventID()) + } + switch delta.Membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() @@ -313,14 +323,17 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = limited + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) res.Rooms.Join[delta.RoomID] = *jr case gomatrixserverlib.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch + // TODO: Apply history visibility on peeked rooms jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = limited jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) @@ -330,12 +343,12 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( fallthrough // transitions to leave are the same as ban case gomatrixserverlib.Ban: - // TODO: recentEvents may contain events that this user is not allowed to see because they are - // no longer in the room. lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + lr.Timeline.Limited = limited && len(events) == len(recentEvents) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.StateEvents, gomatrixserverlib.FormatSync) res.Rooms.Leave[delta.RoomID] = *lr } @@ -343,6 +356,41 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } +// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make +// sure we always return the required events in the timeline. +func applyHistoryVisibilityFilter( + ctx context.Context, + db storage.Database, + rsAPI roomserverAPI.SyncRoomserverAPI, + roomID, userID string, + limit int, + recentEvents []*gomatrixserverlib.HeaderedEvent, +) ([]*gomatrixserverlib.HeaderedEvent, error) { + // We need to make sure we always include the latest states events, if they are in the timeline. + // We grep at least limit * 2 events, to ensure we really get the needed events. + stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) + if err != nil { + // Not a fatal error, we can continue without the stateEvents, + // they are only needed if there are state events in the timeline. + logrus.WithError(err).Warnf("failed to get current room state") + } + alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } + startTime := time.Now() + events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") + if err != nil { + + return nil, err + } + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": roomID, + }).Debug("applied history visibility (sync)") + return events, nil +} + func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { // Work out how many members are in the room. joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) @@ -390,6 +438,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( eventFilter *gomatrixserverlib.RoomEventFilter, wantFullState bool, device *userapi.Device, + isPeek bool, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. @@ -404,33 +453,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - // TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the - // user shouldn't see, we check the recent events and remove any prior to the join event of the user - // which is equiv to history_visibility: joined - joinEventIndex := -1 - for i := len(recentStreamEvents) - 1; i >= 0; i-- { - ev := recentStreamEvents[i] - if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) { - membership, _ := ev.Membership() - if membership == "join" { - joinEventIndex = i - if i > 0 { - // the create event happens before the first join, so we should cut it at that point instead - if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") { - joinEventIndex = i - 1 - break - } - } - break - } - } - } - if joinEventIndex != -1 { - // cut all events earlier than the join (but not the join itself) - recentStreamEvents = recentStreamEvents[joinEventIndex:] - limited = false // so clients know not to try to backpaginate - } - // Work our way through the timeline events and pick out the event IDs // of any state events that appear in the timeline. We'll specifically // exclude them at the next step, so that we don't get duplicate state @@ -474,6 +496,19 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) + events := recentEvents + // Only apply history visibility checks if the response is for joined rooms + if !isPeek { + events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) + if err != nil { + logrus.WithError(err).Error("unable to apply history visibility filter") + } + } + + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + limited = limited && len(events) == len(recentEvents) + if stateFilter.LazyLoadMembers { if err != nil { return nil, err @@ -488,8 +523,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = limited + jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) + // If we are limited by the filter AND the history visibility filter + // didn't "remove" events, return that the response is limited. + jr.Timeline.Limited = limited && len(events) == len(recentEvents) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) return jr, nil } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 931fef883..dc073a16e 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -12,6 +12,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/producers" keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" @@ -54,6 +55,16 @@ func (s *syncRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *rsap return nil } +func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsapi.QueryMembershipForUserRequest, res *rsapi.QueryMembershipForUserResponse) error { + res.IsRoomForgotten = false + res.RoomExists = true + return nil +} + +func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error { + return nil +} + type syncUserAPI struct { userapi.SyncUserAPI accounts []userapi.Device @@ -107,7 +118,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) { jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) - msgs := toNATSMsgs(t, base, room.Events()) + msgs := toNATSMsgs(t, base, room.Events()...) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) testrig.MustPublishMsgs(t, jsctx, msgs...) @@ -200,7 +211,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) { // m.room.power_levels // m.room.join_rules // m.room.history_visibility - msgs := toNATSMsgs(t, base, room.Events()) + msgs := toNATSMsgs(t, base, room.Events()...) sinceTokens := make([]string, len(msgs)) AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{}) for i, msg := range msgs { @@ -315,6 +326,174 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) { } +// This is mainly what Sytest is doing in "test_history_visibility" +func TestMessageHistoryVisibility(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + testHistoryVisibility(t, dbType) + }) +} + +func testHistoryVisibility(t *testing.T, dbType test.DBType) { + type result struct { + seeWithoutJoin bool + seeBeforeJoin bool + seeAfterInvite bool + } + + // create the users + alice := test.NewUser(t) + bob := test.NewUser(t) + + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "BOD_BEARER_TOKEN", + DisplayName: "BOB", + } + + ctx := context.Background() + // check guest and normal user accounts + for _, accType := range []userapi.AccountType{userapi.AccountTypeGuest, userapi.AccountTypeUser} { + testCases := []struct { + historyVisibility gomatrixserverlib.HistoryVisibility + wantResult result + }{ + { + historyVisibility: gomatrixserverlib.HistoryVisibilityWorldReadable, + wantResult: result{ + seeWithoutJoin: true, + seeBeforeJoin: true, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityShared, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: true, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityInvited, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: false, + seeAfterInvite: true, + }, + }, + { + historyVisibility: gomatrixserverlib.HistoryVisibilityJoined, + wantResult: result{ + seeWithoutJoin: false, + seeBeforeJoin: false, + seeAfterInvite: false, + }, + }, + } + + bobDev.AccountType = accType + userType := "guest" + if accType == userapi.AccountTypeUser { + userType = "real user" + } + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use the actual internal roomserver API + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{bobDev}}, rsAPI, &syncKeyAPI{}) + + for _, tc := range testCases { + testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType) + t.Run(testname, func(t *testing.T) { + // create a room with the given visibility + room := test.NewRoom(t, alice, test.RoomHistoryVisibility(tc.historyVisibility)) + + // send the events/messages to NATS to create the rooms + beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("Before invite in a %s room", tc.historyVisibility)}) + eventsToSend := append(room.Events(), beforeJoinEv) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // There is only one event, we expect only to be able to see this, if the room is world_readable + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + "dir": "b", + }))) + if w.Code != 200 { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + // We only care about the returned events at this point + var res struct { + Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + } + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + + verifyEventVisible(t, tc.wantResult.seeWithoutJoin, beforeJoinEv, res.Chunk) + + // Create invite, a message, join the room and create another message. + inviteEv := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "invite"}, test.WithStateKey(bob.ID)) + afterInviteEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After invite in a %s room", tc.historyVisibility)}) + joinEv := room.CreateAndInsert(t, bob, "m.room.member", map[string]interface{}{"membership": "join"}, test.WithStateKey(bob.ID)) + msgEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("After join in a %s room", tc.historyVisibility)}) + + eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) + + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Verify the messages after/before invite are visible or not + w = httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + "dir": "b", + }))) + if w.Code != 200 { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + if err := json.NewDecoder(w.Body).Decode(&res); err != nil { + t.Errorf("failed to decode response body: %s", err) + } + // verify results + verifyEventVisible(t, tc.wantResult.seeBeforeJoin, beforeJoinEv, res.Chunk) + verifyEventVisible(t, tc.wantResult.seeAfterInvite, afterInviteEv, res.Chunk) + }) + } + } +} + +func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatrixserverlib.HeaderedEvent, chunk []gomatrixserverlib.ClientEvent) { + t.Helper() + if wantVisible { + for _, ev := range chunk { + if ev.EventID == wantVisibleEvent.EventID() { + return + } + } + t.Fatalf("expected to see event %s but didn't: %+v", wantVisibleEvent.EventID(), chunk) + } else { + for _, ev := range chunk { + if ev.EventID == wantVisibleEvent.EventID() { + t.Fatalf("expected not to see event %s: %+v", wantVisibleEvent.EventID(), string(ev.Content)) + } + } + } +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } @@ -448,7 +627,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } -func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverlib.HeaderedEvent) []*nats.Msg { +func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input ...*gomatrixserverlib.HeaderedEvent) []*nats.Msg { result := make([]*nats.Msg, len(input)) for i, ev := range input { var addsStateIDs []string @@ -460,6 +639,7 @@ func toNATSMsgs(t *testing.T, base *base.BaseDendrite, input []*gomatrixserverli NewRoomEvent: &rsapi.OutputNewRoomEvent{ Event: ev, AddsStateEventIDs: addsStateIDs, + HistoryVisibility: ev.Visibility, }, }) } diff --git a/sytest-whitelist b/sytest-whitelist index 2a145291f..88dfe920b 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -110,8 +110,6 @@ Newly joined room is included in an incremental sync User is offline if they set_presence=offline in their sync Changes to state are included in an incremental sync A change to displayname should appear in incremental /sync -Current state appears in timeline in private history -Current state appears in timeline in private history with many messages before Rooms a user is invited to appear in an initial sync Rooms a user is invited to appear in an incremental sync Sync can be polled for updates @@ -458,7 +456,6 @@ After changing password, a different session no longer works by default Read markers appear in incremental v2 /sync Read markers appear in initial v2 /sync Read markers can be updated -Local users can peek into world_readable rooms by room ID We can't peek into rooms with shared history_visibility We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility @@ -720,4 +717,26 @@ Setting state twice is idempotent Joining room twice is idempotent Inbound federation can return missing events for shared visibility Inbound federation ignores redactions from invalid servers room > v3 +Joining room twice is idempotent +Getting messages going forward is limited for a departed room (SPEC-216) +m.room.history_visibility == "shared" allows/forbids appropriately for Guest users +m.room.history_visibility == "invited" allows/forbids appropriately for Guest users +m.room.history_visibility == "default" allows/forbids appropriately for Guest users +m.room.history_visibility == "shared" allows/forbids appropriately for Real users +m.room.history_visibility == "invited" allows/forbids appropriately for Real users +m.room.history_visibility == "default" allows/forbids appropriately for Real users +Guest users can sync from world_readable guest_access rooms if joined +Guest users can sync from shared guest_access rooms if joined +Guest users can sync from invited guest_access rooms if joined +Guest users can sync from joined guest_access rooms if joined +Guest users can sync from default guest_access rooms if joined +Real users can sync from world_readable guest_access rooms if joined +Real users can sync from shared guest_access rooms if joined +Real users can sync from invited guest_access rooms if joined +Real users can sync from joined guest_access rooms if joined +Real users can sync from default guest_access rooms if joined +Only see history_visibility changes on boundaries +Current state appears in timeline in private history +Current state appears in timeline in private history with many messages before +Local users can peek into world_readable rooms by room ID Newly joined room includes presence in incremental sync \ No newline at end of file diff --git a/test/room.go b/test/room.go index 6ae403b3f..94eb51bbe 100644 --- a/test/room.go +++ b/test/room.go @@ -37,10 +37,11 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -61,6 +62,7 @@ func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room { preset: PresetPublicChat, Version: gomatrixserverlib.RoomVersionV9, currentState: make(map[string]*gomatrixserverlib.HeaderedEvent), + visibility: gomatrixserverlib.HistoryVisibilityShared, } for _, m := range modifiers { m(t, r) @@ -97,10 +99,14 @@ func (r *Room) insertCreateEvents(t *testing.T) { fallthrough case PresetPrivateChat: joinRule.JoinRule = "invite" - hisVis.HistoryVisibility = "shared" + hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared case PresetPublicChat: joinRule.JoinRule = "public" - hisVis.HistoryVisibility = "shared" + hisVis.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared + } + + if r.visibility != "" { + hisVis.HistoryVisibility = r.visibility } r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{ @@ -183,7 +189,9 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } - return ev.Headered(r.Version) + headeredEvent := ev.Headered(r.Version) + headeredEvent.Visibility = r.visibility + return headeredEvent } // Add a new event to this room DAG. Not thread-safe. @@ -242,6 +250,12 @@ func RoomPreset(p Preset) roomModifier { } } +func RoomHistoryVisibility(vis gomatrixserverlib.HistoryVisibility) roomModifier { + return func(t *testing.T, r *Room) { + r.visibility = vis + } +} + func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { return func(t *testing.T, r *Room) { r.Version = ver diff --git a/test/user.go b/test/user.go index 0020098a5..692eae351 100644 --- a/test/user.go +++ b/test/user.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "testing" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) @@ -45,7 +46,8 @@ var ( ) type User struct { - ID string + ID string + accountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -62,6 +64,12 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve } } +func WithAccountType(accountType api.AccountType) UserOpt { + return func(u *User) { + u.accountType = accountType + } +} + func NewUser(t *testing.T, opts ...UserOpt) *User { counter := atomic.AddInt64(&userIDCounter, 1) var u User From a01af55ec694f33d68ecf5e2dc6cfc3e4ff3c2b9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Aug 2022 17:34:09 +0100 Subject: [PATCH 7/9] Restore the room version cache in the roomserver internal API HTTP client --- roomserver/inthttp/client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index e9387bd97..a1dfc6aac 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -425,10 +425,14 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( response.RoomVersion = roomVersion return nil } - return httputil.CallInternalRPCAPI( + err := httputil.CallInternalRPCAPI( "QueryRoomVersionForRoom", h.roomserverURL+RoomserverQueryRoomVersionForRoomPath, h.httpClient, ctx, request, response, ) + if err == nil { + h.cache.StoreRoomVersion(request.RoomID, response.RoomVersion) + } + return err } func (h *httpRoomserverInternalAPI) QueryCurrentState( From 2b352915a1722f4971283272f34d219ac6ee32af Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Aug 2022 17:40:57 +0100 Subject: [PATCH 8/9] Update `golangci-lint` component in GHA workflow --- .github/workflows/dendrite.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 19ebd760c..9fa6cf197 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -67,7 +67,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 # run go test with different go versions test: From fad3ac8e787cebd231d6fe115980f6335bc94353 Mon Sep 17 00:00:00 2001 From: Tak Wai Wong <64229756+tak-hntlabs@users.noreply.github.com> Date: Fri, 12 Aug 2022 01:12:05 -0700 Subject: [PATCH 9/9] Protect user_interactive reads and writes with locks (#2635) * Protect user_interactive reads and writes with locks * Ignore golangci-lint false positive * fix lint Co-authored-by: Tak Wai Wong --- .github/workflows/dendrite.yml | 18 ++++++++++------ clientapi/auth/user_interactive.go | 34 ++++++++++++++++++++++++------ internal/caching/impl_ristretto.go | 2 +- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 9fa6cf197..3044b0b9d 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -19,10 +19,10 @@ jobs: runs-on: ubuntu-latest if: ${{ false }} # disable for now steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: 1.18 @@ -66,6 +66,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 - name: golangci-lint uses: golangci/golangci-lint-action@v3 @@ -101,7 +105,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - uses: actions/cache@v3 @@ -133,7 +137,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Install dependencies x86 @@ -167,7 +171,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup Go ${{ matrix.go }} - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Install dependencies @@ -208,7 +212,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: "1.18" - uses: actions/cache@v3 @@ -233,7 +237,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: Setup go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: go-version: "1.18" - uses: actions/cache@v3 diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 82ecf674c..9971bf8a4 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "net/http" + "sync" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" @@ -102,6 +103,7 @@ type userInteractiveFlow struct { // the user already has a valid access token, but we want to double-check // that it isn't stolen by re-authenticating them. type UserInteractive struct { + sync.RWMutex Flows []userInteractiveFlow // Map of login type to implementation Types map[string]Type @@ -128,6 +130,8 @@ func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) } func (u *UserInteractive) IsSingleStageFlow(authType string) bool { + u.RLock() + defer u.RUnlock() for _, f := range u.Flows { if len(f.Stages) == 1 && f.Stages[0] == authType { return true @@ -137,8 +141,10 @@ func (u *UserInteractive) IsSingleStageFlow(authType string) bool { } func (u *UserInteractive) AddCompletedStage(sessionID, authType string) { + u.Lock() // TODO: Handle multi-stage flows delete(u.Sessions, sessionID) + u.Unlock() } type Challenge struct { @@ -150,12 +156,17 @@ type Challenge struct { } // Challenge returns an HTTP 401 with the supported flows for authenticating -func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { +func (u *UserInteractive) challenge(sessionID string) *util.JSONResponse { + u.RLock() + completed := u.Sessions[sessionID] + flows := u.Flows + u.RUnlock() + return &util.JSONResponse{ Code: 401, JSON: Challenge{ - Completed: u.Sessions[sessionID], - Flows: u.Flows, + Completed: completed, + Flows: flows, Session: sessionID, Params: make(map[string]interface{}), }, @@ -170,8 +181,10 @@ func (u *UserInteractive) NewSession() *util.JSONResponse { res := jsonerror.InternalServerError() return &res } + u.Lock() u.Sessions[sessionID] = []string{} - return u.Challenge(sessionID) + u.Unlock() + return u.challenge(sessionID) } // ResponseWithChallenge mixes together a JSON body (e.g an error with errcode/message) with the @@ -184,7 +197,7 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter return &ise } _ = json.Unmarshal(b, &mixedObjects) - challenge := u.Challenge(sessionID) + challenge := u.challenge(sessionID) b, err = json.Marshal(challenge.JSON) if err != nil { ise := jsonerror.InternalServerError() @@ -213,7 +226,11 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * // extract the type so we know which login type to use authType := gjson.GetBytes(bodyBytes, "auth.type").Str + + u.RLock() loginType, ok := u.Types[authType] + u.RUnlock() + if !ok { return nil, &util.JSONResponse{ Code: http.StatusBadRequest, @@ -223,7 +240,12 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * // retrieve the session sessionID := gjson.GetBytes(bodyBytes, "auth.session").Str - if _, ok = u.Sessions[sessionID]; !ok { + + u.RLock() + _, ok = u.Sessions[sessionID] + u.RUnlock() + + if !ok { // if the login type is part of a single stage flow then allow them to omit the session ID if !u.IsSingleStageFlow(authType) { return nil, &util.JSONResponse{ diff --git a/internal/caching/impl_ristretto.go b/internal/caching/impl_ristretto.go index fc0c8cc0f..49292d0dc 100644 --- a/internal/caching/impl_ristretto.go +++ b/internal/caching/impl_ristretto.go @@ -146,7 +146,7 @@ func (c *RistrettoCostedCachePartition[K, V]) Set(key K, value V) { } type RistrettoCachePartition[K keyable, V any] struct { - cache *ristretto.Cache + cache *ristretto.Cache //nolint:all,unused Prefix byte Mutable bool MaxAge time.Duration