Change cryptoid references from pseudoids

This commit is contained in:
Devon Hudson 2023-11-17 17:34:01 -07:00
parent 3cbccb9ed7
commit b45e72830e
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
27 changed files with 219 additions and 214 deletions

View file

@ -100,7 +100,7 @@ type queryKeysRequest struct {
type uploadKeysCryptoIDsRequest struct { type uploadKeysCryptoIDsRequest struct {
DeviceKeys json.RawMessage `json:"device_keys"` DeviceKeys json.RawMessage `json:"device_keys"`
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"` OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
} }
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse { func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
@ -132,11 +132,11 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
}, },
} }
} }
if r.OneTimePseudoIDs != nil { if r.OneTimeCryptoIDs != nil {
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{ uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
{ {
UserID: device.UserID, UserID: device.UserID,
KeyJSON: r.OneTimePseudoIDs, KeyJSON: r.OneTimeCryptoIDs,
}, },
} }
} }
@ -144,7 +144,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
util.GetLogger(req.Context()). util.GetLogger(req.Context()).
WithField("device keys", r.DeviceKeys). WithField("device keys", r.DeviceKeys).
WithField("one-time keys", r.OneTimeKeys). WithField("one-time keys", r.OneTimeKeys).
WithField("one-time pseudoids", r.OneTimePseudoIDs). WithField("one-time cryptoids", r.OneTimeCryptoIDs).
Info("Uploading keys") Info("Uploading keys")
var uploadRes api.PerformUploadKeysResponse var uploadRes api.PerformUploadKeysResponse
@ -170,16 +170,16 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
if len(uploadRes.OneTimeKeyCounts) > 0 { if len(uploadRes.OneTimeKeyCounts) > 0 {
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
} }
pseudoIDCount := make(map[string]int) cryptoIDCount := make(map[string]int)
if len(uploadRes.OneTimePseudoIDCounts) > 0 { if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: struct { JSON: struct {
OTKCounts interface{} `json:"one_time_key_counts"` OTKCounts interface{} `json:"one_time_key_counts"`
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"` OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
}{keyCount, pseudoIDCount}, }{keyCount, cryptoIDCount},
} }
} }

View file

@ -320,7 +320,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}", unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc4080/sendPDUs") logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
return *r return *r
} }

View file

