diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go index 17bbaef40..d86ccd6b5 100644 --- a/clientapi/routing/peekroom.go +++ b/clientapi/routing/peekroom.go @@ -17,8 +17,6 @@ package routing import ( "net/http" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" @@ -33,17 +31,6 @@ func PeekRoomByIDOrAlias( accountDB accounts.Database, roomIDOrAlias string, ) util.JSONResponse { - // Check to see if any ?server_name= query parameters were - // given in the request. - if serverNames, ok := req.URL.Query()["server_name"]; ok { - for _, serverName := range serverNames { - peekReq.ServerNames = append( - peekReq.ServerNames, - gomatrixserverlib.ServerName(serverName), - ) - } - } - // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to // to call /peek and /state on the remote server. // TODO: in future we could skip this if we know we're already participating in the room, @@ -57,6 +44,17 @@ func PeekRoomByIDOrAlias( } peekRes := roomserverAPI.PerformPeekResponse{} + // Check to see if any ?server_name= query parameters were + // given in the request. + if serverNames, ok := req.URL.Query()["server_name"]; ok { + for _, serverName := range serverNames { + peekReq.ServerNames = append( + peekReq.ServerNames, + gomatrixserverlib.ServerName(serverName), + ) + } + } + // Ask the roomserver to perform the peek. rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes) if peekRes.Error != nil { diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 9b53aa88c..0e1b645e4 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -38,6 +38,15 @@ func (t *RoomserverInternalAPITrace) PerformInvite( return t.Impl.PerformInvite(ctx, req, res) } +func (t *RoomserverInternalAPITrace) PerformPeek( + ctx context.Context, + req *PerformPeekRequest, + res *PerformPeekResponse, +) { + t.Impl.PerformPeek(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res)) +} + func (t *RoomserverInternalAPITrace) PerformJoin( ctx context.Context, req *PerformJoinRequest, diff --git a/roomserver/internal/perform_peek.go b/roomserver/internal/perform_peek.go index 2be224e3b..4f080737e 100644 --- a/roomserver/internal/perform_peek.go +++ b/roomserver/internal/perform_peek.go @@ -16,13 +16,10 @@ package internal import ( "context" - "errors" "fmt" "strings" - "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -80,7 +77,7 @@ func (r *RoomserverInternalAPI) performPeek( func (r *RoomserverInternalAPI) performPeekRoomByAlias( ctx context.Context, - req *api.PerformJoinRequest, + req *api.PerformPeekRequest, ) (string, error) { // Get the domain part of the room alias. _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) @@ -149,7 +146,7 @@ func (r *RoomserverInternalAPI) performPeekRoomByID( // TODO: handle federated peeks - err := r.WriteOutputEvents(roomID, []api.OutputEvent{ + err = r.WriteOutputEvents(roomID, []api.OutputEvent{ { Type: api.OutputTypeNewPeek, NewPeek: &api.OutputNewPeek{ @@ -167,5 +164,5 @@ func (r *RoomserverInternalAPI) performPeekRoomByID( // it will have been overwritten with a room ID by performPeekRoomByAlias. // We should now include this in the response so that the CS API can // return the right room ID. - return + return roomID, nil; } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1657bcdeb..4ffc3c8bb 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -26,6 +26,7 @@ const ( // Perform operations RoomserverPerformInvitePath = "/roomserver/performInvite" + RoomserverPerformPeekPath = "/roomserver/performPeek" RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformBackfillPath = "/roomserver/performBackfill" @@ -179,6 +180,23 @@ func (h *httpRoomserverInternalAPI) PerformJoin( } } +func (h *httpRoomserverInternalAPI) PerformPeek( + ctx context.Context, + request *api.PerformPeekRequest, + response *api.PerformPeekResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverPerformPeekPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.PerformError{ + Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err), + } + } +} + func (h *httpRoomserverInternalAPI) PerformLeave( ctx context.Context, request *api.PerformLeaveRequest, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 952122952..69a92110e 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -31,7 +31,7 @@ type Database interface { // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. - AllPeekingDevicesInRooms(ctx context.Context) (map[string][]PeekingDevice, error) + AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. @@ -83,6 +83,9 @@ type Database interface { // RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite. // Returns an error if there was a problem communicating with the database. RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) + // AddPeek adds a new peek to our DB for a given room by a given user's device. + // Returns an error if there was a problem communicating with the database. + AddPeek(ctx context.Context, RoomID, UserID, DeviceID string) (types.StreamPosition, error) // SetTypingTimeoutCallback sets a callback function that is called right after // a user is removed from the typing user list due to timeout. SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index cf8dd604c..76bb15259 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -39,6 +39,7 @@ type Database struct { DB *sql.DB Writer sqlutil.Writer Invites tables.Invites + Peeks tables.Peeks AccountData tables.AccountData OutputEvents tables.Events Topology tables.Topology @@ -120,7 +121,7 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri return d.CurrentRoomState.SelectJoinedUsers(ctx) } -func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]PeekingDevice, error) { +func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { return d.Peeks.SelectPeekingDevices(ctx) } @@ -198,7 +199,7 @@ func (d *Database) AddPeek( ctx context.Context, roomID, userID, deviceID string, ) (sp types.StreamPosition, err error) { _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - sp, err = d.Peeks.InsertPeek(ctx, nil, inviteEvent) + sp, err = d.Peeks.InsertPeek(ctx, nil, roomID, userID, deviceID) return nil }) return @@ -992,7 +993,7 @@ func (d *Database) getStateDeltas( // find out which rooms this user is peeking, if any. // We do this before joins so joins overwrite peeks - peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, device.DeviceID) + peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, device.ID) if err != nil { return nil, nil, err } @@ -1006,7 +1007,7 @@ func (d *Database) getStateDeltas( if err != nil { return nil, nil, err } - state[roomID] = s + state[peek.RoomID] = s } deltas = append(deltas, stateDelta{ @@ -1017,7 +1018,7 @@ func (d *Database) getStateDeltas( } if len(peeks) > 0 { - err := d.Peeks.MarkPeeksAsOld(ctx, txn, userID, device.DeviceID) + err := d.Peeks.MarkPeeksAsOld(ctx, txn, userID, device.ID) if err != nil { return nil, nil, err } @@ -1084,7 +1085,7 @@ func (d *Database) getStateDeltasForFullStateSync( return nil, nil, err } - peeks, err = d.Peeks.SelectPeeks(ctx, txn, userID, device,ID) + peeks, err := d.Peeks.SelectPeeks(ctx, txn, userID, device.ID) if err != nil { return nil, nil, err } @@ -1119,7 +1120,7 @@ func (d *Database) getStateDeltasForFullStateSync( } if len(peeks) > 0 { - err := d.Peeks.MarkPeeksAsOld(ctx, txn, userID, device.DeviceID) + err := d.Peeks.MarkPeeksAsOld(ctx, txn, userID, device.ID) if err != nil { return nil, nil, err } diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 7ac329f66..9218e0e05 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -17,13 +17,12 @@ package sqlite3 import ( "context" "database/sql" - "encoding/json" + "time" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const peeksSchema = ` @@ -49,30 +48,33 @@ const insertPeekSQL = "" + const deletePeekSQL = "" + "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2 and device_id = $3" -const selectPeeksSQL == "" + +const selectPeeksSQL = "" + "SELECT room_id, new FROM syncapi_peeks WHERE user_id = $1 and device_id = $2" -const selectPeekingDevicesSQL == "" + +const selectPeekingDevicesSQL = "" + "SELECT room_id, user_id, device_id FROM syncapi_peeks" -const markPeeksAsOldSQL == "" + +const markPeeksAsOldSQL = "" + "UPDATE syncapi_peeks SET new=false WHERE user_id = $1 and device_id = $2" type peekStatements struct { db *sql.DB + streamIDStatements *streamIDStatements insertPeekStmt *sql.Stmt deletePeekStmt *sql.Stmt selectPeeksStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt + markPeeksAsOldStmt *sql.Stmt } -func NewSqlitePeeksTable(db *sql.DB) (tables.Peeks, error) { +func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) { _, err := db.Exec(filterSchema) if err != nil { return nil, err } s := &peekStatements{ db: db, + streamIDStatements: streamID, } if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { return nil, err @@ -117,7 +119,7 @@ func (s *peekStatements) DeletePeek( func (s *peekStatements) SelectPeeks( ctx context.Context, txn *sql.Tx, userID, deviceID string, -) (peeks []Peek, err error) { +) (peeks []types.Peek, err error) { rows, err := sqlutil.TxStmt(txn, s.selectPeeksStmt).QueryContext(ctx, userID, deviceID) if err != nil { return @@ -125,8 +127,8 @@ func (s *peekStatements) SelectPeeks( defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeks: rows.close() failed") for rows.Next() { - peek = Peek{} - if err = rows.Scan(&peek.roomId, &peek.new); err != nil { + peek := types.Peek{} + if err = rows.Scan(&peek.RoomID, &peek.New); err != nil { return } peeks = append(peeks, peek) @@ -138,27 +140,27 @@ func (s *peekStatements) SelectPeeks( func (s *peekStatements) MarkPeeksAsOld ( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (err error) { - _, err := sqlutil.TxStmt(txn, s.markPeeksAsOldStmt).ExecContext(ctx, userID, deviceID) + _, err = sqlutil.TxStmt(txn, s.markPeeksAsOldStmt).ExecContext(ctx, userID, deviceID) return } func (s *peekStatements) SelectPeekingDevices( ctx context.Context, -) (peekingDevices map[string][]PeekingDevice, err error) { +) (peekingDevices map[string][]types.PeekingDevice, err error) { rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed") - result := make(map[string][]PeekingDevice) + result := make(map[string][]types.PeekingDevice) for rows.Next() { var roomID, userID, deviceID string if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { return nil, err } devices := result[roomID] - devices = append(devices, PeekingDevice{userID, deviceID}) + devices = append(devices, types.PeekingDevice{userID, deviceID}) result[roomID] = devices } return result, nil diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 81197bb76..5db9939a3 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -75,6 +75,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + peeks, err := NewSqlitePeeksTable(d.db, &d.streamID) + if err != nil { + return err + } topology, err := NewSqliteTopologyTable(d.db) if err != nil { return err @@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) { DB: d.db, Writer: sqlutil.NewExclusiveWriter(), Invites: invites, + Peeks: peeks, AccountData: accountData, OutputEvents: events, BackwardExtremities: bwExtrem, diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 9566b8b38..b7281f11c 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -42,8 +42,9 @@ type Invites interface { type Peeks interface { InsertPeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) - SelectPeeks(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (peeks []string, err error) - SelectPeekingDevices((ctxt context.Context) (peekingDevices map[string][]PeekingDevice, err error) + SelectPeeks(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (peeks []types.Peek, err error) + SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) + MarkPeeksAsOld(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (err error) } type Events interface { diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index f09c2b4d8..e6f7440e0 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -221,7 +221,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { // setPeekingDevices marks the given devices as peeking in the given rooms, such that new events from // these rooms will wake the given devices' /sync requests. This should be called prior to ANY calls to // OnNewEvent (eg on startup) to prevent racing. -func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]PeekingDevices) { +func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]types.PeekingDevice) { // This is just the bulk form of addPeekingDevice for roomID, peekingDevices := range roomIDToPeekingDevices { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { @@ -235,7 +235,7 @@ func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]Peeking // wakeupUsers will wake up the sync strems for all of the devices for all of the // specified user IDs, and also the specified peekingDevices -func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []PeekingDevice, newPos types.StreamingToken) { +func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingDevice, newPos types.StreamingToken) { for _, userID := range userIDs { for _, stream := range n.fetchUserStreams(userID) { if stream == nil { @@ -248,7 +248,7 @@ func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []PeekingDevice, if peekingDevices != nil { for _, peekingDevice := range peekingDevices { // TODO: don't bother waking up for devices whose users we already woke up - if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { + if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.ID, false); stream != nil { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } @@ -337,7 +337,7 @@ func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } - n.roomIDToPeekingDevices[roomID].add(PeekingDevice{deviceID, userID}) + n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{deviceID, userID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -346,11 +346,11 @@ func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } // XXX: is this going to work as a key? - n.roomIDToPeekingDevices[roomID].remove(PeekingDevice{deviceID, userID}) + n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{deviceID, userID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only -func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []PeekingDevices) { +func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.PeekingDevice) { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { return } @@ -407,17 +407,17 @@ func (s userIDSet) values() (vals []string) { // A set of PeekingDevices, similar to userIDSet -type peekingDeviceSet map[PeekingDevice]bool +type peekingDeviceSet map[types.PeekingDevice]bool -func (s peekingDeviceSet) add(d PeekingDevice) { +func (s peekingDeviceSet) add(d types.PeekingDevice) { s[d] = true } -func (s peekingDeviceSet) remove(d PeekingDevice) { +func (s peekingDeviceSet) remove(d types.PeekingDevice) { delete(s, d) } -func (s peekingDeviceSet) values() (vals []PeekingDevice) { +func (s peekingDeviceSet) values() (vals []types.PeekingDevice) { for d := range s { vals = append(vals, d) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 80a010904..b9888a65f 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -517,5 +517,5 @@ type PeekingDevice struct { type Peek struct { RoomID string - New boolean + New bool } \ No newline at end of file