Fix one time pseudoids in upload keys & sync endpoints

This commit is contained in:
Devon Hudson 2023-10-25 16:23:46 -06:00
parent 038103ac7f
commit 60be1391bf
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
6 changed files with 50 additions and 2 deletions

View file

@ -141,6 +141,12 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
}
}
util.GetLogger(req.Context()).
WithField("device keys", r.DeviceKeys).
WithField("one-time keys", r.OneTimeKeys).
WithField("one-time pseudoids", r.OneTimePseudoIDs).
Info("Uploading keys")
var uploadRes api.PerformUploadKeysResponse
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err)
@ -166,7 +172,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
}
pseudoIDCount := make(map[string]int)
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
}
return util.JSONResponse{
Code: 200,

View file

@ -1598,6 +1598,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc_cryptoids/keys/upload")
return UploadKeysCryptoIDs(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)

View file

@ -269,6 +269,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer userStreamListener.Close()
giveup := func() util.JSONResponse {
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
syncReq.Response.NextBatch = syncReq.Since
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
@ -284,6 +285,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
}
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
}
return util.JSONResponse{
Code: http.StatusOK,

View file

@ -365,7 +365,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
}
func (r Response) MarshalJSON() ([]byte, error) {

View file

@ -58,6 +58,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimePseudoIDs) > 0 {
a.uploadOneTimePseudoIDs(ctx, req, res)
}
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
@ -68,6 +69,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
return err
}
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
return nil
}
@ -806,6 +808,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
}
}
if counts != nil {
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
}
return
@ -843,6 +846,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
continue
}
// collect counts
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
}
}

View file

@ -1,6 +1,7 @@
package storage_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
})
}
func TestOneTimePseudoIDs(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
otk := api.OneTimePseudoIDs{
UserID: userID,
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
// Add a one time pseudoID to the DB
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
MustNotError(t, err)
// Check the count of one time pseudoIDs is correct
count, err := db.OneTimePseudoIDsCount(ctx, userID)
MustNotError(t, err)
if count.KeyCount["pseudoid_curve25519"] != 1 {
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
}
// Check the actual pseudoid contents are correct
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
MustNotError(t, err)
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
MustNotError(t, err)
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
}
})
}