diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index afd193a95..4d37189bc 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -141,6 +141,12 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api } } + util.GetLogger(req.Context()). + WithField("device keys", r.DeviceKeys). + WithField("one-time keys", r.OneTimeKeys). + WithField("one-time pseudoids", r.OneTimePseudoIDs). + Info("Uploading keys") + var uploadRes api.PerformUploadKeysResponse if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil { return util.ErrorResponse(err) @@ -166,7 +172,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api } pseudoIDCount := make(map[string]int) if len(uploadRes.OneTimePseudoIDCounts) > 0 { - keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount + pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount } return util.JSONResponse{ Code: 200, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 4ba1a3782..1cf689e6a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1598,6 +1598,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + logrus.Info("Processing request to /org.matrix.msc_cryptoids/keys/upload") return UploadKeysCryptoIDs(req, userAPI, device) }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 28862937f..2734bbac2 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -269,6 +269,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. defer userStreamListener.Close() giveup := func() util.JSONResponse { + syncReq.Log.Info("Responding to sync since client gave up or timeout was reached") syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") syncReq.Response.NextBatch = syncReq.Since // We should always try to include OTKs in sync responses, otherwise clients might upload keys @@ -284,6 +285,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") } + + syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount) + syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount) } return util.JSONResponse{ Code: http.StatusOK, diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 2f57d5df6..e74700a7d 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -365,7 +365,7 @@ type Response struct { ToDevice *ToDeviceResponse `json:"to_device,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` - OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"` + OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"` } func (r Response) MarshalJSON() ([]byte, error) { diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index ae342f9fb..85f245435 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -58,6 +58,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor if len(req.OneTimePseudoIDs) > 0 { a.uploadOneTimePseudoIDs(ctx, req, res) } + logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts) otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err @@ -68,6 +69,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor return err } res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs} + logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts) return nil } @@ -806,6 +808,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P } } if counts != nil { + logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts) res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) } return @@ -843,6 +846,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P continue } // collect counts + logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts) res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) } } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a46ee9ebb..35a41d516 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -1,6 +1,7 @@ package storage_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { } }) } + +func TestOneTimePseudoIDs(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + userID := "@alice:localhost" + otk := api.OneTimePseudoIDs{ + UserID: userID, + KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)}, + } + + // Add a one time pseudoID to the DB + _, err := db.StoreOneTimePseudoIDs(ctx, otk) + MustNotError(t, err) + + // Check the count of one time pseudoIDs is correct + count, err := db.OneTimePseudoIDsCount(ctx, userID) + MustNotError(t, err) + if count.KeyCount["pseudoid_curve25519"] != 1 { + t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"]) + } + + // Check the actual pseudoid contents are correct + keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"}) + MustNotError(t, err) + keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON() + MustNotError(t, err) + if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) { + t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"]) + } + }) +}