diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 4d37189bc..441b410f3 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -100,7 +100,7 @@ type queryKeysRequest struct { type uploadKeysCryptoIDsRequest struct { DeviceKeys json.RawMessage `json:"device_keys"` OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` - OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"` + OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"` } func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { @@ -132,11 +132,11 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api }, } } - if r.OneTimePseudoIDs != nil { - uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{ + if r.OneTimeCryptoIDs != nil { + uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{ { UserID: device.UserID, - KeyJSON: r.OneTimePseudoIDs, + KeyJSON: r.OneTimeCryptoIDs, }, } } @@ -144,7 +144,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api util.GetLogger(req.Context()). WithField("device keys", r.DeviceKeys). WithField("one-time keys", r.OneTimeKeys). - WithField("one-time pseudoids", r.OneTimePseudoIDs). + WithField("one-time cryptoids", r.OneTimeCryptoIDs). Info("Uploading keys") var uploadRes api.PerformUploadKeysResponse @@ -170,16 +170,16 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api if len(uploadRes.OneTimeKeyCounts) > 0 { keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount } - pseudoIDCount := make(map[string]int) - if len(uploadRes.OneTimePseudoIDCounts) > 0 { - pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount + cryptoIDCount := make(map[string]int) + if len(uploadRes.OneTimeCryptoIDCounts) > 0 { + cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount } return util.JSONResponse{ Code: 200, JSON: struct { - OTKCounts interface{} `json:"one_time_key_counts"` - OTPIDCounts interface{} `json:"one_time_pseudoid_counts"` - }{keyCount, pseudoIDCount}, + OTKCounts interface{} `json:"one_time_key_counts"` + OTIDCounts interface{} `json:"one_time_cryptoid_counts"` + }{keyCount, cryptoIDCount}, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 441136a2c..6b51bb10a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -320,7 +320,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}", httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - logrus.Info("Processing request to /org.matrix.msc4080/sendPDUs") + logrus.Info("Processing request to /org.matrix.msc4080/send_pdus") if r := rateLimits.Limit(req, device); r != nil { return *r } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 184936380..9f71bb593 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -100,7 +100,7 @@ func SendEvent( } // Translate user ID state keys to room keys in pseudo ID rooms - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil { + if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil { parsedRoomID, innerErr := spec.NewRoomID(roomID) if innerErr != nil { return util.JSONResponse{ @@ -154,7 +154,7 @@ func SendEvent( } // for power level events we need to replace the userID with the pseudoID - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { + if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels { err = updatePowerLevels(req, r, roomID, rsAPI) if err != nil { return util.JSONResponse{ @@ -299,7 +299,7 @@ func SendEventCryptoIDs( } // Translate user ID state keys to room keys in pseudo ID rooms - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil { + if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil { parsedRoomID, innerErr := spec.NewRoomID(roomID) if innerErr != nil { return util.JSONResponse{ @@ -345,7 +345,7 @@ func SendEventCryptoIDs( } // for power level events we need to replace the userID with the pseudoID - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { + if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels { err = updatePowerLevels(req, r, roomID, rsAPI) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 18f9a0e9c..471294921 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -214,7 +214,7 @@ func OnIncomingStateTypeRequest( } // Translate user ID state keys to room keys in pseudo ID rooms - if roomVer == gomatrixserverlib.RoomVersionPseudoIDs { + if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs { parsedRoomID, err := spec.NewRoomID(roomID) if err != nil { return util.JSONResponse{ diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index ad22b3cc9..7a4e5e9ad 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -314,7 +314,7 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send } func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { - return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID) + return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID) } func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { @@ -328,7 +328,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s roomVersion = roomInfo.RoomVersion } } - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs { privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID) if err != nil { return fclient.SigningIdentity{}, err diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 77b50d0e2..6b63e4872 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent( } // TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one. - if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember { + if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember { mapping := gomatrixserverlib.MemberContent{} if err = json.Unmarshal(event.Content(), &mapping); err != nil { return err diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 553a3fda1..70f42c1c4 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -69,7 +69,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us return nil, spec.InternalServerError{Err: err.Error()} } - if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs { util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String()) bytes := spec.Base64Bytes{} err = bytes.Decode(string(senderID)) @@ -152,7 +152,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us } // If we are creating a room with pseudo IDs, create and sign the MXIDMapping - if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs { mapping := &gomatrixserverlib.MXIDMapping{ UserRoomKey: senderID, UserID: userID.String(), diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 2df20618d..4f58fb184 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -577,7 +577,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs( info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) if err == nil && info != nil { switch info.RoomVersion { - case gomatrixserverlib.RoomVersionPseudoIDs: + case gomatrixserverlib.RoomVersionCryptoIDs: senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) if queryErr == nil { checkInvitePending = true @@ -664,7 +664,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs( identity := r.Cfg.Matrix.SigningIdentity // at this point we know we have an existing room - if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs { mapping := &gomatrixserverlib.MXIDMapping{ UserRoomKey: senderID, UserID: userID.String(), diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f87a3f7ed..6c2d6e177 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, } switch version { - case gomatrixserverlib.RoomVersionPseudoIDs: + case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs: key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) if err != nil { return nil, err diff --git a/setup/config/config_roomserver.go b/setup/config/config_roomserver.go index f47803f48..3465551cd 100644 --- a/setup/config/config_roomserver.go +++ b/setup/config/config_roomserver.go @@ -17,7 +17,7 @@ type RoomServer struct { func (c *RoomServer) Defaults(opts DefaultOpts) { //c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10 - c.DefaultRoomVersion = gomatrixserverlib.RoomVersionPseudoIDs + c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs if opts.Generate { if !opts.SingleDatabase { c.Database.ConnectionString = "file:roomserver.db" diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index a24bf61e8..21fe11810 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -46,13 +46,13 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI return nil } -// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response -func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error { - count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID) +// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response +func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error { + count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID) if err != nil { return err } - res.OTPseudoIDsCount = count.KeyCount + res.OTCryptoIDsCount = count.KeyCount return nil } diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index ec5c9aa84..de3c155bc 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -51,8 +51,8 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { return nil } -func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { - return userapi.OneTimePseudoIDsCount{}, nil +func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) { + return userapi.OneTimeCryptoIDsCount{}, nil } func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { return nil diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index fc02311be..be52f81fa 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -41,7 +41,7 @@ func (p *DeviceListStreamProvider) IncrementalSync( req.Log.WithError(err).Error("internal.DeviceOTKCounts failed") return from } - err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response) + err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response) if err != nil { req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed") return from diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 28862937f..b2be4f606 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -280,7 +280,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTK counts") } - err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response) + err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response) if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a56c16b2a..0cf16e676 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -112,8 +112,8 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn return nil } -func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { - return userapi.OneTimePseudoIDsCount{}, nil +func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) { + return userapi.OneTimeCryptoIDsCount{}, nil } func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index fe4f6c07f..885202df2 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor // TODO: Set Signatures & Hashes fields } - if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) { err := updatePseudoIDs(&ce, se, userIDForSender, format) if err != nil { return nil, err @@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS return nil, err } - if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { + if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation { for i, ev := range inviteStateEvents { userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID)) if userIDErr != nil { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index cf5fc99b6..cb23d71ff 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/roomserver/api" @@ -365,7 +366,7 @@ type Response struct { ToDevice *ToDeviceResponse `json:"to_device,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` - OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"` + OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"` } func (r Response) MarshalJSON() ([]byte, error) { @@ -428,7 +429,7 @@ func NewResponse() *Response { res.DeviceLists = &DeviceLists{} res.ToDevice = &ToDeviceResponse{} res.DeviceListsOTKCount = map[string]int{} - res.OTPseudoIDsCount = map[string]int{} + res.OTCryptoIDsCount = map[string]int{} return &res } @@ -532,7 +533,7 @@ type InviteResponse struct { InviteState struct { Events []json.RawMessage `json:"events"` } `json:"invite_state"` - OneTimePseudoID string `json:"one_time_pseudoid,omitempty"` + OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"` } // NewInviteResponse creates an empty response with initialised arrays. @@ -540,13 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} - res.OneTimePseudoID = *event.PDU.StateKey() + logrus.Infof("Room version: %s", event.Version()) + if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs { + logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey()) + res.OneTimeCryptoID = *event.PDU.StateKey() + } // First see if there's invite_room_state in the unsigned key of the invite. // If there is then unmarshal it into the response. This will contain the // partial room state such as join rules, room name etc. if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { - if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation { + if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation { updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }, inviteRoomState, event.PDU, event.RoomID(), eventFormat) diff --git a/userapi/api/api.go b/userapi/api/api.go index 83909622c..c3e63bf95 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -51,7 +51,7 @@ type AppserviceUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) - ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) + ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) } // api functions required by the media api @@ -670,7 +670,7 @@ type UploadDeviceKeysAPI interface { type SyncKeyAPI interface { QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError) + QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError) PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error } @@ -774,7 +774,7 @@ type OneTimeKeys struct { KeyJSON map[string]json.RawMessage } -type OneTimePseudoIDs struct { +type OneTimeCryptoIDs struct { // The user who owns this device UserID string // A map of algorithm:key_id => key JSON @@ -788,7 +788,7 @@ func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { } // Split a key in KeyJSON into algorithm and key ID -func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) { +func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) { segments := strings.Split(keyIDWithAlgo, ":") return segments[0], segments[1] } @@ -807,7 +807,7 @@ type OneTimeKeysCount struct { KeyCount map[string]int } -type OneTimePseudoIDsCount struct { +type OneTimeCryptoIDsCount struct { // The user who owns this device UserID string // algorithm to count e.g: @@ -823,7 +823,7 @@ type PerformUploadKeysRequest struct { DeviceID string // Optional - Device performing the request, for fetching OTK count DeviceKeys []DeviceKeys OneTimeKeys []OneTimeKeys - OneTimePseudoIDs []OneTimePseudoIDs + OneTimeCryptoIDs []OneTimeCryptoIDs // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // the display name for their respective device, and NOT to modify the keys. The key // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. @@ -838,7 +838,7 @@ type PerformUploadKeysResponse struct { // A map of user_id -> device_id -> Error for tracking failures. KeyErrors map[string]map[string]*KeyError OneTimeKeyCounts []OneTimeKeysCount - OneTimePseudoIDCounts []OneTimePseudoIDsCount + OneTimeCryptoIDCounts []OneTimeCryptoIDsCount } // PerformDeleteKeysRequest asks the keyserver to forget about certain diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 584316121..acf20621a 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -57,21 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor if len(req.OneTimeKeys) > 0 { a.uploadOneTimeKeys(ctx, req, res) } - if len(req.OneTimePseudoIDs) > 0 { - a.uploadOneTimePseudoIDs(ctx, req, res) + if len(req.OneTimeCryptoIDs) > 0 { + a.uploadOneTimeCryptoIDs(ctx, req, res) } - logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts) + logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts) otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) if err != nil { return err } res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} - otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) + otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID) if err != nil { return err } - res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs} - logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts) + res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs} + logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts) return nil } @@ -193,11 +193,11 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn return nil } -func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) { - count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID) +func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) { + count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID) if err != nil { - return api.OneTimePseudoIDsCount{}, &api.KeyError{ - Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + return api.OneTimeCryptoIDsCount{}, &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTID counts: %s", err), } } return *count, nil @@ -796,26 +796,26 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor } -func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { +func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { if req.UserID == "" { res.Error = &api.KeyError{ Err: "user ID missing", } } - if len(req.OneTimePseudoIDs) == 0 { - counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) + if len(req.OneTimeCryptoIDs) == 0 { + counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err), + Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err), } } if counts != nil { - logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts) - res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) + logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts) + res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts) } return } - for _, key := range req.OneTimePseudoIDs { + for _, key := range req.OneTimeCryptoIDs { // grab existing keys based on (user/algorithm/key ID) keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) i := 0 @@ -823,10 +823,10 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P keyIDsWithAlgorithms[i] = keyIDWithAlgo i++ } - existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms) + existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: "failed to query existing one-time pseudoIDs: " + err.Error(), + Err: "failed to query existing one-time cryptoIDs: " + err.Error(), }) continue } @@ -834,22 +834,22 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P // if keys exist and the JSON doesn't match, error out as the key already exists if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo), + Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo), }) continue } } // store one-time keys - counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key) + counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key) if err != nil { res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()), + Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()), }) continue } // collect counts - logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts) - res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) + logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts) + res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts) } } @@ -857,16 +857,16 @@ type Ed25519Key struct { Key spec.Base64Bytes `json:"key"` } -func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { - pseudoID, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519") +func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519") if err != nil { return "", err } - logrus.Infof("Claimed one time pseuodID: %s", pseudoID) + logrus.Infof("Claimed one time cryptoID: %s", cryptoID) - if pseudoID != nil { - for key, value := range pseudoID.KeyJSON { + if cryptoID != nil { + for key, value := range cryptoID.KeyJSON { keyParts := strings.Split(key, ":") if keyParts[0] == "ed25519" { var key_bytes Ed25519Key @@ -885,7 +885,7 @@ func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec. } } - return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String()) + return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String()) } func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 315d4e4ac..b0c619cf8 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -175,10 +175,10 @@ type KeyDatabase interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) - OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) - ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) + ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) + OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) + ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/userapi/storage/postgres/one_time_pseudoids_table.go b/userapi/storage/postgres/one_time_cryptoids_table.go similarity index 55% rename from userapi/storage/postgres/one_time_pseudoids_table.go rename to userapi/storage/postgres/one_time_cryptoids_table.go index b83770669..db6fb240d 100644 --- a/userapi/storage/postgres/one_time_pseudoids_table.go +++ b/userapi/storage/postgres/one_time_cryptoids_table.go @@ -27,78 +27,78 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/tables" ) -var oneTimePseudoIDsSchema = ` --- Stores one-time pseudoIDs for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( +var oneTimeCryptoIDsSchema = ` +-- Stores one-time cryptoIDs for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids ( user_id TEXT NOT NULL, key_id TEXT NOT NULL, algorithm TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, -- Clobber based on 3-uple of user/key/algorithm. - CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm) + CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm) ); -CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); +CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id); ` -const upsertPseudoIDsSQL = "" + - "INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + +const upsertCryptoIDsSQL = "" + + "INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + " VALUES ($1, $2, $3, $4, $5)" + - " ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" + " DO UPDATE SET key_json = $5" -const selectOneTimePseudoIDsSQL = "" + - "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);" +const selectOneTimeCryptoIDsSQL = "" + + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);" -const selectPseudoIDsCountSQL = "" + +const selectCryptoIDsCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + + " (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" + " x GROUP BY algorithm" -const deleteOneTimePseudoIDSQL = "" + - "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" +const deleteOneTimeCryptoIDSQL = "" + + "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" -const selectPseudoIDByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" +const selectCryptoIDByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" -const deleteOneTimePseudoIDsSQL = "" + - "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" +const deleteOneTimeCryptoIDsSQL = "" + + "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1" -type oneTimePseudoIDsStatements struct { +type oneTimeCryptoIDsStatements struct { db *sql.DB - upsertPseudoIDsStmt *sql.Stmt - selectPseudoIDsStmt *sql.Stmt - selectPseudoIDsCountStmt *sql.Stmt - selectPseudoIDByAlgorithmStmt *sql.Stmt - deleteOneTimePseudoIDStmt *sql.Stmt - deleteOneTimePseudoIDsStmt *sql.Stmt + upsertCryptoIDsStmt *sql.Stmt + selectCryptoIDsStmt *sql.Stmt + selectCryptoIDsCountStmt *sql.Stmt + selectCryptoIDByAlgorithmStmt *sql.Stmt + deleteOneTimeCryptoIDStmt *sql.Stmt + deleteOneTimeCryptoIDsStmt *sql.Stmt } -func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { - s := &oneTimePseudoIDsStatements{ +func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) { + s := &oneTimeCryptoIDsStatements{ db: db, } - _, err := db.Exec(oneTimePseudoIDsSchema) + _, err := db.Exec(oneTimeCryptoIDsSchema) if err != nil { return nil, err } return s, sqlutil.StatementList{ - {&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, - {&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, - {&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, - {&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, - {&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, - {&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, + {&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL}, + {&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL}, + {&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL}, + {&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL}, + {&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL}, + {&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL}, }.Prepare(db) } -func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms)) +func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms)) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed") result := make(map[string]json.RawMessage) var ( @@ -114,16 +114,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, return result, rows.Err() } -func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { - counts := &api.OneTimePseudoIDsCount{ +func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) { + counts := &api.OneTimeCryptoIDsCount{ UserID: userID, KeyCount: make(map[string]int), } - rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) + rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed") for rows.Next() { var algorithm string var count int @@ -135,26 +135,26 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, return counts, nil } -func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) { +func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) { now := time.Now().Unix() - counts := &api.OneTimePseudoIDsCount{ + counts := &api.OneTimeCryptoIDsCount{ UserID: keys.UserID, KeyCount: make(map[string]int), } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext( ctx, keys.UserID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) + rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed") for rows.Next() { var algorithm string var count int @@ -167,25 +167,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, return counts, rows.Err() } -func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( +func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID( ctx context.Context, txn *sql.Tx, userID, algorithm string, ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err } -func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) +func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID) return err } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 644a2f364..1f04beb4f 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -149,7 +149,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } - otpid, err := NewPostgresOneTimePseudoIDsTable(db) + otpid, err := NewPostgresOneTimeCryptoIDsTable(db) if err != nil { return nil, err } @@ -176,7 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, - OneTimePseudoIDsTable: otpid, + OneTimeCryptoIDsTable: otpid, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 374dab24c..2eb3a54a7 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -65,7 +65,7 @@ type Database struct { type KeyDatabase struct { OneTimeKeysTable tables.OneTimeKeys - OneTimePseudoIDsTable tables.OneTimePseudoIDs + OneTimeCryptoIDsTable tables.OneTimeCryptoIDs DeviceKeysTable tables.DeviceKeys KeyChangesTable tables.KeyChanges StaleDeviceListsTable tables.StaleDeviceLists @@ -946,31 +946,31 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } -func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms) +func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms) } -func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) { +func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys) + counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys) return err }) return } -func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { - return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID) +func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) { + return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID) } -func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) { - var result *api.OneTimePseudoIDs +func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) { + var result *api.OneTimeCryptoIDs err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm) + keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm) if err != nil { return err } if keyJSON != nil { - result = &api.OneTimePseudoIDs{ + result = &api.OneTimeCryptoIDs{ UserID: userID.String(), KeyJSON: keyJSON, } diff --git a/userapi/storage/sqlite3/one_time_pseudoids_table.go b/userapi/storage/sqlite3/one_time_cryptoids_table.go similarity index 56% rename from userapi/storage/sqlite3/one_time_pseudoids_table.go rename to userapi/storage/sqlite3/one_time_cryptoids_table.go index e8015373b..ee5c367a9 100644 --- a/userapi/storage/sqlite3/one_time_pseudoids_table.go +++ b/userapi/storage/sqlite3/one_time_cryptoids_table.go @@ -27,9 +27,9 @@ import ( "github.com/sirupsen/logrus" ) -var oneTimePseudoIDsSchema = ` --- Stores one-time pseudoIDs for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( +var oneTimeCryptoIDsSchema = ` +-- Stores one-time cryptoIDs for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids ( user_id TEXT NOT NULL, key_id TEXT NOT NULL, algorithm TEXT NOT NULL, @@ -39,66 +39,66 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( UNIQUE (user_id, key_id, algorithm) ); -CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); +CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id); ` -const upsertPseudoIDsSQL = "" + - "INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + +const upsertCryptoIDsSQL = "" + + "INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT (user_id, key_id, algorithm)" + " DO UPDATE SET key_json = $5" -const selectOneTimePseudoIDsSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1" +const selectOneTimeCryptoIDsSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1" -const selectPseudoIDsCountSQL = "" + +const selectCryptoIDsCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + + " (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" + " x GROUP BY algorithm" -const deleteOneTimePseudoIDSQL = "" + - "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" +const deleteOneTimeCryptoIDSQL = "" + + "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" -const selectPseudoIDByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" +const selectCryptoIDByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" -const deleteOneTimePseudoIDsSQL = "" + - "DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" +const deleteOneTimeCryptoIDsSQL = "" + + "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1" -type oneTimePseudoIDsStatements struct { +type oneTimeCryptoIDsStatements struct { db *sql.DB - upsertPseudoIDsStmt *sql.Stmt - selectPseudoIDsStmt *sql.Stmt - selectPseudoIDsCountStmt *sql.Stmt - selectPseudoIDByAlgorithmStmt *sql.Stmt - deleteOneTimePseudoIDStmt *sql.Stmt - deleteOneTimePseudoIDsStmt *sql.Stmt + upsertCryptoIDsStmt *sql.Stmt + selectCryptoIDsStmt *sql.Stmt + selectCryptoIDsCountStmt *sql.Stmt + selectCryptoIDByAlgorithmStmt *sql.Stmt + deleteOneTimeCryptoIDStmt *sql.Stmt + deleteOneTimeCryptoIDsStmt *sql.Stmt } -func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { - s := &oneTimePseudoIDsStatements{ +func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) { + s := &oneTimeCryptoIDsStatements{ db: db, } - _, err := db.Exec(oneTimePseudoIDsSchema) + _, err := db.Exec(oneTimeCryptoIDsSchema) if err != nil { return nil, err } return s, sqlutil.StatementList{ - {&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, - {&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, - {&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, - {&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, - {&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, - {&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, + {&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL}, + {&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL}, + {&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL}, + {&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL}, + {&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL}, + {&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL}, }.Prepare(db) } -func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID) +func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed") wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) for _, ka := range keyIDsWithAlgorithms { @@ -121,16 +121,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, return result, rows.Err() } -func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { - counts := &api.OneTimePseudoIDsCount{ +func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) { + counts := &api.OneTimeCryptoIDsCount{ UserID: userID, KeyCount: make(map[string]int), } - rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) + rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed") for rows.Next() { var algorithm string var count int @@ -142,28 +142,28 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, return counts, nil } -func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs( - ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs, -) (*api.OneTimePseudoIDsCount, error) { +func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs( + ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs, +) (*api.OneTimeCryptoIDsCount, error) { now := time.Now().Unix() - counts := &api.OneTimePseudoIDsCount{ + counts := &api.OneTimeCryptoIDsCount{ UserID: keys.UserID, KeyCount: make(map[string]int), } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext( ctx, keys.UserID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) + rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed") for rows.Next() { var algorithm string var count int @@ -176,25 +176,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs( return counts, rows.Err() } -func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( +func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID( ctx context.Context, txn *sql.Tx, userID, algorithm string, ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { - logrus.Warnf("No rows found for one time pseudoIDs") + logrus.Warnf("No rows found for one time cryptoIDs") return nil, nil } return nil, err } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID) if err != nil { return nil, err } if keyJSON == "" { - logrus.Warnf("Empty key JSON for one time pseudoIDs") + logrus.Warnf("Empty key JSON for one time cryptoIDs") return nil, nil } return map[string]json.RawMessage{ @@ -202,7 +202,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( }, err } -func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) +func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID) return err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 356920263..c0fe4c077 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -146,7 +146,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp if err != nil { return nil, err } - otpid, err := NewSqliteOneTimePseudoIDsTable(db) + otpid, err := NewSqliteOneTimeCryptoIDsTable(db) if err != nil { return nil, err } @@ -173,7 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp return &shared.KeyDatabase{ OneTimeKeysTable: otk, - OneTimePseudoIDsTable: otpid, + OneTimeCryptoIDsTable: otpid, DeviceKeysTable: dk, KeyChangesTable: kc, StaleDeviceListsTable: sdl, diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 35a41d516..7ee6ab668 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -760,29 +760,29 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { }) } -func TestOneTimePseudoIDs(t *testing.T) { +func TestOneTimeCryptoIDs(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, clean := mustCreateKeyDatabase(t, dbType) defer clean() userID := "@alice:localhost" - otk := api.OneTimePseudoIDs{ + otk := api.OneTimeCryptoIDs{ UserID: userID, KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)}, } // Add a one time pseudoID to the DB - _, err := db.StoreOneTimePseudoIDs(ctx, otk) + _, err := db.StoreOneTimeCryptoIDs(ctx, otk) MustNotError(t, err) // Check the count of one time pseudoIDs is correct - count, err := db.OneTimePseudoIDsCount(ctx, userID) + count, err := db.OneTimeCryptoIDsCount(ctx, userID) MustNotError(t, err) if count.KeyCount["pseudoid_curve25519"] != 1 { t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"]) } // Check the actual pseudoid contents are correct - keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"}) + keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"}) MustNotError(t, err) keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON() MustNotError(t, err) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 14c04b0f5..d3964b245 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -168,12 +168,12 @@ type OneTimeKeys interface { DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error } -type OneTimePseudoIDs interface { - SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) - InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) - SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error +type OneTimeCryptoIDs interface { + SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) + InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) + SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error } type DeviceKeys interface {