Fix one time pseudoids in upload keys & sync endpoints

This commit is contained in:
Devon Hudson 2023-10-25 16:23:46 -06:00
parent 038103ac7f
commit 60be1391bf
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
6 changed files with 50 additions and 2 deletions

View file

@ -141,6 +141,12 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
} }
} }
util.GetLogger(req.Context()).
WithField("device keys", r.DeviceKeys).
WithField("one-time keys", r.OneTimeKeys).
WithField("one-time pseudoids", r.OneTimePseudoIDs).
Info("Uploading keys")
var uploadRes api.PerformUploadKeysResponse var uploadRes api.PerformUploadKeysResponse
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil { if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -166,7 +172,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
} }
pseudoIDCount := make(map[string]int) pseudoIDCount := make(map[string]int)
if len(uploadRes.OneTimePseudoIDCounts) > 0 { if len(uploadRes.OneTimePseudoIDCounts) > 0 {
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,

View file

@ -1598,6 +1598,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload", unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc_cryptoids/keys/upload")
return UploadKeysCryptoIDs(req, userAPI, device) return UploadKeysCryptoIDs(req, userAPI, device)
}, httputil.WithAllowGuests()), }, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)

View file

@ -269,6 +269,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer userStreamListener.Close() defer userStreamListener.Close()
giveup := func() util.JSONResponse { giveup := func() util.JSONResponse {
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
syncReq.Response.NextBatch = syncReq.Since syncReq.Response.NextBatch = syncReq.Since
// We should always try to include OTKs in sync responses, otherwise clients might upload keys // We should always try to include OTKs in sync responses, otherwise clients might upload keys
@ -284,6 +285,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled { if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
} }
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -365,7 +365,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"` ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"` OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
} }
func (r Response) MarshalJSON() ([]byte, error) { func (r Response) MarshalJSON() ([]byte, error) {

View file

@ -58,6 +58,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimePseudoIDs) > 0 { if len(req.OneTimePseudoIDs) > 0 {
a.uploadOneTimePseudoIDs(ctx, req, res) a.uploadOneTimePseudoIDs(ctx, req, res)
} }
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil { if err != nil {
return err return err
@ -68,6 +69,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
return err return err
} }
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs} res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
return nil return nil
} }
@ -806,6 +808,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
} }
} }
if counts != nil { if counts != nil {
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
} }
return return
@ -843,6 +846,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
continue continue
} }
// collect counts // collect counts
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
} }
} }

View file

@ -1,6 +1,7 @@
package storage_test package storage_test
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
} }
}) })
} }
func TestOneTimePseudoIDs(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
userID := "@alice:localhost"
otk := api.OneTimePseudoIDs{
UserID: userID,
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
}
// Add a one time pseudoID to the DB
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
MustNotError(t, err)
// Check the count of one time pseudoIDs is correct
count, err := db.OneTimePseudoIDsCount(ctx, userID)
MustNotError(t, err)
if count.KeyCount["pseudoid_curve25519"] != 1 {
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
}
// Check the actual pseudoid contents are correct
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
MustNotError(t, err)
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
MustNotError(t, err)
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
}
})
}