mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-22 14:21:55 -06:00
Fix one time pseudoids in upload keys & sync endpoints
This commit is contained in:
parent
038103ac7f
commit
60be1391bf
|
@ -141,6 +141,12 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.GetLogger(req.Context()).
|
||||||
|
WithField("device keys", r.DeviceKeys).
|
||||||
|
WithField("one-time keys", r.OneTimeKeys).
|
||||||
|
WithField("one-time pseudoids", r.OneTimePseudoIDs).
|
||||||
|
Info("Uploading keys")
|
||||||
|
|
||||||
var uploadRes api.PerformUploadKeysResponse
|
var uploadRes api.PerformUploadKeysResponse
|
||||||
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -166,7 +172,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
||||||
}
|
}
|
||||||
pseudoIDCount := make(map[string]int)
|
pseudoIDCount := make(map[string]int)
|
||||||
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
||||||
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 200,
|
Code: 200,
|
||||||
|
|
|
@ -1598,6 +1598,7 @@ func Setup(
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
|
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
|
||||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc_cryptoids/keys/upload")
|
||||||
return UploadKeysCryptoIDs(req, userAPI, device)
|
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
|
@ -269,6 +269,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
defer userStreamListener.Close()
|
defer userStreamListener.Close()
|
||||||
|
|
||||||
giveup := func() util.JSONResponse {
|
giveup := func() util.JSONResponse {
|
||||||
|
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
|
||||||
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
|
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
|
||||||
syncReq.Response.NextBatch = syncReq.Since
|
syncReq.Response.NextBatch = syncReq.Since
|
||||||
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
|
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
|
||||||
|
@ -284,6 +285,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
if err != nil && err != context.Canceled {
|
if err != nil && err != context.Canceled {
|
||||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
|
||||||
|
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
@ -365,7 +365,7 @@ type Response struct {
|
||||||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||||
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
|
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Response) MarshalJSON() ([]byte, error) {
|
func (r Response) MarshalJSON() ([]byte, error) {
|
||||||
|
|
|
@ -58,6 +58,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
||||||
if len(req.OneTimePseudoIDs) > 0 {
|
if len(req.OneTimePseudoIDs) > 0 {
|
||||||
a.uploadOneTimePseudoIDs(ctx, req, res)
|
a.uploadOneTimePseudoIDs(ctx, req, res)
|
||||||
}
|
}
|
||||||
|
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
|
||||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -68,6 +69,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
||||||
|
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -806,6 +808,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if counts != nil {
|
if counts != nil {
|
||||||
|
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
|
||||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -843,6 +846,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// collect counts
|
// collect counts
|
||||||
|
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
|
||||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package storage_test
|
package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOneTimePseudoIDs(t *testing.T) {
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||||
|
defer clean()
|
||||||
|
userID := "@alice:localhost"
|
||||||
|
otk := api.OneTimePseudoIDs{
|
||||||
|
UserID: userID,
|
||||||
|
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a one time pseudoID to the DB
|
||||||
|
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
|
||||||
|
MustNotError(t, err)
|
||||||
|
|
||||||
|
// Check the count of one time pseudoIDs is correct
|
||||||
|
count, err := db.OneTimePseudoIDsCount(ctx, userID)
|
||||||
|
MustNotError(t, err)
|
||||||
|
if count.KeyCount["pseudoid_curve25519"] != 1 {
|
||||||
|
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the actual pseudoid contents are correct
|
||||||
|
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||||
|
MustNotError(t, err)
|
||||||
|
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
|
||||||
|
MustNotError(t, err)
|
||||||
|
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||||
|
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue