diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index cb68fe196..3c7421bb2 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -41,7 +41,7 @@ func JoinRoomByIDOrAlias( } joinRes := roomserverAPI.PerformJoinResponse{} - // If content was provided in the request then incude that + // If content was provided in the request then include that // in the request. It'll get used as a part of the membership // event content. _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go new file mode 100644 index 000000000..98983c60a --- /dev/null +++ b/clientapi/routing/peekroom.go @@ -0,0 +1,71 @@ +// Copyright 2020 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func PeekRoomByIDOrAlias( + req *http.Request, + device *api.Device, + rsAPI roomserverAPI.RoomserverInternalAPI, + accountDB accounts.Database, + roomIDOrAlias string, +) util.JSONResponse { + // if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to + // to call /peek and /state on the remote server. + // TODO: in future we could skip this if we know we're already participating in the room, + // but this is fiddly in case we stop participating in the room. + + // then we create a local peek. + peekReq := roomserverAPI.PerformPeekRequest{ + RoomIDOrAlias: roomIDOrAlias, + UserID: device.UserID, + DeviceID: device.ID, + } + peekRes := roomserverAPI.PerformPeekResponse{} + + // Ask the roomserver to perform the join. + rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes) + if peekRes.Error != nil { + return peekRes.Error.JSONResponse() + } + + // if this user is already joined to the room, we let them peek anyway + // (given they might be about to part the room, and it makes things less fiddly) + + + // Peeking stops if none of the devices who started peeking have been + // /syncing for a while, or if everyone who was peeking calls /leave + // (or /unpeek with a server_name param? or DELETE /peek?) + // on the peeked room. + + return util.JSONResponse{ + Code: http.StatusOK, + // TODO: Put the response struct somewhere internal. + JSON: struct { + RoomID string `json:"room_id"` + }{joinRes.RoomID}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index c259e5293..69c76cf93 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -103,6 +103,17 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/peek/{roomIDOrAlias}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return PeekRoomByIDOrAlias( + req, device, rsAPI, accountDB, vars["roomIDOrAlias"], + ) + }), + ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, stateAPI) diff --git a/docs/peeking.md b/docs/peeking.md new file mode 100644 index 000000000..35f1d9d83 --- /dev/null +++ b/docs/peeking.md @@ -0,0 +1,19 @@ +## Peeking + +Peeking is implemented as per [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753). + +Implementationwise, this means: + * Users call `/peek` and `/unpeek` on the clientapi from a given device. + * The clientapi delegates these via HTTP to the roomserver, which coordinates peeking in general for a given room + * The roomserver writes an NewPeek event into the kafka log headed to the syncserver + * The syncserver tracks the existence of the local peek in its DB, and then starts waking up the peeking devices for the room in question, putting it in the `peeking` section of the /sync response. + +Questions (given this is [my](https://github.com/ara4n) first time hacking on Dendrite): + * The whole clientapi -> roomserver -> syncapi flow to initiate a peek seems very indirect. Is there a reason not to just let syncapi itself host the implementation of `/peek`? + +In future, peeking over federation will be added as per [MSC2444](https://github.com/matrix-org/matrix-doc/pull/2444). + * The `roomserver` will kick the `federationsender` much as it does for a federated `/join` in order to trigger a federated `/peek` + * The `federationsender` tracks the existence of the remote peek in question + * The `federationsender` regularly renews the remote peek as long as there are still peeking devices syncing for it. + * TBD: how do we tell if there are no devices currently syncing for a given peeked room? The syncserver needs to tell the roomserver + somehow who then needs to warn the federationsender. \ No newline at end of file diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 0fe30b8b5..3b2d4bd77 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -36,6 +36,12 @@ type RoomserverInternalAPI interface { res *PerformLeaveResponse, ) error + PerformPeek( + ctx context.Context, + req *PerformPeekRequest, + res *PerformPeekResponse, + ) + PerformPublish( ctx context.Context, req *PerformPublishRequest, diff --git a/roomserver/api/output.go b/roomserver/api/output.go index d6c09f9e8..d74a37b36 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -46,6 +46,9 @@ const ( // - Redact the event and set the corresponding `unsigned` fields to indicate it as redacted. // - Replace the event in the database. OutputTypeRedactedEvent OutputType = "redacted_event" + + // OutputTypeNewPeek indicates that the kafka event is an OutputNewPeek + OutputTypeNewPeek OutputType = "new_peek" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -59,8 +62,10 @@ type OutputEvent struct { NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"` // The content of event with type OutputTypeRetireInviteEvent RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"` - // The content of event with type OutputTypeRedactedEvent + // The content of event with type OutputTypeRedactedEvent RedactedEvent *OutputRedactedEvent `json:"redacted_event,omitempty"` + // The content of event with type OutputTypeNewPeek + NewPeek *OutputNewPeek `json:"new_peek,omitempty"` } // An OutputNewRoomEvent is written when the roomserver receives a new event. @@ -195,3 +200,11 @@ type OutputRedactedEvent struct { // The value of `unsigned.redacted_because` - the redaction event itself RedactedBecause gomatrixserverlib.HeaderedEvent } + +// An OutputNewPeek is written whenever a user starts peeking into a room +// using a given device. +type OutputNewPeek struct { + RoomID string + UserID string + DeviceID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 24e958bb4..0c2d96a7d 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -108,6 +108,20 @@ type PerformInviteResponse struct { Error *PerformError } +type PerformPeekRequest struct { + RoomIDOrAlias string `json:"room_id_or_alias"` + UserID string `json:"user_id"` + DeviceID string `json:"device_id"` + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformPeekResponse struct { + // The room ID, populated on success. + RoomID string `json:"room_id"` + // If non-nil, the join request failed. Contains more information why it failed. + Error *PerformError +} + // PerformBackfillRequest is a request to PerformBackfill. type PerformBackfillRequest struct { // The room to backfill diff --git a/roomserver/internal/perform_peek.go b/roomserver/internal/perform_peek.go new file mode 100644 index 000000000..2be224e3b --- /dev/null +++ b/roomserver/internal/perform_peek.go @@ -0,0 +1,171 @@ +// Copyright 2020 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationsender. +func (r *RoomserverInternalAPI) PerformPeek( + ctx context.Context, + req *api.PerformPeekRequest, + res *api.PerformPeekResponse, +) { + roomID, err := r.performPeek(ctx, req) + if err != nil { + perr, ok := err.(*api.PerformError) + if ok { + res.Error = perr + } else { + res.Error = &api.PerformError{ + Msg: err.Error(), + } + } + } + res.RoomID = roomID +} + +func (r *RoomserverInternalAPI) performPeek( + ctx context.Context, + req *api.PerformPeekRequest, +) (string, error) { + // FIXME: there's way too much duplication with performJoin + _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID), + } + } + if domain != r.Cfg.Matrix.ServerName { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID), + } + } + if strings.HasPrefix(req.RoomIDOrAlias, "!") { + return r.performPeekRoomByID(ctx, req) + } + if strings.HasPrefix(req.RoomIDOrAlias, "#") { + return r.performPeekRoomByAlias(ctx, req) + } + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias), + } +} + +func (r *RoomserverInternalAPI) performPeekRoomByAlias( + ctx context.Context, + req *api.PerformJoinRequest, +) (string, error) { + // Get the domain part of the room alias. + _, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias) + if err != nil { + return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias) + } + req.ServerNames = append(req.ServerNames, domain) + + // Check if this alias matches our own server configuration. If it + // doesn't then we'll need to try a federated peek. + var roomID string + if domain != r.Cfg.Matrix.ServerName { + // The alias isn't owned by us, so we will need to try peeking using + // a remote server. + dirReq := fsAPI.PerformDirectoryLookupRequest{ + RoomAlias: req.RoomIDOrAlias, // the room alias to lookup + ServerName: domain, // the server to ask + } + dirRes := fsAPI.PerformDirectoryLookupResponse{} + err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) + if err != nil { + logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) + return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) + } + roomID = dirRes.RoomID + req.ServerNames = append(req.ServerNames, dirRes.ServerNames...) + } else { + // Otherwise, look up if we know this room alias locally. + roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) + if err != nil { + return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err) + } + } + + // If the room ID is empty then we failed to look up the alias. + if roomID == "" { + return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias) + } + + // If we do, then pluck out the room ID and continue the peek. + req.RoomIDOrAlias = roomID + return r.performPeekRoomByID(ctx, req) +} + +func (r *RoomserverInternalAPI) performPeekRoomByID( + ctx context.Context, + req *api.PerformPeekRequest, +) (roomID string, err error) { + roomID = req.RoomIDOrAlias + + // Get the domain part of the room ID. + _, domain, err := gomatrixserverlib.SplitID('!', roomID) + if err != nil { + return "", &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err), + } + } + + // If the server name in the room ID isn't ours then it's a + // possible candidate for finding the room via federation. Add + // it to the list of servers to try. + if domain != r.Cfg.Matrix.ServerName { + req.ServerNames = append(req.ServerNames, domain) + } + + // TODO: handle federated peeks + + err := r.WriteOutputEvents(roomID, []api.OutputEvent{ + { + Type: api.OutputTypeNewPeek, + NewPeek: &api.OutputNewPeek{ + RoomID: roomID, + UserID: req.UserID, + DeviceID: req.DeviceID, + }, + }, + }) + if err != nil { + return + } + + // By this point, if req.RoomIDOrAlias contained an alias, then + // it will have been overwritten with a room ID by performPeekRoomByAlias. + // We should now include this in the response so that the CS API can + // return the right room ID. + return +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index bf231d099..da656a4cf 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -99,6 +99,8 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { return s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent) case api.OutputTypeRetireInviteEvent: return s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent) + case api.OutputTypeNewPeek: + return s.onNewPeek(context.TODO(), *output.NewPeek) case api.OutputTypeRedactedEvent: return s.onRedactEvent(context.TODO(), *output.RedactedEvent) default: @@ -218,6 +220,26 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent( return nil } +func (s *OutputRoomEventConsumer) onNewPeek( + ctx context.Context, msg api.OutputNewPeek, +) error { + sp, err := s.db.AddPeek(ctx, msg.RoomID, msg.UserID, msg.DeviceID) + if err != nil { + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + log.ErrorKey: err, + }).Panicf("roomserver output log: write peek failure") + return nil + } + // tell the notifier about the new peek so it knows to wake up new devices + s.notifier.OnNewPeek(msg.RoomID, msg.UserID, msg.DeviceID) + + // we need to wake up the users who might need to now be peeking into this room, + // so we send in a dummy event to trigger a wakeup + s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil)) + return nil +} + func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a5e13b674..952122952 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -30,6 +30,8 @@ type Database interface { internal.PartitionStorer // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. + AllPeekingDevicesInRooms(ctx context.Context) (map[string][]PeekingDevice, error) // Events lookups a list of event by their event ID. // Returns a list of events matching the requested IDs found in the database. // If an event is not found in the database then it will be omitted from the list. diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index ad0c1d996..6f7efefb7 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -120,6 +120,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri return d.CurrentRoomState.SelectJoinedUsers(ctx) } +func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]PeekingDevice, error) { + return d.Peeks.SelectPeekingDevices(ctx) +} + func (d *Database) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { @@ -187,6 +191,19 @@ func (d *Database) RetireInviteEvent( return } +// AddPeek tracks the fact that a user has started peeking. +// If the peek was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddPeek( + ctx context.Context, roomID, userID, deviceID string, +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Peeks.InsertPeek(ctx, nil, inviteEvent) + return nil + }) + return +} + // GetAccountDataInRange returns all account data for a given user inserted or // updated between two given positions // Returns a map following the format data[roomID] = []dataTypes diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go new file mode 100644 index 000000000..76df6dee3 --- /dev/null +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -0,0 +1,151 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const peeksSchema = ` +CREATE TABLE IF NOT EXISTS syncapi_peeks ( + id INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + -- When the peek was created in UNIX epoch ms. + creation_ts INTEGER NOT NULL, +); + +CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id); +CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_Id, device_id); +` + +const insertPeekSQL = "" + + "INSERT INTO syncapi_peeks" + + " (id, room_id, user_id, device_id, creation_ts" + + " VALUES ($1, $2, $3, $4, $5)" + +const deletePeekSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2 and device_id = $3" + +const selectPeeksSQL == "" + + "SELECT room_id FROM syncapi_peeks WHERE user_id = $1 and device_id = $2" + +const selectPeekingDevicesSQL == "" + + "SELECT room_id, user_id, device_id FROM syncapi_peeks" + +type peekStatements struct { + db *sql.DB + insertPeekStmt *sql.Stmt + deletePeekStmt *sql.Stmt + selectPeeksStmt *sql.Stmt + selectPeekingDevicesStmt *sql.Stmt +} + +func NewSqlitePeeksTable(db *sql.DB) (tables.Peeks, error) { + _, err := db.Exec(filterSchema) + if err != nil { + return nil, err + } + s := &peekStatements{ + db: db, + } + if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { + return nil, err + } + if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { + return nil, err + } + if s.selectPeeksStmt, err = db.Prepare(selectPeeksSQL); err != nil { + return nil, err + } + if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *peekStatements) InsertPeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + nowMilli := time.Now().UnixNano() / int64(time.Millisecond) + _, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, roomID, userID, deviceID, nowMilli) + return +} + +func (s *peekStatements) DeletePeek( + ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + _, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, roomID, userID, deviceID) + return +} + +func (s *peekStatements) SelectPeeks( + ctx context.Context, txn *sql.Tx, userID, deviceID string, +) (roomIDs []string, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectPeeksStmt).QueryContext(ctx, userID, deviceID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeks: rows.close() failed") + + for rows.Next() { + var roomID string + if err = rows.Scan(&roomId); err != nil { + return + } + roomIDs = append(roomIDs, roomID) + } + + return roomIDs, rows.Err() +} + +func (s *peekStatements) SelectPeekingDevices( + ctx context.Context, +) (peekingDevices map[string][]PeekingDevice, err error) { + rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed") + + result := make(map[string][]PeekingDevice) + for rows.Next() { + var roomID, userID, deviceID string + if err := rows.Scan(&roomID, &userID, &deviceID); err != nil { + return nil, err + } + devices := result[roomID] + devices = append(devices, PeekingDevice{userID, deviceID}) + result[roomID] = devices + } + return result, nil +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2ff229cbc..9566b8b38 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,13 @@ type Invites interface { SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) } +type Peeks interface { + InsertPeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) + DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) + SelectPeeks(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (peeks []string, err error) + SelectPeekingDevices((ctxt context.Context) (peekingDevices map[string][]PeekingDevice, err error) +} + type Events interface { SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index df23a2f4a..e5de78f6b 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -33,6 +33,8 @@ import ( type Notifier struct { // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine roomIDToJoinedUsers map[string]userIDSet + // A map of RoomID => Set : Must only be accessed by the OnNewEvent goroutine + roomIDToPeekingDevices map[string]PeekingDeviceSet // Protects currPos and userStreams. streamLock *sync.Mutex // The latest sync position @@ -48,11 +50,11 @@ type Notifier struct { // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). func NewNotifier(pos types.StreamingToken) *Notifier { return &Notifier{ - currPos: pos, - roomIDToJoinedUsers: make(map[string]userIDSet), - userDeviceStreams: make(map[string]map[string]*UserDeviceStream), - streamLock: &sync.Mutex{}, - lastCleanUpTime: time.Now(), + currPos: pos, + roomIDToJoinedUsers: make(map[string]userIDSet), + roomIDToPeekingDevices: make(map[string]PeekingDeviceSet), + userDeviceStreams: make(map[string]map[string]*UserDeviceStream), + streamLock: &sync.Mutex{}, } } @@ -82,6 +84,8 @@ func (n *Notifier) OnNewEvent( if ev != nil { // Map this event's room_id to a list of joined users, and wake them up. usersToNotify := n.joinedUsers(ev.RoomID()) + // Map this event's room_id to a list of peeking devices, and wake them up. + peekingDevicesToNotify := n.PeekingDevices(ev.RoomID()) // If this is an invite, also add in the invitee to this list. if ev.Type() == "m.room.member" && ev.StateKey() != nil { targetUserID := *ev.StateKey() @@ -108,11 +112,11 @@ func (n *Notifier) OnNewEvent( } } - n.wakeupUsers(usersToNotify, latestPos) + n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos) } else if roomID != "" { - n.wakeupUsers(n.joinedUsers(roomID), latestPos) + n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos) } else if len(userIDs) > 0 { - n.wakeupUsers(userIDs, latestPos) + n.wakeupUsers(userIDs, nil, latestPos) } else { log.WithFields(log.Fields{ "posUpdate": posUpdate.String, @@ -120,6 +124,18 @@ func (n *Notifier) OnNewEvent( } } +func (n *Notifier) OnNewPeek( + roomID, userID, deviceID string, +) { + n.streamLock.Lock() + defer n.streamLock.Unlock() + + n.addPeekingDevice(roomID, userID, deviceID) + + // we don't wake up devices here given the roomserver consumer will do this shortly afterwards + // by calling OnNewEvent. +} + func (n *Notifier) OnNewSendToDevice( userID string, deviceIDs []string, posUpdate types.StreamingToken, @@ -139,7 +155,7 @@ func (n *Notifier) OnNewKeyChange( defer n.streamLock.Unlock() latestPos := n.currPos.WithUpdates(posUpdate) n.currPos = latestPos - n.wakeupUsers([]string{wakeUserID}, latestPos) + n.wakeupUsers([]string{wakeUserID}, nil, latestPos) } // GetListener returns a UserStreamListener that can be used to wait for @@ -169,6 +185,13 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { return err } n.setUsersJoinedToRooms(roomToUsers) + + roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx) + if err != nil { + return err + } + n.setPeekingDevices(roomToPeekingDevices) + return nil } @@ -195,9 +218,24 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } +// setPeekingDevices marks the given devices as peeking in the given rooms, such that new events from +// these rooms will wake the given devices' /sync requests. This should be called prior to ANY calls to +// OnNewEvent (eg on startup) to prevent racing. +func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]PeekingDevices) { + // This is just the bulk form of addPeekingDevice + for roomID, peekingDevices := range roomIDToPeekingDevices { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(PeekingDeviceSet) + } + for _, peekingDevice := range peekingDevices { + n.roomIDToPeekingDevices[roomID].add(peekingDevice) + } + } +} + // wakeupUsers will wake up the sync strems for all of the devices for all of the -// specified user IDs. -func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { +// specified user IDs, and also the specified peekingDevices +func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []PeekingDevice, newPos types.StreamingToken) { for _, userID := range userIDs { for _, stream := range n.fetchUserStreams(userID) { if stream == nil { @@ -206,6 +244,15 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } + + if peekingDevices != nil { + for _, peekingDevice := range peekingDevices { + // TODO: don't bother waking up for devices whose users we already woke up + if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { + stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream + } + } + } } // wakeupUserDevice will wake up the sync stream for a specific user device. Other @@ -284,6 +331,34 @@ func (n *Notifier) joinedUsers(roomID string) (userIDs []string) { return n.roomIDToJoinedUsers[roomID].values() } + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(PeekingDeviceSet) + } + n.roomIDToPeekingDevices[roomID].add(PeekingDevice{deviceID, userID}) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + n.roomIDToPeekingDevices[roomID] = make(PeekingDeviceSet) + } + // XXX: is this going to work as a key? + n.roomIDToPeekingDevices[roomID].remove(PeekingDevice{deviceID, userID}) +} + +// Not thread-safe: must be called on the OnNewEvent goroutine only +func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []PeekingDevices) { + if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { + return + } + return n.roomIDToPeekingDevices[roomID].values() +} + + + // removeEmptyUserStreams iterates through the user stream map and removes any // that have been empty for a certain amount of time. This is a crude way of // ensuring that the userStreams map doesn't grow forver. diff --git a/syncapi/types/types.go b/syncapi/types/types.go index f3324800f..0d3be2eff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -507,3 +507,27 @@ type SendToDeviceEvent struct { DeviceID string SentByToken *StreamingToken } + +// For tracking peeking devices + +type PeekingDevice struct { + ID string + UserID string +} + +type PeekingDeviceSet map[PeekingDevice]bool + +func (s PeekingDeviceSet) add(d PeekingDevice) { + s[d] = true +} + +func (s PeekingDeviceSet) remove(d PeekingDevice) { + delete(s, d) +} + +func (s PeekingDeviceSet) values() (vals []PeekingDevice) { + for d := range s { + vals = append(vals, d) + } + return +}