Add one time pseudoids to upload keys & sync endpoints
This commit is contained in:
parent
29cd14baf5
commit
038103ac7f
|
@ -97,6 +97,86 @@ type queryKeysRequest struct {
|
||||||
DeviceKeys map[string][]string `json:"device_keys"`
|
DeviceKeys map[string][]string `json:"device_keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type uploadKeysCryptoIDsRequest struct {
|
||||||
|
DeviceKeys json.RawMessage `json:"device_keys"`
|
||||||
|
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||||
|
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
|
||||||
|
var r uploadKeysCryptoIDsRequest
|
||||||
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
|
if resErr != nil {
|
||||||
|
return *resErr
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadReq := &api.PerformUploadKeysRequest{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
}
|
||||||
|
if r.DeviceKeys != nil {
|
||||||
|
uploadReq.DeviceKeys = []api.DeviceKeys{
|
||||||
|
{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.DeviceKeys,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.OneTimeKeys != nil {
|
||||||
|
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
||||||
|
{
|
||||||
|
DeviceID: device.ID,
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.OneTimeKeys,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.OneTimePseudoIDs != nil {
|
||||||
|
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{
|
||||||
|
{
|
||||||
|
UserID: device.UserID,
|
||||||
|
KeyJSON: r.OneTimePseudoIDs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var uploadRes api.PerformUploadKeysResponse
|
||||||
|
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
if uploadRes.Error != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(uploadRes.Error).Error("Failed to PerformUploadKeys")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(uploadRes.KeyErrors) > 0 {
|
||||||
|
util.GetLogger(req.Context()).WithField("key_errors", uploadRes.KeyErrors).Error("Failed to upload one or more keys")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 400,
|
||||||
|
JSON: uploadRes.KeyErrors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keyCount := make(map[string]int)
|
||||||
|
if len(uploadRes.OneTimeKeyCounts) > 0 {
|
||||||
|
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
|
||||||
|
}
|
||||||
|
pseudoIDCount := make(map[string]int)
|
||||||
|
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
||||||
|
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||||
|
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"`
|
||||||
|
}{keyCount, pseudoIDCount},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
func (r *queryKeysRequest) GetTimeout() time.Duration {
|
||||||
if r.Timeout == 0 {
|
if r.Timeout == 0 {
|
||||||
return 10 * time.Second
|
return 10 * time.Second
|
||||||
|
|
|
@ -1596,6 +1596,11 @@ func Setup(
|
||||||
return UploadKeys(req, userAPI, device)
|
return UploadKeys(req, userAPI, device)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
|
||||||
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||||
|
}, httputil.WithAllowGuests()),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
v3mux.Handle("/keys/query",
|
v3mux.Handle("/keys/query",
|
||||||
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return QueryKeys(req, userAPI, device)
|
return QueryKeys(req, userAPI, device)
|
||||||
|
|
|
@ -132,7 +132,7 @@ func SendPDUs(
|
||||||
err = json.Unmarshal(pdu.Content(), &membership)
|
err = json.Unmarshal(pdu.Content(), &membership)
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
util.GetLogger(req.Context()).Errorf("m.room.member event content invalid", pdu.Content(), pdu.EventID())
|
util.GetLogger(req.Context()).Errorf("m.room.member event (%s) content invalid: %v", pdu.EventID(), pdu.Content())
|
||||||
continue
|
continue
|
||||||
case membership.Membership == spec.Join:
|
case membership.Membership == spec.Join:
|
||||||
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
deviceUserID, err := spec.NewUserID(device.UserID, true)
|
||||||
|
|
|
@ -46,6 +46,16 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response
|
||||||
|
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||||
|
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.OTPseudoIDsCount = count.KeyCount
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
// be already filled in with join/leave information.
|
// be already filled in with join/leave information.
|
||||||
|
|
|
@ -50,7 +50,9 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||||
|
return userapi.OneTimePseudoIDsCount{}, nil
|
||||||
}
|
}
|
||||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -38,7 +38,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
|
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
|
||||||
|
return from
|
||||||
|
}
|
||||||
|
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||||
|
if err != nil {
|
||||||
|
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
|
||||||
return from
|
return from
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -280,6 +280,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
if err != nil && err != context.Canceled {
|
if err != nil && err != context.Canceled {
|
||||||
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
||||||
}
|
}
|
||||||
|
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||||
|
if err != nil && err != context.Canceled {
|
||||||
|
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
@ -112,6 +112,10 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||||
|
return userapi.OneTimePseudoIDsCount{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -365,6 +365,7 @@ type Response struct {
|
||||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||||
|
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Response) MarshalJSON() ([]byte, error) {
|
func (r Response) MarshalJSON() ([]byte, error) {
|
||||||
|
@ -427,6 +428,7 @@ func NewResponse() *Response {
|
||||||
res.DeviceLists = &DeviceLists{}
|
res.DeviceLists = &DeviceLists{}
|
||||||
res.ToDevice = &ToDeviceResponse{}
|
res.ToDevice = &ToDeviceResponse{}
|
||||||
res.DeviceListsOTKCount = map[string]int{}
|
res.DeviceListsOTKCount = map[string]int{}
|
||||||
|
res.OTPseudoIDsCount = map[string]int{}
|
||||||
|
|
||||||
return &res
|
return &res
|
||||||
}
|
}
|
||||||
|
|
|
@ -669,6 +669,7 @@ type UploadDeviceKeysAPI interface {
|
||||||
type SyncKeyAPI interface {
|
type SyncKeyAPI interface {
|
||||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||||
|
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError)
|
||||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -772,12 +773,25 @@ type OneTimeKeys struct {
|
||||||
KeyJSON map[string]json.RawMessage
|
KeyJSON map[string]json.RawMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimePseudoIDs struct {
|
||||||
|
// The user who owns this device
|
||||||
|
UserID string
|
||||||
|
// A map of algorithm:key_id => key JSON
|
||||||
|
KeyJSON map[string]json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
// Split a key in KeyJSON into algorithm and key ID
|
// Split a key in KeyJSON into algorithm and key ID
|
||||||
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||||
segments := strings.Split(keyIDWithAlgo, ":")
|
segments := strings.Split(keyIDWithAlgo, ":")
|
||||||
return segments[0], segments[1]
|
return segments[0], segments[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split a key in KeyJSON into algorithm and key ID
|
||||||
|
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||||
|
segments := strings.Split(keyIDWithAlgo, ":")
|
||||||
|
return segments[0], segments[1]
|
||||||
|
}
|
||||||
|
|
||||||
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
||||||
type OneTimeKeysCount struct {
|
type OneTimeKeysCount struct {
|
||||||
// The user who owns this device
|
// The user who owns this device
|
||||||
|
@ -792,12 +806,23 @@ type OneTimeKeysCount struct {
|
||||||
KeyCount map[string]int
|
KeyCount map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimePseudoIDsCount struct {
|
||||||
|
// The user who owns this device
|
||||||
|
UserID string
|
||||||
|
// algorithm to count e.g:
|
||||||
|
// {
|
||||||
|
// "pseudoid_curve25519": 10,
|
||||||
|
// }
|
||||||
|
KeyCount map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||||
type PerformUploadKeysRequest struct {
|
type PerformUploadKeysRequest struct {
|
||||||
UserID string // Required - User performing the request
|
UserID string // Required - User performing the request
|
||||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||||
DeviceKeys []DeviceKeys
|
DeviceKeys []DeviceKeys
|
||||||
OneTimeKeys []OneTimeKeys
|
OneTimeKeys []OneTimeKeys
|
||||||
|
OneTimePseudoIDs []OneTimePseudoIDs
|
||||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||||
// the display name for their respective device, and NOT to modify the keys. The key
|
// the display name for their respective device, and NOT to modify the keys. The key
|
||||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||||
|
@ -810,8 +835,9 @@ type PerformUploadKeysResponse struct {
|
||||||
// A fatal error when processing e.g database failures
|
// A fatal error when processing e.g database failures
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
// A map of user_id -> device_id -> Error for tracking failures.
|
// A map of user_id -> device_id -> Error for tracking failures.
|
||||||
KeyErrors map[string]map[string]*KeyError
|
KeyErrors map[string]map[string]*KeyError
|
||||||
OneTimeKeyCounts []OneTimeKeysCount
|
OneTimeKeyCounts []OneTimeKeysCount
|
||||||
|
OneTimePseudoIDCounts []OneTimePseudoIDsCount
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||||
|
|
|
@ -55,11 +55,19 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
||||||
if len(req.OneTimeKeys) > 0 {
|
if len(req.OneTimeKeys) > 0 {
|
||||||
a.uploadOneTimeKeys(ctx, req, res)
|
a.uploadOneTimeKeys(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
if len(req.OneTimePseudoIDs) > 0 {
|
||||||
|
a.uploadOneTimePseudoIDs(ctx, req, res)
|
||||||
|
}
|
||||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||||
|
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,6 +189,17 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) {
|
||||||
|
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return api.OneTimePseudoIDsCount{}, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *count, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -773,6 +792,61 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
|
if req.UserID == "" {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: "user ID missing",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(req.OneTimePseudoIDs) == 0 {
|
||||||
|
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if counts != nil {
|
||||||
|
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, key := range req.OneTimePseudoIDs {
|
||||||
|
// grab existing keys based on (user/algorithm/key ID)
|
||||||
|
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||||
|
i := 0
|
||||||
|
for keyIDWithAlgo := range key.KeyJSON {
|
||||||
|
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||||
|
if err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: "failed to query existing one-time pseudoIDs: " + err.Error(),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo := range existingKeys {
|
||||||
|
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||||
|
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// store one-time keys
|
||||||
|
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// collect counts
|
||||||
|
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||||
// if we only want to update the display names, we can skip the checks below
|
// if we only want to update the display names, we can skip the checks below
|
||||||
if onlyUpdateDisplayName {
|
if onlyUpdateDisplayName {
|
||||||
|
|
|
@ -175,6 +175,10 @@ type KeyDatabase interface {
|
||||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||||
|
|
||||||
|
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||||
|
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
|
191
userapi/storage/postgres/one_time_pseudoids_table.go
Normal file
191
userapi/storage/postgres/one_time_pseudoids_table.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var oneTimePseudoIDsSchema = `
|
||||||
|
-- Stores one-time pseudoIDs for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
-- Clobber based on 3-uple of user/key/algorithm.
|
||||||
|
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertPseudoIDsSQL = "" +
|
||||||
|
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
|
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" +
|
||||||
|
" DO UPDATE SET key_json = $5"
|
||||||
|
|
||||||
|
const selectOneTimePseudoIDsSQL = "" +
|
||||||
|
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||||
|
|
||||||
|
const selectPseudoIDsCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||||
|
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||||
|
" x GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimePseudoIDSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||||
|
|
||||||
|
const selectPseudoIDByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||||
|
|
||||||
|
const deleteOneTimePseudoIDsSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||||
|
|
||||||
|
type oneTimePseudoIDsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertPseudoIDsStmt *sql.Stmt
|
||||||
|
selectPseudoIDsStmt *sql.Stmt
|
||||||
|
selectPseudoIDsCountStmt *sql.Stmt
|
||||||
|
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||||
|
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||||
|
s := &oneTimePseudoIDsStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||||
|
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||||
|
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||||
|
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||||
|
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||||
|
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
result := make(map[string]json.RawMessage)
|
||||||
|
var (
|
||||||
|
algorithmWithID string
|
||||||
|
keyJSONStr string
|
||||||
|
)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||||
|
counts := &api.OneTimePseudoIDsCount{
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
counts := &api.OneTimePseudoIDsCount{
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -149,6 +149,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
otpid, err := NewPostgresOneTimePseudoIDsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewPostgresDeviceKeysTable(db)
|
dk, err := NewPostgresDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -172,6 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
OneTimePseudoIDsTable: otpid,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -65,6 +65,7 @@ type Database struct {
|
||||||
|
|
||||||
type KeyDatabase struct {
|
type KeyDatabase struct {
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
|
OneTimePseudoIDsTable tables.OneTimePseudoIDs
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
KeyChangesTable tables.KeyChanges
|
KeyChangesTable tables.KeyChanges
|
||||||
StaleDeviceListsTable tables.StaleDeviceLists
|
StaleDeviceListsTable tables.StaleDeviceLists
|
||||||
|
@ -945,6 +946,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
||||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) {
|
||||||
|
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||||
|
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
205
userapi/storage/sqlite3/one_time_pseudoids_table.go
Normal file
205
userapi/storage/sqlite3/one_time_pseudoids_table.go
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var oneTimePseudoIDsSchema = `
|
||||||
|
-- Stores one-time pseudoIDs for users
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
algorithm TEXT NOT NULL,
|
||||||
|
ts_added_secs BIGINT NOT NULL,
|
||||||
|
key_json TEXT NOT NULL,
|
||||||
|
-- Clobber based on 3-uple of user/key/algorithm.
|
||||||
|
UNIQUE (user_id, key_id, algorithm)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||||
|
`
|
||||||
|
|
||||||
|
const upsertPseudoIDsSQL = "" +
|
||||||
|
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||||
|
" VALUES ($1, $2, $3, $4, $5)" +
|
||||||
|
" ON CONFLICT (user_id, key_id, algorithm)" +
|
||||||
|
" DO UPDATE SET key_json = $5"
|
||||||
|
|
||||||
|
const selectOneTimePseudoIDsSQL = "" +
|
||||||
|
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1"
|
||||||
|
|
||||||
|
const selectPseudoIDsCountSQL = "" +
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||||
|
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||||
|
" x GROUP BY algorithm"
|
||||||
|
|
||||||
|
const deleteOneTimePseudoIDSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||||
|
|
||||||
|
const selectPseudoIDByAlgorithmSQL = "" +
|
||||||
|
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||||
|
|
||||||
|
const deleteOneTimePseudoIDsSQL = "" +
|
||||||
|
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||||
|
|
||||||
|
type oneTimePseudoIDsStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertPseudoIDsStmt *sql.Stmt
|
||||||
|
selectPseudoIDsStmt *sql.Stmt
|
||||||
|
selectPseudoIDsCountStmt *sql.Stmt
|
||||||
|
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||||
|
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||||
|
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||||
|
s := &oneTimePseudoIDsStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||||
|
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||||
|
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||||
|
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||||
|
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||||
|
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||||
|
|
||||||
|
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||||
|
for _, ka := range keyIDsWithAlgorithms {
|
||||||
|
wantSet[ka] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]json.RawMessage)
|
||||||
|
for rows.Next() {
|
||||||
|
var keyID string
|
||||||
|
var algorithm string
|
||||||
|
var keyJSONStr string
|
||||||
|
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keyIDWithAlgo := algorithm + ":" + keyID
|
||||||
|
if wantSet[keyIDWithAlgo] {
|
||||||
|
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||||
|
counts := &api.OneTimePseudoIDsCount{
|
||||||
|
UserID: userID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
|
||||||
|
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs,
|
||||||
|
) (*api.OneTimePseudoIDsCount, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
counts := &api.OneTimePseudoIDsCount{
|
||||||
|
UserID: keys.UserID,
|
||||||
|
KeyCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||||
|
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||||
|
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var algorithm string
|
||||||
|
var count int
|
||||||
|
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts.KeyCount[algorithm] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||||
|
) (map[string]json.RawMessage, error) {
|
||||||
|
var keyID string
|
||||||
|
var keyJSON string
|
||||||
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if keyJSON == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return map[string]json.RawMessage{
|
||||||
|
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -146,6 +146,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
otpid, err := NewSqliteOneTimePseudoIDsTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dk, err := NewSqliteDeviceKeysTable(db)
|
dk, err := NewSqliteDeviceKeysTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -169,6 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
||||||
|
|
||||||
return &shared.KeyDatabase{
|
return &shared.KeyDatabase{
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
|
OneTimePseudoIDsTable: otpid,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
KeyChangesTable: kc,
|
KeyChangesTable: kc,
|
||||||
StaleDeviceListsTable: sdl,
|
StaleDeviceListsTable: sdl,
|
||||||
|
|
|
@ -168,6 +168,14 @@ type OneTimeKeys interface {
|
||||||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OneTimePseudoIDs interface {
|
||||||
|
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||||
|
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||||
|
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||||
|
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
|
}
|
||||||
|
|
||||||
type DeviceKeys interface {
|
type DeviceKeys interface {
|
||||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
|
|
Loading…
Reference in a new issue