Add one time pseudoids to upload keys & sync endpoints

This commit is contained in:
Devon Hudson 2023-10-18 22:01:16 -06:00
parent 29cd14baf5
commit 038103ac7f
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
18 changed files with 656 additions and 9 deletions

View file

@ -97,6 +97,86 @@ type queryKeysRequest struct {
DeviceKeys map[string][]string `json:"device_keys"`
}
type uploadKeysCryptoIDsRequest struct {
DeviceKeys json.RawMessage `json:"device_keys"`
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"`
}
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r uploadKeysCryptoIDsRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
uploadReq := &api.PerformUploadKeysRequest{
DeviceID: device.ID,
UserID: device.UserID,
}
if r.DeviceKeys != nil {
uploadReq.DeviceKeys = []api.DeviceKeys{
{
DeviceID: device.ID,
UserID: device.UserID,
KeyJSON: r.DeviceKeys,
},
}
}
if r.OneTimeKeys != nil {
uploadReq.OneTimeKeys = []api.OneTimeKeys{
{
DeviceID: device.ID,
UserID: device.UserID,
KeyJSON: r.OneTimeKeys,
},
}
}
if r.OneTimePseudoIDs != nil {
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{
{
UserID: device.UserID,
KeyJSON: r.OneTimePseudoIDs,
},
}
}
var uploadRes api.PerformUploadKeysResponse
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
}
if uploadRes.Error != nil {
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
if len(uploadRes.KeyErrors) > 0 {
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
return util.JSONResponse{
Code: 400,
JSON: uploadRes.KeyErrors,
}
}
keyCount := make(map[string]int)
if len(uploadRes.OneTimeKeyCounts) > 0 {
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
}
pseudoIDCount := make(map[string]int)
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
}
return util.JSONResponse{
Code: 200,
JSON: struct {
OTKCounts interface{} `json:"one_time_key_counts"`
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"`
}{keyCount, pseudoIDCount},
}
}
func (r *queryKeysRequest) GetTimeout() time.Duration {
if r.Timeout == 0 {
return 10 * time.Second

View file

@ -1596,6 +1596,11 @@ func Setup(
return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeysCryptoIDs(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, userAPI, device)

View file

@ -132,7 +132,7 @@ func SendPDUs(
err = json.Unmarshal(pdu.Content(), &membership)
switch {
case err != nil:
util.GetLogger(req.Context()).Errorf("m.room.member event content invalid", pdu.Content(), pdu.EventID())
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
continue
case membership.Membership == spec.Join:
deviceUserID, err := spec.NewUserID(device.UserID, true)

View file

@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return nil
}
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID)
if err != nil {
return err
}
res.OTPseudoIDsCount = count.KeyCount
return nil
}
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.

View file

@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
}
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil
}
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil

View file

@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
}
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
return from
}
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
return from
}

View file

@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
}
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
}
}
return util.JSONResponse{
Code: http.StatusOK,

View file

@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
return nil
}
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil
}
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
return nil
}

View file

@ -365,6 +365,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
}
func (r Response) MarshalJSON() ([]byte, error) {
@ -427,6 +428,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{}
res.OTPseudoIDsCount = map[string]int{}
return &res
}

View file

@ -669,6 +669,7 @@ type UploadDeviceKeysAPI interface {
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError)
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
@ -772,12 +773,25 @@ type OneTimeKeys struct {
KeyJSON map[string]json.RawMessage
}
type OneTimePseudoIDs struct {
// The user who owns this device
UserID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// OneTimeKeysCount represents the counts of one-time keys for a single device
type OneTimeKeysCount struct {
// The user who owns this device
@ -792,12 +806,23 @@ type OneTimeKeysCount struct {
KeyCount map[string]int
}
type OneTimePseudoIDsCount struct {
// The user who owns this device
UserID string
// algorithm to count e.g:
// {
// "pseudoid_curve25519": 10,
// }
KeyCount map[string]int
}
// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
OneTimePseudoIDs []OneTimePseudoIDs
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@ -812,6 +837,7 @@ type PerformUploadKeysResponse struct {
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
OneTimePseudoIDCounts []OneTimePseudoIDsCount
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain

View file

@ -55,11 +55,19 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
if len(req.OneTimePseudoIDs) > 0 {
a.uploadOneTimePseudoIDs(ctx, req, res)
}
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
if err != nil {
return err
}
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
return nil
}
@ -181,6 +189,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
return nil
}
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) {
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID)
if err != nil {
return api.OneTimePseudoIDsCount{}, &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
}
}
return *count, nil
}
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
@ -773,6 +792,61 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
}
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if len(req.OneTimePseudoIDs) == 0 {
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err),
}
}
if counts != nil {
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
}
return
}
for _, key := range req.OneTimePseudoIDs {
// grab existing keys based on (user/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0
for keyIDWithAlgo := range key.KeyJSON {
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time pseudoIDs: " + err.Error(),
})
continue
}
for keyIDWithAlgo := range existingKeys {
// if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
})
continue
}
}
// store one-time keys
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
// collect counts
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
}
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
// if we only want to update the display names, we can skip the checks below
if onlyUpdateDisplayName {

View file

@ -175,6 +175,10 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -0,0 +1,191 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimePseudoIDsSchema = `
-- Stores one-time pseudoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm.
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
`
const upsertPseudoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" +
" DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
const selectPseudoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct {
db *sql.DB
upsertPseudoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt
}
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
s := &oneTimePseudoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimePseudoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
}.Prepare(db)
}
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
result := make(map[string]json.RawMessage)
var (
algorithmWithID string
keyJSONStr string
)
for rows.Next() {
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
return nil, err
}
result[algorithmWithID] = json.RawMessage(keyJSONStr)
}
return result, rows.Err()
}
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
otpid, err := NewPostgresOneTimePseudoIDsTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -65,6 +65,7 @@ type Database struct {
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
OneTimePseudoIDsTable tables.OneTimePseudoIDs
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
@ -945,6 +946,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms)
}
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}

View file

@ -0,0 +1,205 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimePseudoIDsSchema = `
-- Stores one-time pseudoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
user_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm.
UNIQUE (user_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
`
const upsertPseudoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1"
const selectPseudoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct {
db *sql.DB
upsertPseudoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt
}
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
s := &oneTimePseudoIDsStatements{
db: db,
}
_, err := db.Exec(oneTimePseudoIDsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
}.Prepare(db)
}
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
wantSet[ka] = true
}
result := make(map[string]json.RawMessage)
for rows.Next() {
var keyID string
var algorithm string
var keyJSONStr string
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
return nil, err
}
keyIDWithAlgo := algorithm + ":" + keyID
if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
}
}
return result, rows.Err()
}
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs,
) (*api.OneTimePseudoIDsCount, error) {
now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
if err != nil {
return nil, err
}
if keyJSON == "" {
return nil, nil
}
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
return err
}

View file

@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
otpid, err := NewSqliteOneTimePseudoIDsTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,

View file

@ -168,6 +168,14 @@ type OneTimeKeys interface {
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
}
type OneTimePseudoIDs interface {
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error
}
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error