Add one time pseudoids to upload keys & sync endpoints
This commit is contained in:
parent
29cd14baf5
commit
038103ac7f
|
@ -97,6 +97,86 @@ type queryKeysRequest struct {
|
|||
DeviceKeys map[string][]string `json:"device_keys"`
|
||||
}
|
||||
|
||||
type uploadKeysCryptoIDsRequest struct {
|
||||
DeviceKeys json.RawMessage `json:"device_keys"`
|
||||
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"`
|
||||
}
|
||||
|
||||
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
|
||||
var r uploadKeysCryptoIDsRequest
|
||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||
if resErr != nil {
|
||||
return *resErr
|
||||
}
|
||||
|
||||
uploadReq := &api.PerformUploadKeysRequest{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
}
|
||||
if r.DeviceKeys != nil {
|
||||
uploadReq.DeviceKeys = []api.DeviceKeys{
|
||||
{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.DeviceKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
if r.OneTimeKeys != nil {
|
||||
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
||||
{
|
||||
DeviceID: device.ID,
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.OneTimeKeys,
|
||||
},
|
||||
}
|
||||
}
|
||||
if r.OneTimePseudoIDs != nil {
|
||||
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{
|
||||
{
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.OneTimePseudoIDs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
if uploadRes.Error != nil {
|
||||
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
if len(uploadRes.KeyErrors) > 0 {
|
||||
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
|
||||
return util.JSONResponse{
|
||||
Code: 400,
|
||||
JSON: uploadRes.KeyErrors,
|
||||
}
|
||||
}
|
||||
|
||||
keyCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimeKeyCounts) > 0 {
|
||||
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
|
||||
}
|
||||
pseudoIDCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
||||
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"`
|
||||
}{keyCount, pseudoIDCount},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
||||
if r.Timeout == 0 {
|
||||
return 10 * time.Second
|
||||
|
|
|
@ -1596,6 +1596,11 @@ func Setup(
|
|||
return UploadKeys(req, userAPI, device)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
v3mux.Handle("/keys/query",
|
||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return QueryKeys(req, userAPI, device)
|
||||
|
|
|
@ -132,7 +132,7 @@ func SendPDUs(
|
|||
err = json.Unmarshal(pdu.Content(), &membership)
|
||||
switch {
|
||||
case err != nil:
|
||||
util.GetLogger(req.Context()).Errorf("m.room.member event content invalid", pdu.Content(), pdu.EventID())
|
||||
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
|
||||
continue
|
||||
case membership.Membership == spec.Join:
|
||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||
|
|
|
@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
|||
return nil
|
||||
}
|
||||
|
||||
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response
|
||||
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OTPseudoIDsCount = count.KeyCount
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||
// be already filled in with join/leave information.
|
||||
|
|
|
@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
|
|||
}
|
||||
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimePseudoIDsCount{}, nil
|
||||
}
|
||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
||||
return nil
|
||||
|
|
|
@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
|||
}
|
||||
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
||||
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
|
||||
return from
|
||||
}
|
||||
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
|
||||
return from
|
||||
}
|
||||
|
||||
|
|
|
@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
||||
}
|
||||
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
|
|
|
@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimePseudoIDsCount{}, nil
|
||||
}
|
||||
|
||||
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -365,6 +365,7 @@ type Response struct {
|
|||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
|
||||
}
|
||||
|
||||
func (r Response) MarshalJSON() ([]byte, error) {
|
||||
|
@ -427,6 +428,7 @@ func NewResponse() *Response {
|
|||
res.DeviceLists = &DeviceLists{}
|
||||
res.ToDevice = &ToDeviceResponse{}
|
||||
res.DeviceListsOTKCount = map[string]int{}
|
||||
res.OTPseudoIDsCount = map[string]int{}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
|
|
@ -669,6 +669,7 @@ type UploadDeviceKeysAPI interface {
|
|||
type SyncKeyAPI interface {
|
||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError)
|
||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
|
@ -772,12 +773,25 @@ type OneTimeKeys struct {
|
|||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
type OneTimePseudoIDs struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// A map of algorithm:key_id => key JSON
|
||||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
||||
type OneTimeKeysCount struct {
|
||||
// The user who owns this device
|
||||
|
@ -792,12 +806,23 @@ type OneTimeKeysCount struct {
|
|||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
type OneTimePseudoIDsCount struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// algorithm to count e.g:
|
||||
// {
|
||||
// "pseudoid_curve25519": 10,
|
||||
// }
|
||||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||
type PerformUploadKeysRequest struct {
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
OneTimePseudoIDs []OneTimePseudoIDs
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
|
@ -810,8 +835,9 @@ type PerformUploadKeysResponse struct {
|
|||
// A fatal error when processing e.g database failures
|
||||
Error *KeyError
|
||||
// A map of user_id -> device_id -> Error for tracking failures.
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
OneTimePseudoIDCounts []OneTimePseudoIDsCount
|
||||
}
|
||||
|
||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||
|
|
|
@ -55,11 +55,19 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimePseudoIDs) > 0 {
|
||||
a.uploadOneTimePseudoIDs(ctx, req, res)
|
||||
}
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -181,6 +189,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) {
|
||||
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID)
|
||||
if err != nil {
|
||||
return api.OneTimePseudoIDsCount{}, &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||
}
|
||||
}
|
||||
return *count, nil
|
||||
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
|
@ -773,6 +792,61 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
}
|
||||
}
|
||||
if len(req.OneTimePseudoIDs) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, key := range req.OneTimePseudoIDs {
|
||||
// grab existing keys based on (user/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time pseudoIDs: " + err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
for keyIDWithAlgo := range existingKeys {
|
||||
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
}
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
// if we only want to update the display names, we can skip the checks below
|
||||
if onlyUpdateDisplayName {
|
||||
|
|
|
@ -175,6 +175,10 @@ type KeyDatabase interface {
|
|||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
|
|
191
userapi/storage/postgres/one_time_pseudoids_table.go
Normal file
191
userapi/storage/postgres/one_time_pseudoids_table.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimePseudoIDsSchema = `
|
||||
-- Stores one-time pseudoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 3-uple of user/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||
`
|
||||
|
||||
const upsertPseudoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimePseudoIDsSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||
|
||||
const selectPseudoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimePseudoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectPseudoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimePseudoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||
|
||||
type oneTimePseudoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsCountStmt *sql.Stmt
|
||||
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||
s := &oneTimePseudoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewPostgresOneTimePseudoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimePseudoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -65,6 +65,7 @@ type Database struct {
|
|||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
OneTimePseudoIDsTable tables.OneTimePseudoIDs
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
|
@ -945,6 +946,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
|||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
|
205
userapi/storage/sqlite3/one_time_pseudoids_table.go
Normal file
205
userapi/storage/sqlite3/one_time_pseudoids_table.go
Normal file
|
@ -0,0 +1,205 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimePseudoIDsSchema = `
|
||||
-- Stores one-time pseudoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 3-uple of user/key/algorithm.
|
||||
UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||
`
|
||||
|
||||
const upsertPseudoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (user_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimePseudoIDsSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1"
|
||||
|
||||
const selectPseudoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimePseudoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectPseudoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimePseudoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||
|
||||
type oneTimePseudoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsCountStmt *sql.Stmt
|
||||
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||
s := &oneTimePseudoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs,
|
||||
) (*api.OneTimePseudoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewSqliteOneTimePseudoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimePseudoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -168,6 +168,14 @@ type OneTimeKeys interface {
|
|||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
}
|
||||
|
||||
type OneTimePseudoIDs interface {
|
||||
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
|
|
Loading…
Reference in a new issue