mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 08:11:55 -06:00
Fix one time pseudoids in upload keys & sync endpoints
This commit is contained in:
parent
038103ac7f
commit
60be1391bf
|
@ -141,6 +141,12 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
|||
}
|
||||
}
|
||||
|
||||
util.GetLogger(req.Context()).
|
||||
WithField("device keys", r.DeviceKeys).
|
||||
WithField("one-time keys", r.OneTimeKeys).
|
||||
WithField("one-time pseudoids", r.OneTimePseudoIDs).
|
||||
Info("Uploading keys")
|
||||
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := keyAPI.PerformUploadKeys(req.Context(), uploadReq, &uploadRes); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -166,7 +172,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
|||
}
|
||||
pseudoIDCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
||||
keyCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
|
|
|
@ -1598,6 +1598,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc_cryptoids/keys/upload",
|
||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc_cryptoids/keys/upload")
|
||||
return UploadKeysCryptoIDs(req, userAPI, device)
|
||||
}, httputil.WithAllowGuests()),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
|
|
@ -269,6 +269,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
defer userStreamListener.Close()
|
||||
|
||||
giveup := func() util.JSONResponse {
|
||||
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
|
||||
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
|
||||
syncReq.Response.NextBatch = syncReq.Since
|
||||
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
|
||||
|
@ -284,6 +285,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||
}
|
||||
|
||||
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
|
||||
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
|
|
|
@ -365,7 +365,7 @@ type Response struct {
|
|||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
OTPseudoIDsCount map[string]int `json:"one_time_pseudoIDs_count,omitempty"`
|
||||
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
|
||||
}
|
||||
|
||||
func (r Response) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -58,6 +58,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.OneTimePseudoIDs) > 0 {
|
||||
a.uploadOneTimePseudoIDs(ctx, req, res)
|
||||
}
|
||||
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -68,6 +69,7 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
return err
|
||||
}
|
||||
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
||||
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -806,6 +808,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
|||
}
|
||||
}
|
||||
if counts != nil {
|
||||
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
}
|
||||
return
|
||||
|
@ -843,6 +846,7 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
|||
continue
|
||||
}
|
||||
// collect counts
|
||||
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -758,3 +759,35 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOneTimePseudoIDs(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
userID := "@alice:localhost"
|
||||
otk := api.OneTimePseudoIDs{
|
||||
UserID: userID,
|
||||
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||
}
|
||||
|
||||
// Add a one time pseudoID to the DB
|
||||
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
|
||||
MustNotError(t, err)
|
||||
|
||||
// Check the count of one time pseudoIDs is correct
|
||||
count, err := db.OneTimePseudoIDsCount(ctx, userID)
|
||||
MustNotError(t, err)
|
||||
if count.KeyCount["pseudoid_curve25519"] != 1 {
|
||||
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
|
||||
}
|
||||
|
||||
// Check the actual pseudoid contents are correct
|
||||
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||
MustNotError(t, err)
|
||||
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
|
||||
MustNotError(t, err)
|
||||
if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
|
||||
t.Fatalf("Existing pseudoIDs do not match expected. Got %v", keysJSON["pseudoid_curve25519:KEY1"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue