diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 6795498fe..eb2f9e24a 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -31,6 +31,7 @@ type KeyInternalAPI interface { PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) } // KeyError is returned if there was a problem performing/querying the server @@ -157,3 +158,16 @@ type QueryKeyChangesResponse struct { // Set if there was a problem handling the request. Error *KeyError } + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index bb8286635..3c8dff847 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -168,6 +168,17 @@ func (a *KeyInternalAPI) claimRemoteKeys( util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys") } +func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) { + count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + return + } + res.Count = *count +} + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) diff --git a/keyserver/inthttp/client.go b/keyserver/inthttp/client.go index 3f9690b51..b65cbdafb 100644 --- a/keyserver/inthttp/client.go +++ b/keyserver/inthttp/client.go @@ -31,6 +31,7 @@ const ( PerformClaimKeysPath = "/keyserver/performClaimKeys" QueryKeysPath = "/keyserver/queryKeys" QueryKeyChangesPath = "/keyserver/queryKeyChanges" + QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys" ) // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. @@ -108,6 +109,23 @@ func (h *httpKeyInternalAPI) QueryKeys( } } +func (h *httpKeyInternalAPI) QueryOneTimeKeys( + ctx context.Context, + request *api.QueryOneTimeKeysRequest, + response *api.QueryOneTimeKeysResponse, +) { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOneTimeKeys") + defer span.Finish() + + apiURL := h.apiURL + QueryOneTimeKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.Error = &api.KeyError{ + Err: err.Error(), + } + } +} + func (h *httpKeyInternalAPI) QueryKeyChanges( ctx context.Context, request *api.QueryKeyChangesRequest, diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index f3d2882c2..615b6f80e 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -58,6 +58,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryOneTimeKeysPath, + httputil.MakeInternalAPI("queryOneTimeKeys", func(req *http.Request) util.JSONResponse { + request := api.QueryOneTimeKeysRequest{} + response := api.QueryOneTimeKeysResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryOneTimeKeys(req.Context(), &request, &response) + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(QueryKeyChangesPath, httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse { request := api.QueryKeyChangesRequest{} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index fade75228..0e0158e58 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -29,6 +29,9 @@ type Database interface { // StoreOneTimeKeys persists the given one-time keys. StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index a9d05548b..df215d5a8 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 8c2534f5c..44cb0cc25 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -39,6 +39,10 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) ( return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys) } +func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index fecf533e6..b35407cd4 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d return result, rows.Err() } +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { now := time.Now().Unix() counts := &api.OneTimeKeysCount{ diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 8b89283f5..c6e43be45 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -24,6 +24,7 @@ import ( type OneTimeKeys interface { SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. // Returns an empty map if the key does not exist. diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 0272ffc06..66134d791 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -29,6 +29,20 @@ import ( const DeviceListLogName = "dl" +// DeviceOTKCounts adds one-time key counts to the /sync response +func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, deviceID string, res *types.Response) error { + var queryRes api.QueryOneTimeKeysResponse + keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{ + UserID: userID, + DeviceID: deviceID, + }, &queryRes) + if queryRes.Error != nil { + return queryRes.Error + } + res.DeviceListsOTKCount = queryRes.Count.KeyCount + return nil +} + // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. @@ -36,6 +50,7 @@ func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, res *types.Response, from, to types.StreamingToken, ) (hasNew bool, err error) { + // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) newlyLeftRooms := leftRooms(res) diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 2c3d154d5..af456af41 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -38,6 +38,9 @@ func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformCl func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) { } func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) { +} +func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) { + } type mockCurrentStateAPI struct { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b530b34d1..12c597bbe 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -192,8 +192,9 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { - res = types.NewResponse() +// nolint:gocyclo +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (*types.Response, error) { + res := types.NewResponse() since := types.NewStreamToken(0, 0, nil) if req.since != nil { @@ -213,17 +214,21 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) } if err != nil { - return + return res, err } accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter) if err != nil { - return + return res, err } res, err = rp.appendDeviceLists(res, req.device.UserID, since, latestPos) if err != nil { - return + return res, err + } + err = internal.DeviceOTKCounts(req.ctx, rp.keyAPI, req.device.UserID, req.device.ID, res) + if err != nil { + return res, err } // Before we return the sync response, make sure that we take action on @@ -233,7 +238,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea // Handle the updates and deletions in the database. err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since) if err != nil { - return + return res, err } } if len(events) > 0 { @@ -250,7 +255,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea } } - return + return res, err } func (rp *RequestPool) appendDeviceLists( diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 4761cce28..f465d9fff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -393,6 +393,7 @@ type Response struct { Changed []string `json:"changed,omitempty"` Left []string `json:"left,omitempty"` } `json:"device_lists,omitempty"` + DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"` } // NewResponse creates an empty response with initialised maps. @@ -411,6 +412,7 @@ func NewResponse() *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) + res.DeviceListsOTKCount = make(map[string]int) return &res }