diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 72785cda8..afd193a95 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -97,6 +97,86 @@ type queryKeysRequest struct { DeviceKeys map[string][]string `json:"device_keys"` } +type uploadKeysCryptoIDsRequest struct { + DeviceKeys json.RawMessage `json:"device_keys"` + OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` + OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"` +} + +func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { + var r uploadKeysCryptoIDsRequest + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + uploadReq := &api.PerformUploadKeysRequest{ + DeviceID: device.ID, + UserID: device.UserID, + } + if r.DeviceKeys != nil { + uploadReq.DeviceKeys = []api.DeviceKeys{ + { + DeviceID: device.ID, + UserID: device.UserID, + KeyJSON: r.DeviceKeys, + }, + } + } + if r.OneTimeKeys != nil { + uploadReq.OneTimeKeys = []api.OneTimeKeys{ + { + DeviceID: device.ID, + UserID: device.UserID, + KeyJSON: r.OneTimeKeys, + }, + } + } + if r.OneTimePseudoIDs != nil { + uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{ + { + UserID: device.UserID, + KeyJSON: r.OneTimePseudoIDs, + }, + } + } + + var uploadRes api.PerformUploadKeysResponse + if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil { + return util.ErrorResponse(err) + } + if uploadRes.Error != nil { + util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if len(uploadRes.KeyErrors) > 0 { + util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys") + return util.JSONResponse{ + Code: 400, + JSON: uploadRes.KeyErrors, + } + } + + keyCount := make(map[string]int) + if len(uploadRes.OneTimeKeyCounts) > 0 { + keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount + } + pseudoIDCount := make(map[string]int) + if len(uploadRes.OneTimePseudoIDCounts) > 0 { + keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount + } + return util.JSONResponse{ + Code: 200, + JSON: struct { + OTKCounts interface{} `json:"one_time_key_counts"` + OTPIDCounts interface{} `json:"one_time_pseudoid_counts"` + }{keyCount, pseudoIDCount}, + } +} + func (r *queryKeysRequest) GetTimeout() time.Duration { if r.Timeout == 0 { return 10 * time.Second diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 8f5d69567..4ba1a3782 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1596,6 +1596,11 @@ func Setup( return UploadKeys(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) + unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload", + httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return UploadKeysCryptoIDs(req, userAPI, device) + }, httputil.WithAllowGuests()), + ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, userAPI, device) diff --git a/clientapi/routing/send_pdus.go b/clientapi/routing/send_pdus.go index d4cd4439d..2a3e5274f 100644 --- a/clientapi/routing/send_pdus.go +++ b/clientapi/routing/send_pdus.go @@ -132,7 +132,7 @@ func SendPDUs( err = json.Unmarshal(pdu.Content(), &membership) switch { case err != nil: - util.GetLogger(req.Context()).Errorf("m.room.member event content invalid", pdu.Content(), pdu.EventID()) + util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content()) continue case membership.Membership == spec.Join: deviceUserID, err := spec.NewUserID(device.UserID, true) diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 24ffcc041..a24bf61e8 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI return nil } +// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response +func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error { + count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID) + if err != nil { + return err + } + res.OTPseudoIDsCount = count.KeyCount + return nil +} + // DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response // was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST // be already filled in with join/leave information. diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index 56954cfa0..ec5c9aa84 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC } func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil - +} +func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { + return userapi.OneTimePseudoIDsCount{}, nil } func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index e8189c352..fc02311be 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync( } err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response) if err != nil { - req.Log.WithError(err).Error("internal.DeviceListCatchup failed") + req.Log.WithError(err).Error("internal.DeviceOTKCounts failed") + return from + } + err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response) + if err != nil { + req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed") return from } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 5a92c70e1..28862937f 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } + err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response) + if err != nil && err != context.Canceled { + syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") + } } return util.JSONResponse{ Code: http.StatusOK, diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index ac5268511..a56c16b2a 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn return nil } +func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { + return userapi.OneTimePseudoIDsCount{}, nil +} + func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { return nil } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index bca11855c..2f57d5df6 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -365,6 +365,7 @@ type Response struct { ToDevice *ToDeviceResponse `json:"to_device,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` + OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"` } func (r Response) MarshalJSON() ([]byte, error) { @@ -427,6 +428,7 @@ func NewResponse() *Response { res.DeviceLists = &DeviceLists{} res.ToDevice = &ToDeviceResponse{} res.DeviceListsOTKCount = map[string]int{} + res.OTPseudoIDsCount = map[string]int{} return &res } diff --git a/userapi/api/api.go b/userapi/api/api.go index a0dce9758..56a409f6d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -669,6 +669,7 @@ type UploadDeviceKeysAPI interface { type SyncKeyAPI interface { QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError) PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } @@ -772,12 +773,25 @@ type OneTimeKeys struct { KeyJSON map[string]json.RawMessage } +type OneTimePseudoIDs struct { + // The user who owns this device + UserID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + // Split a key in KeyJSON into algorithm and key ID func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { segments := strings.Split(keyIDWithAlgo, ":") return segments[0], segments[1] } +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + // OneTimeKeysCount represents the counts of one-time keys for a single device type OneTimeKeysCount struct { // The user who owns this device @@ -792,12 +806,23 @@ type OneTimeKeysCount struct { KeyCount map[string]int } +type OneTimePseudoIDsCount struct { + // The user who owns this device + UserID string + // algorithm to count e.g: + // { + // "pseudoid_curve25519": 10, + // } + KeyCount map[string]int +} + // PerformUploadKeysRequest is the request to PerformUploadKeys type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + OneTimePseudoIDs []OneTimePseudoIDs // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // the display name for their respective device, and NOT to modify the keys. The key // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. @@ -810,8 +835,9 @@ type PerformUploadKeysResponse struct { // A fatal error when processing e.g database failures Error *KeyError // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount + OneTimePseudoIDCounts []OneTimePseudoIDsCount } // PerformDeleteKeysRequest asks the keyserver to forget about certain diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 786a2dcd8..ae342f9fb 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -55,11 +55,19 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } + if len(req.OneTimePseudoIDs) > 0 { + a.uploadOneTimePseudoIDs(ctx, req, res) + } otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} + otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) + if err != nil { + return err + } + res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs} return nil } @@ -181,6 +189,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn return nil } +func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) { + count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID) + if err != nil { + return api.OneTimePseudoIDsCount{}, &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + } + return *count, nil + +} + func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) if err != nil { @@ -773,6 +792,61 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor } +func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + if req.UserID == "" { + res.Error = &api.KeyError{ + Err: "user ID missing", + } + } + if len(req.OneTimePseudoIDs) == 0 { + counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err), + } + } + if counts != nil { + res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) + } + return + } + for _, key := range req.OneTimePseudoIDs { + // grab existing keys based on (user/algorithm/key ID) + keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) + i := 0 + for keyIDWithAlgo := range key.KeyJSON { + keyIDsWithAlgorithms[i] = keyIDWithAlgo + i++ + } + existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: "failed to query existing one-time pseudoIDs: " + err.Error(), + }) + continue + } + for keyIDWithAlgo := range existingKeys { + // if keys exist and the JSON doesn't match, error out as the key already exists + if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo), + }) + continue + } + } + // store one-time keys + counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()), + }) + continue + } + // collect counts + res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) + } +} + func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { // if we only want to update the display names, we can skip the checks below if onlyUpdateDisplayName { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 125b31585..af1da509c 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -175,6 +175,10 @@ type KeyDatabase interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) + OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/userapi/storage/postgres/one_time_pseudoids_table.go b/userapi/storage/postgres/one_time_pseudoids_table.go new file mode 100644 index 000000000..b83770669 --- /dev/null +++ b/userapi/storage/postgres/one_time_pseudoids_table.go @@ -0,0 +1,191 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimePseudoIDsSchema = ` +-- Stores one-time pseudoIDs for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( + user_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 3-uple of user/key/algorithm. + CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); +` + +const upsertPseudoIDsSQL = "" + + "INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" + + " DO UPDATE SET key_json = $5" + +const selectOneTimePseudoIDsSQL = "" + + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);" + +const selectPseudoIDsCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimePseudoIDSQL = "" + + "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" + +const selectPseudoIDByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" + +const deleteOneTimePseudoIDsSQL = "" + + "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" + +type oneTimePseudoIDsStatements struct { + db *sql.DB + upsertPseudoIDsStmt *sql.Stmt + selectPseudoIDsStmt *sql.Stmt + selectPseudoIDsCountStmt *sql.Stmt + selectPseudoIDByAlgorithmStmt *sql.Stmt + deleteOneTimePseudoIDStmt *sql.Stmt + deleteOneTimePseudoIDsStmt *sql.Stmt +} + +func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { + s := &oneTimePseudoIDsStatements{ + db: db, + } + _, err := db.Exec(oneTimePseudoIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, + {&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, + {&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, + {&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, + {&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, + {&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, + }.Prepare(db) +} + +func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") + + result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) + for rows.Next() { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { + return nil, err + } + result[algorithmWithID] = json.RawMessage(keyJSONStr) + } + return result, rows.Err() +} + +func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { + counts := &api.OneTimePseudoIDsCount{ + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) { + now := time.Now().Unix() + counts := &api.OneTimePseudoIDsCount{ + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( + ctx, keys.UserID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( + ctx context.Context, txn *sql.Tx, userID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index b4edc80a9..644a2f364 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } + otpid, err := NewPostgresOneTimePseudoIDsTable(db) + if err != nil { + return nil, err + } dk, err := NewPostgresDeviceKeysTable(db) if err != nil { return nil, err @@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, + OneTimePseudoIDsTable: otpid, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index b7acb2035..1a8c2a0d6 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -65,6 +65,7 @@ type Database struct { type KeyDatabase struct { OneTimeKeysTable tables.OneTimeKeys + OneTimePseudoIDsTable tables.OneTimePseudoIDs DeviceKeysTable tables.DeviceKeys KeyChangesTable tables.KeyChanges StaleDeviceListsTable tables.StaleDeviceLists @@ -945,6 +946,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } +func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { + return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID) +} + func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } diff --git a/userapi/storage/sqlite3/one_time_pseudoids_table.go b/userapi/storage/sqlite3/one_time_pseudoids_table.go new file mode 100644 index 000000000..abb71e09a --- /dev/null +++ b/userapi/storage/sqlite3/one_time_pseudoids_table.go @@ -0,0 +1,205 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimePseudoIDsSchema = ` +-- Stores one-time pseudoIDs for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( + user_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 3-uple of user/key/algorithm. + UNIQUE (user_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); +` + +const upsertPseudoIDsSQL = "" + + "INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT (user_id, key_id, algorithm)" + + " DO UPDATE SET key_json = $5" + +const selectOneTimePseudoIDsSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1" + +const selectPseudoIDsCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimePseudoIDSQL = "" + + "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" + +const selectPseudoIDByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" + +const deleteOneTimePseudoIDsSQL = "" + + "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" + +type oneTimePseudoIDsStatements struct { + db *sql.DB + upsertPseudoIDsStmt *sql.Stmt + selectPseudoIDsStmt *sql.Stmt + selectPseudoIDsCountStmt *sql.Stmt + selectPseudoIDByAlgorithmStmt *sql.Stmt + deleteOneTimePseudoIDStmt *sql.Stmt + deleteOneTimePseudoIDsStmt *sql.Stmt +} + +func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { + s := &oneTimePseudoIDsStatements{ + db: db, + } + _, err := db.Exec(oneTimePseudoIDsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, + {&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, + {&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, + {&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, + {&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, + {&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, + }.Prepare(db) +} + +func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { + counts := &api.OneTimePseudoIDsCount{ + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs( + ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs, +) (*api.OneTimePseudoIDsCount, error) { + now := time.Now().Unix() + counts := &api.OneTimePseudoIDsCount{ + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( + ctx, keys.UserID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( + ctx context.Context, txn *sql.Tx, userID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) + if err != nil { + return nil, err + } + if keyJSON == "" { + return nil, nil + } + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index fc13dde57..356920263 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } + otpid, err := NewSqliteOneTimePseudoIDsTable(db) + if err != nil { + return nil, err + } dk, err := NewSqliteDeviceKeysTable(db) if err != nil { return nil, err @@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, + OneTimePseudoIDsTable: otpid, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3a0be73e4..14c04b0f5 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -168,6 +168,14 @@ type OneTimeKeys interface { DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error } +type OneTimePseudoIDs interface { + SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) + InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) + SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error +} + type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error