@ -100,7 +100,7 @@ func SendEvent(
} }
// Translate user ID state keys to room keys in pseudo ID rooms // Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID) parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil { if innerErr != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -154,7 +154,7 @@ func SendEvent(
} }
// for power level events we need to replace the userID with the pseudoID // for power level events we need to replace the userID with the pseudoID
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI) err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -299,7 +299,7 @@ func SendEventCryptoIDs(
} }
// Translate user ID state keys to room keys in pseudo ID rooms // Translate user ID state keys to room keys in pseudo ID rooms
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
parsedRoomID, innerErr := spec.NewRoomID(roomID) parsedRoomID, innerErr := spec.NewRoomID(roomID)
if innerErr != nil { if innerErr != nil {
return util.JSONResponse{ return util.JSONResponse{
@ -345,7 +345,7 @@ func SendEventCryptoIDs(
} }
// for power level events we need to replace the userID with the pseudoID // for power level events we need to replace the userID with the pseudoID
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels { if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
err = updatePowerLevels(req, r, roomID, rsAPI) err = updatePowerLevels(req, r, roomID, rsAPI)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -214,7 +214,7 @@ func OnIncomingStateTypeRequest(
} }
// Translate user ID state keys to room keys in pseudo ID rooms // Translate user ID state keys to room keys in pseudo ID rooms
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs { if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
parsedRoomID, err := spec.NewRoomID(roomID) parsedRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{

View file

@ -314,7 +314,7 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
} }
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID) return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
} }
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
@ -328,7 +328,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
roomVersion = roomInfo.RoomVersion roomVersion = roomInfo.RoomVersion
} }
} }
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID) privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID)
if err != nil { if err != nil {
return fclient.SigningIdentity{}, err return fclient.SigningIdentity{}, err

View file

@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
} }
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one. // TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
mapping := gomatrixserverlib.MemberContent{} mapping := gomatrixserverlib.MemberContent{}
if err = json.Unmarshal(event.Content(), &mapping); err != nil { if err = json.Unmarshal(event.Content(), &mapping); err != nil {
return err return err

View file

@ -69,7 +69,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
return nil, spec.InternalServerError{Err: err.Error()} return nil, spec.InternalServerError{Err: err.Error()}
} }
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String()) util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
bytes := spec.Base64Bytes{} bytes := spec.Base64Bytes{}
err = bytes.Decode(string(senderID)) err = bytes.Decode(string(senderID))
@ -152,7 +152,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
} }
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping // If we are creating a room with pseudo IDs, create and sign the MXIDMapping
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{ mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID, UserRoomKey: senderID,
UserID: userID.String(), UserID: userID.String(),

View file

@ -577,7 +577,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
if err == nil && info != nil { if err == nil && info != nil {
switch info.RoomVersion { switch info.RoomVersion {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionCryptoIDs:
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
if queryErr == nil { if queryErr == nil {
checkInvitePending = true checkInvitePending = true
@ -664,7 +664,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
identity := r.Cfg.Matrix.SigningIdentity identity := r.Cfg.Matrix.SigningIdentity
// at this point we know we have an existing room // at this point we know we have an existing room
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
mapping := &gomatrixserverlib.MXIDMapping{ mapping := &gomatrixserverlib.MXIDMapping{
UserRoomKey: senderID, UserRoomKey: senderID,
UserID: userID.String(), UserID: userID.String(),

View file

@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
} }
switch version { switch version {
case gomatrixserverlib.RoomVersionPseudoIDs: case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -17,7 +17,7 @@ type RoomServer struct {
func (c *RoomServer) Defaults(opts DefaultOpts) { func (c *RoomServer) Defaults(opts DefaultOpts) {
//c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10 //c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionPseudoIDs c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
if opts.Generate { if opts.Generate {
if !opts.SingleDatabase { if !opts.SingleDatabase {
c.Database.ConnectionString = "file:roomserver.db" c.Database.ConnectionString = "file:roomserver.db"

View file

@ -46,13 +46,13 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
return nil return nil
} }
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response // OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error { func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID) count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
res.OTPseudoIDsCount = count.KeyCount res.OTCryptoIDsCount = count.KeyCount
return nil return nil
} }

View file

@ -51,8 +51,8 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error { func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil return nil
} }
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil return userapi.OneTimeCryptoIDsCount{}, nil
} }
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error { func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil return nil

View file

@ -41,7 +41,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed") req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
return from return from
} }
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response) err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
if err != nil { if err != nil {
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed") req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
return from return from

View file

@ -280,7 +280,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled { if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts") syncReq.Log.WithError(err).Warn("failed to get OTK counts")
} }
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response) err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
if err != nil && err != context.Canceled { if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
} }

View file

@ -112,8 +112,8 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
return nil return nil
} }
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) { func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
return userapi.OneTimePseudoIDsCount{}, nil return userapi.OneTimeCryptoIDsCount{}, nil
} }
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error { func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {

View file

@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
// TODO: Set Signatures & Hashes fields // TODO: Set Signatures & Hashes fields
} }
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
err := updatePseudoIDs(&ce, se, userIDForSender, format) err := updatePseudoIDs(&ce, se, userIDForSender, format)
if err != nil { if err != nil {
return nil, err return nil, err
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
return nil, err return nil, err
} }
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
for i, ev := range inviteStateEvents { for i, ev := range inviteStateEvents {
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID)) userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
if userIDErr != nil { if userIDErr != nil {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -365,7 +366,7 @@ type Response struct {
ToDevice *ToDeviceResponse `json:"to_device,omitempty"` ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
DeviceLists *DeviceLists `json:"device_lists,omitempty"` DeviceLists *DeviceLists `json:"device_lists,omitempty"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"` OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
} }
func (r Response) MarshalJSON() ([]byte, error) { func (r Response) MarshalJSON() ([]byte, error) {
@ -428,7 +429,7 @@ func NewResponse() *Response {
res.DeviceLists = &DeviceLists{} res.DeviceLists = &DeviceLists{}
res.ToDevice = &ToDeviceResponse{} res.ToDevice = &ToDeviceResponse{}
res.DeviceListsOTKCount = map[string]int{} res.DeviceListsOTKCount = map[string]int{}
res.OTPseudoIDsCount = map[string]int{} res.OTCryptoIDsCount = map[string]int{}
return &res return &res
} }
@ -532,7 +533,7 @@ type InviteResponse struct {
InviteState struct { InviteState struct {
Events []json.RawMessage `json:"events"` Events []json.RawMessage `json:"events"`
} `json:"invite_state"` } `json:"invite_state"`
OneTimePseudoID string `json:"one_time_pseudoid,omitempty"` OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
} }
// NewInviteResponse creates an empty response with initialised arrays. // NewInviteResponse creates an empty response with initialised arrays.
@ -540,13 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
res := InviteResponse{} res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{} res.InviteState.Events = []json.RawMessage{}
res.OneTimePseudoID = *event.PDU.StateKey() logrus.Infof("Room version: %s", event.Version())
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
res.OneTimeCryptoID = *event.PDU.StateKey()
}
// First see if there's invite_room_state in the unsigned key of the invite. // First see if there's invite_room_state in the unsigned key of the invite.
// If there is then unmarshal it into the response. This will contain the // If there is then unmarshal it into the response. This will contain the
// partial room state such as join rules, room name etc. // partial room state such as join rules, room name etc.
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation { if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat) }, inviteRoomState, event.PDU, event.RoomID(), eventFormat)

View file

@ -51,7 +51,7 @@ type AppserviceUserAPI interface {
type RoomserverUserAPI interface { type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
} }
// api functions required by the media api // api functions required by the media api
@ -670,7 +670,7 @@ type UploadDeviceKeysAPI interface {
type SyncKeyAPI interface { type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
} }
@ -774,7 +774,7 @@ type OneTimeKeys struct {
KeyJSON map[string]json.RawMessage KeyJSON map[string]json.RawMessage
} }
type OneTimePseudoIDs struct { type OneTimeCryptoIDs struct {
// The user who owns this device // The user who owns this device
UserID string UserID string
// A map of algorithm:key_id => key JSON // A map of algorithm:key_id => key JSON
@ -788,7 +788,7 @@ func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
} }
// Split a key in KeyJSON into algorithm and key ID // Split a key in KeyJSON into algorithm and key ID
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) { func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":") segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1] return segments[0], segments[1]
} }
@ -807,7 +807,7 @@ type OneTimeKeysCount struct {
KeyCount map[string]int KeyCount map[string]int
} }
type OneTimePseudoIDsCount struct { type OneTimeCryptoIDsCount struct {
// The user who owns this device // The user who owns this device
UserID string UserID string
// algorithm to count e.g: // algorithm to count e.g:
@ -823,7 +823,7 @@ type PerformUploadKeysRequest struct {
DeviceID string // Optional - Device performing the request, for fetching OTK count DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys OneTimeKeys []OneTimeKeys
OneTimePseudoIDs []OneTimePseudoIDs OneTimeCryptoIDs []OneTimeCryptoIDs
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key // the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@ -838,7 +838,7 @@ type PerformUploadKeysResponse struct {
// A map of user_id -> device_id -> Error for tracking failures. // A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount OneTimeKeyCounts []OneTimeKeysCount
OneTimePseudoIDCounts []OneTimePseudoIDsCount OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
} }
// PerformDeleteKeysRequest asks the keyserver to forget about certain // PerformDeleteKeysRequest asks the keyserver to forget about certain

View file

@ -57,21 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.OneTimeKeys) > 0 { if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
if len(req.OneTimePseudoIDs) > 0 { if len(req.OneTimeCryptoIDs) > 0 {
a.uploadOneTimePseudoIDs(ctx, req, res) a.uploadOneTimeCryptoIDs(ctx, req, res)
} }
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts) logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil { if err != nil {
return err return err
} }
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil { if err != nil {
return err return err
} }
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs} res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts) logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
return nil return nil
} }
@ -193,11 +193,11 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
return nil return nil
} }
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) { func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID) count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
if err != nil { if err != nil {
return api.OneTimePseudoIDsCount{}, &api.KeyError{ return api.OneTimeCryptoIDsCount{}, &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err), Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
} }
} }
return *count, nil return *count, nil
@ -796,26 +796,26 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
} }
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" { if req.UserID == "" {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: "user ID missing", Err: "user ID missing",
} }
} }
if len(req.OneTimePseudoIDs) == 0 { if len(req.OneTimeCryptoIDs) == 0 {
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID) counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err), Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
} }
} }
if counts != nil { if counts != nil {
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts) logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
} }
return return
} }
for _, key := range req.OneTimePseudoIDs { for _, key := range req.OneTimeCryptoIDs {
// grab existing keys based on (user/algorithm/key ID) // grab existing keys based on (user/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
i := 0 i := 0
@ -823,10 +823,10 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
keyIDsWithAlgorithms[i] = keyIDWithAlgo keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++ i++
} }
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms) existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
if err != nil { if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time pseudoIDs: " + err.Error(), Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
}) })
continue continue
} }
@ -834,22 +834,22 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
// if keys exist and the JSON doesn't match, error out as the key already exists // if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo), Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
}) })
continue continue
} }
} }
// store one-time keys // store one-time keys
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key) counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
if err != nil { if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()), Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
}) })
continue continue
} }
// collect counts // collect counts
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts) logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts) res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
} }
} }
@ -857,16 +857,16 @@ type Ed25519Key struct {
Key spec.Base64Bytes `json:"key"` Key spec.Base64Bytes `json:"key"`
} }
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
pseudoID, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519") cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
if err != nil { if err != nil {
return "", err return "", err
} }
logrus.Infof("Claimed one time pseuodID: %s", pseudoID) logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
if pseudoID != nil { if cryptoID != nil {
for key, value := range pseudoID.KeyJSON { for key, value := range cryptoID.KeyJSON {
keyParts := strings.Split(key, ":") keyParts := strings.Split(key, ":")
if keyParts[0] == "ed25519" { if keyParts[0] == "ed25519" {
var key_bytes Ed25519Key var key_bytes Ed25519Key
@ -885,7 +885,7 @@ func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.
} }
} }
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String()) return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
} }
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {

View file

@ -175,10 +175,10 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device. // OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -27,78 +27,78 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
) )
var oneTimePseudoIDsSchema = ` var oneTimeCryptoIDsSchema = `
-- Stores one-time pseudoIDs for users -- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
key_id TEXT NOT NULL, key_id TEXT NOT NULL,
algorithm TEXT NOT NULL, algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
-- Clobber based on 3-uple of user/key/algorithm. -- Clobber based on 3-uple of user/key/algorithm.
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm) CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
` `
const upsertPseudoIDsSQL = "" + const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + "INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" + " ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
" DO UPDATE SET key_json = $5" " DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" + const selectOneTimeCryptoIDsSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);" "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
const selectPseudoIDsCountSQL = "" + const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " + "SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + " (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm" " x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" + const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" + const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" "SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" + const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct { type oneTimeCryptoIDsStatements struct {
db *sql.DB db *sql.DB
upsertPseudoIDsStmt *sql.Stmt upsertCryptoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt selectCryptoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt selectCryptoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt deleteOneTimeCryptoIDsStmt *sql.Stmt
} }
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimePseudoIDsStatements{ s := &oneTimeCryptoIDsStatements{
db: db, db: db,
} }
_, err := db.Exec(oneTimePseudoIDsSchema) _, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, {&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, {&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, {&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, {&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, {&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, {&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms)) rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
result := make(map[string]json.RawMessage) result := make(map[string]json.RawMessage)
var ( var (
@ -114,16 +114,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{ counts := &api.OneTimeCryptoIDsCount{
UserID: userID, UserID: userID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() { for rows.Next() {
var algorithm string var algorithm string
var count int var count int
@ -135,26 +135,26 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
return counts, nil return counts, nil
} }
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) { func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix() now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{ counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID, UserID: keys.UserID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() { for rows.Next() {
var algorithm string var algorithm string
var count int var count int
@ -167,25 +167,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context,
return counts, rows.Err() return counts, rows.Err()
} }
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string, ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
return map[string]json.RawMessage{ return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON), algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err }, err
} }
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err return err
} }

View file

@ -149,7 +149,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil { if err != nil {
return nil, err return nil, err
} }
otpid, err := NewPostgresOneTimePseudoIDsTable(db) otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -176,7 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{ return &shared.KeyDatabase{
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid, OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl, StaleDeviceListsTable: sdl,

View file

@ -65,7 +65,7 @@ type Database struct {
type KeyDatabase struct { type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys OneTimeKeysTable tables.OneTimeKeys
OneTimePseudoIDsTable tables.OneTimePseudoIDs OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
DeviceKeysTable tables.DeviceKeys DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists StaleDeviceListsTable tables.StaleDeviceLists
@ -946,31 +946,31 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
} }
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms) return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
} }
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) { func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys) counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
return err return err
}) })
return return
} }
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID) return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
} }
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) { func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
var result *api.OneTimePseudoIDs var result *api.OneTimeCryptoIDs
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm) keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
if err != nil { if err != nil {
return err return err
} }
if keyJSON != nil { if keyJSON != nil {
result = &api.OneTimePseudoIDs{ result = &api.OneTimeCryptoIDs{
UserID: userID.String(), UserID: userID.String(),
KeyJSON: keyJSON, KeyJSON: keyJSON,
} }

View file

@ -27,9 +27,9 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
var oneTimePseudoIDsSchema = ` var oneTimeCryptoIDsSchema = `
-- Stores one-time pseudoIDs for users -- Stores one-time cryptoIDs for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids ( CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
key_id TEXT NOT NULL, key_id TEXT NOT NULL,
algorithm TEXT NOT NULL, algorithm TEXT NOT NULL,
@ -39,66 +39,66 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
UNIQUE (user_id, key_id, algorithm) UNIQUE (user_id, key_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id); CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
` `
const upsertPseudoIDsSQL = "" + const upsertCryptoIDsSQL = "" +
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" + "INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, key_id, algorithm)" + " ON CONFLICT (user_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $5" " DO UPDATE SET key_json = $5"
const selectOneTimePseudoIDsSQL = "" + const selectOneTimeCryptoIDsSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1" "SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
const selectPseudoIDsCountSQL = "" + const selectCryptoIDsCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " + "SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" + " (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
" x GROUP BY algorithm" " x GROUP BY algorithm"
const deleteOneTimePseudoIDSQL = "" + const deleteOneTimeCryptoIDSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3" "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
const selectPseudoIDByAlgorithmSQL = "" + const selectCryptoIDByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1" "SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
const deleteOneTimePseudoIDsSQL = "" + const deleteOneTimeCryptoIDsSQL = "" +
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1" "DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
type oneTimePseudoIDsStatements struct { type oneTimeCryptoIDsStatements struct {
db *sql.DB db *sql.DB
upsertPseudoIDsStmt *sql.Stmt upsertCryptoIDsStmt *sql.Stmt
selectPseudoIDsStmt *sql.Stmt selectCryptoIDsStmt *sql.Stmt
selectPseudoIDsCountStmt *sql.Stmt selectCryptoIDsCountStmt *sql.Stmt
selectPseudoIDByAlgorithmStmt *sql.Stmt selectCryptoIDByAlgorithmStmt *sql.Stmt
deleteOneTimePseudoIDStmt *sql.Stmt deleteOneTimeCryptoIDStmt *sql.Stmt
deleteOneTimePseudoIDsStmt *sql.Stmt deleteOneTimeCryptoIDsStmt *sql.Stmt
} }
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) { func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
s := &oneTimePseudoIDsStatements{ s := &oneTimeCryptoIDsStatements{
db: db, db: db,
} }
_, err := db.Exec(oneTimePseudoIDsSchema) _, err := db.Exec(oneTimeCryptoIDsSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL}, {&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL}, {&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL}, {&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL}, {&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL}, {&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL}, {&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID) rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms { for _, ka := range keyIDsWithAlgorithms {
@ -121,16 +121,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
return result, rows.Err() return result, rows.Err()
} }
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) { func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
counts := &api.OneTimePseudoIDsCount{ counts := &api.OneTimeCryptoIDsCount{
UserID: userID, UserID: userID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID) rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() { for rows.Next() {
var algorithm string var algorithm string
var count int var count int
@ -142,28 +142,28 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
return counts, nil return counts, nil
} }
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs( func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs, ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
) (*api.OneTimePseudoIDsCount, error) { ) (*api.OneTimeCryptoIDsCount, error) {
now := time.Now().Unix() now := time.Now().Unix()
counts := &api.OneTimePseudoIDsCount{ counts := &api.OneTimeCryptoIDsCount{
UserID: keys.UserID, UserID: keys.UserID,
KeyCount: make(map[string]int), KeyCount: make(map[string]int),
} }
for keyIDWithAlgo, keyJSON := range keys.KeyJSON { for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo) algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext( _, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
ctx, keys.UserID, keyID, algo, now, string(keyJSON), ctx, keys.UserID, keyID, algo, now, string(keyJSON),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID) rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
for rows.Next() { for rows.Next() {
var algorithm string var algorithm string
var count int var count int
@ -176,25 +176,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
return counts, rows.Err() return counts, rows.Err()
} }
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
ctx context.Context, txn *sql.Tx, userID, algorithm string, ctx context.Context, txn *sql.Tx, userID, algorithm string,
) (map[string]json.RawMessage, error) { ) (map[string]json.RawMessage, error) {
var keyID string var keyID string
var keyJSON string var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
logrus.Warnf("No rows found for one time pseudoIDs") logrus.Warnf("No rows found for one time cryptoIDs")
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID) _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if keyJSON == "" { if keyJSON == "" {
logrus.Warnf("Empty key JSON for one time pseudoIDs") logrus.Warnf("Empty key JSON for one time cryptoIDs")
return nil, nil return nil, nil
} }
return map[string]json.RawMessage{ return map[string]json.RawMessage{
@ -202,7 +202,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
}, err }, err
} }
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error { func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID) _, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
return err return err
} }

View file

@ -146,7 +146,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil { if err != nil {
return nil, err return nil, err
} }
otpid, err := NewSqliteOneTimePseudoIDsTable(db) otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -173,7 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{ return &shared.KeyDatabase{
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
OneTimePseudoIDsTable: otpid, OneTimeCryptoIDsTable: otpid,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc, KeyChangesTable: kc,
StaleDeviceListsTable: sdl, StaleDeviceListsTable: sdl,

View file

@ -760,29 +760,29 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}) })
} }
func TestOneTimePseudoIDs(t *testing.T) { func TestOneTimeCryptoIDs(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType) db, clean := mustCreateKeyDatabase(t, dbType)
defer clean() defer clean()
userID := "@alice:localhost" userID := "@alice:localhost"
otk := api.OneTimePseudoIDs{ otk := api.OneTimeCryptoIDs{
UserID: userID, UserID: userID,
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)}, KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
} }
// Add a one time pseudoID to the DB // Add a one time pseudoID to the DB
_, err := db.StoreOneTimePseudoIDs(ctx, otk) _, err := db.StoreOneTimeCryptoIDs(ctx, otk)
MustNotError(t, err) MustNotError(t, err)
// Check the count of one time pseudoIDs is correct // Check the count of one time pseudoIDs is correct
count, err := db.OneTimePseudoIDsCount(ctx, userID) count, err := db.OneTimeCryptoIDsCount(ctx, userID)
MustNotError(t, err) MustNotError(t, err)
if count.KeyCount["pseudoid_curve25519"] != 1 { if count.KeyCount["pseudoid_curve25519"] != 1 {
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"]) t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
} }
// Check the actual pseudoid contents are correct // Check the actual pseudoid contents are correct
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"}) keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
MustNotError(t, err) MustNotError(t, err)
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON() keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
MustNotError(t, err) MustNotError(t, err)

View file

@ -168,12 +168,12 @@ type OneTimeKeys interface {
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
} }
type OneTimePseudoIDs interface { type OneTimeCryptoIDs interface {
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error) SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
} }
type DeviceKeys interface { type DeviceKeys interface {