mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 00:01:55 -06:00
Change cryptoid references from pseudoids
This commit is contained in:
parent
3cbccb9ed7
commit
b45e72830e
|
@ -100,7 +100,7 @@ type queryKeysRequest struct {
|
|||
type uploadKeysCryptoIDsRequest struct {
|
||||
DeviceKeys json.RawMessage `json:"device_keys"`
|
||||
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||
OneTimePseudoIDs map[string]json.RawMessage `json:"one_time_pseudoids"`
|
||||
OneTimeCryptoIDs map[string]json.RawMessage `json:"one_time_cryptoids"`
|
||||
}
|
||||
|
||||
func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
|
||||
|
@ -132,11 +132,11 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
|||
},
|
||||
}
|
||||
}
|
||||
if r.OneTimePseudoIDs != nil {
|
||||
uploadReq.OneTimePseudoIDs = []api.OneTimePseudoIDs{
|
||||
if r.OneTimeCryptoIDs != nil {
|
||||
uploadReq.OneTimeCryptoIDs = []api.OneTimeCryptoIDs{
|
||||
{
|
||||
UserID: device.UserID,
|
||||
KeyJSON: r.OneTimePseudoIDs,
|
||||
KeyJSON: r.OneTimeCryptoIDs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
|||
util.GetLogger(req.Context()).
|
||||
WithField("device keys", r.DeviceKeys).
|
||||
WithField("one-time keys", r.OneTimeKeys).
|
||||
WithField("one-time pseudoids", r.OneTimePseudoIDs).
|
||||
WithField("one-time cryptoids", r.OneTimeCryptoIDs).
|
||||
Info("Uploading keys")
|
||||
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
|
@ -170,16 +170,16 @@ func UploadKeysCryptoIDs(req *http.Request, keyAPI api.ClientKeyAPI, device *api
|
|||
if len(uploadRes.OneTimeKeyCounts) > 0 {
|
||||
keyCount = uploadRes.OneTimeKeyCounts[0].KeyCount
|
||||
}
|
||||
pseudoIDCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimePseudoIDCounts) > 0 {
|
||||
pseudoIDCount = uploadRes.OneTimePseudoIDCounts[0].KeyCount
|
||||
cryptoIDCount := make(map[string]int)
|
||||
if len(uploadRes.OneTimeCryptoIDCounts) > 0 {
|
||||
cryptoIDCount = uploadRes.OneTimeCryptoIDCounts[0].KeyCount
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: 200,
|
||||
JSON: struct {
|
||||
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||
OTPIDCounts interface{} `json:"one_time_pseudoid_counts"`
|
||||
}{keyCount, pseudoIDCount},
|
||||
OTKCounts interface{} `json:"one_time_key_counts"`
|
||||
OTIDCounts interface{} `json:"one_time_cryptoid_counts"`
|
||||
}{keyCount, cryptoIDCount},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -320,7 +320,7 @@ func Setup(
|
|||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
unstableMux.Handle("/org.matrix.msc4080/send_pdus/{txnID}",
|
||||
httputil.MakeAuthAPI("send_pdus", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/sendPDUs")
|
||||
logrus.Info("Processing request to /org.matrix.msc4080/send_pdus")
|
||||
if r := rateLimits.Limit(req, device); r != nil {
|
||||
return *r
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ func SendEvent(
|
|||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||
if innerErr != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -154,7 +154,7 @@ func SendEvent(
|
|||
}
|
||||
|
||||
// for power level events we need to replace the userID with the pseudoID
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -299,7 +299,7 @@ func SendEventCryptoIDs(
|
|||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && stateKey != nil {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && stateKey != nil {
|
||||
parsedRoomID, innerErr := spec.NewRoomID(roomID)
|
||||
if innerErr != nil {
|
||||
return util.JSONResponse{
|
||||
|
@ -345,7 +345,7 @@ func SendEventCryptoIDs(
|
|||
}
|
||||
|
||||
// for power level events we need to replace the userID with the pseudoID
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs && eventType == spec.MRoomPowerLevels {
|
||||
if (roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs) && eventType == spec.MRoomPowerLevels {
|
||||
err = updatePowerLevels(req, r, roomID, rsAPI)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
|
|
@ -214,7 +214,7 @@ func OnIncomingStateTypeRequest(
|
|||
}
|
||||
|
||||
// Translate user ID state keys to room keys in pseudo ID rooms
|
||||
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if roomVer == gomatrixserverlib.RoomVersionPseudoIDs || roomVer == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
parsedRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
|
|
|
@ -314,7 +314,7 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
|
|||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID)
|
||||
return r.usAPI.ClaimOneTimeCryptoID(ctx, roomID, userID)
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
||||
|
@ -328,7 +328,7 @@ func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID s
|
|||
roomVersion = roomInfo.RoomVersion
|
||||
}
|
||||
}
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs || roomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
privKey, err := r.GetOrCreateUserRoomPrivateKey(ctx, senderID, roomID)
|
||||
if err != nil {
|
||||
return fclient.SigningIdentity{}, err
|
||||
|
|
|
@ -445,7 +445,7 @@ func (r *Inputer) processRoomEvent(
|
|||
}
|
||||
|
||||
// TODO: Revist this to ensure we don't replace a current state mxid_mapping with an older one.
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && event.Type() == spec.MRoomMember {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && event.Type() == spec.MRoomMember {
|
||||
mapping := gomatrixserverlib.MemberContent{}
|
||||
if err = json.Unmarshal(event.Content(), &mapping); err != nil {
|
||||
return err
|
||||
|
|
|
@ -69,7 +69,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
|
|||
return nil, spec.InternalServerError{Err: err.Error()}
|
||||
}
|
||||
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
util.GetLogger(ctx).Infof("StoreUserRoomPublicKey - SenderID: %s UserID: %s RoomID: %s", senderID, userID.String(), roomID.String())
|
||||
bytes := spec.Base64Bytes{}
|
||||
err = bytes.Decode(string(senderID))
|
||||
|
@ -152,7 +152,7 @@ func (c *Creator) PerformCreateRoomCryptoIDs(ctx context.Context, userID spec.Us
|
|||
}
|
||||
|
||||
// If we are creating a room with pseudo IDs, create and sign the MXIDMapping
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
mapping := &gomatrixserverlib.MXIDMapping{
|
||||
UserRoomKey: senderID,
|
||||
UserID: userID.String(),
|
||||
|
|
|
@ -577,7 +577,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
|
|||
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
|
||||
if err == nil && info != nil {
|
||||
switch info.RoomVersion {
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||
case gomatrixserverlib.RoomVersionCryptoIDs:
|
||||
senderIDPtr, queryErr := r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||
if queryErr == nil {
|
||||
checkInvitePending = true
|
||||
|
@ -664,7 +664,7 @@ func (r *Joiner) performJoinRoomByIDCryptoIDs(
|
|||
identity := r.Cfg.Matrix.SigningIdentity
|
||||
|
||||
// at this point we know we have an existing room
|
||||
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if inRoomRes.RoomVersion == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
mapping := &gomatrixserverlib.MXIDMapping{
|
||||
UserRoomKey: senderID,
|
||||
UserID: userID.String(),
|
||||
|
|
|
@ -1044,7 +1044,7 @@ func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID,
|
|||
}
|
||||
|
||||
switch version {
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs, gomatrixserverlib.RoomVersionCryptoIDs:
|
||||
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -17,7 +17,7 @@ type RoomServer struct {
|
|||
|
||||
func (c *RoomServer) Defaults(opts DefaultOpts) {
|
||||
//c.DefaultRoomVersion = gomatrixserverlib.RoomVersionV10
|
||||
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionPseudoIDs
|
||||
c.DefaultRoomVersion = gomatrixserverlib.RoomVersionCryptoIDs
|
||||
if opts.Generate {
|
||||
if !opts.SingleDatabase {
|
||||
c.Database.ConnectionString = "file:roomserver.db"
|
||||
|
|
|
@ -46,13 +46,13 @@ func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceI
|
|||
return nil
|
||||
}
|
||||
|
||||
// OTPseudoIDCounts adds one-time pseudoID counts to the /sync response
|
||||
func OTPseudoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||
count, err := keyAPI.QueryOneTimePseudoIDs(ctx, userID)
|
||||
// OTCryptoIDCounts adds one-time pseudoID counts to the /sync response
|
||||
func OTCryptoIDCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID string, res *types.Response) error {
|
||||
count, err := keyAPI.QueryOneTimeCryptoIDs(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OTPseudoIDsCount = count.KeyCount
|
||||
res.OTCryptoIDsCount = count.KeyCount
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -51,8 +51,8 @@ func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyC
|
|||
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
|
||||
return nil
|
||||
}
|
||||
func (a *mockKeyAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimePseudoIDsCount{}, nil
|
||||
func (a *mockKeyAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||
}
|
||||
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
|
||||
return nil
|
||||
|
|
|
@ -41,7 +41,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
|
|||
req.Log.WithError(err).Error("internal.DeviceOTKCounts failed")
|
||||
return from
|
||||
}
|
||||
err = internal.OTPseudoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||
err = internal.OTCryptoIDCounts(req.Context, p.userAPI, req.Device.UserID, req.Response)
|
||||
if err != nil {
|
||||
req.Log.WithError(err).Error("internal.OTPseudoIDCounts failed")
|
||||
return from
|
||||
|
|
|
@ -280,7 +280,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
|
||||
}
|
||||
err = internal.OTPseudoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||
err = internal.OTCryptoIDCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Response)
|
||||
if err != nil && err != context.Canceled {
|
||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||
}
|
||||
|
|
|
@ -112,8 +112,8 @@ func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *syncUserAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (userapi.OneTimePseudoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimePseudoIDsCount{}, nil
|
||||
func (a *syncUserAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (userapi.OneTimeCryptoIDsCount, *userapi.KeyError) {
|
||||
return userapi.OneTimeCryptoIDsCount{}, nil
|
||||
}
|
||||
|
||||
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
|
||||
|
|
|
@ -153,7 +153,7 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDFor
|
|||
// TODO: Set Signatures & Hashes fields
|
||||
}
|
||||
|
||||
if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
if format != FormatSyncFederation && (se.Version() == gomatrixserverlib.RoomVersionPseudoIDs || se.Version() == gomatrixserverlib.RoomVersionCryptoIDs) {
|
||||
err := updatePseudoIDs(&ce, se, userIDForSender, format)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -304,7 +304,7 @@ func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomS
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != FormatSyncFederation {
|
||||
for i, ev := range inviteStateEvents {
|
||||
userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID))
|
||||
if userIDErr != nil {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
@ -365,7 +366,7 @@ type Response struct {
|
|||
ToDevice *ToDeviceResponse `json:"to_device,omitempty"`
|
||||
DeviceLists *DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
|
||||
OTPseudoIDsCount map[string]int `json:"one_time_pseudoids_count,omitempty"`
|
||||
OTCryptoIDsCount map[string]int `json:"one_time_cryptoids_count,omitempty"`
|
||||
}
|
||||
|
||||
func (r Response) MarshalJSON() ([]byte, error) {
|
||||
|
@ -428,7 +429,7 @@ func NewResponse() *Response {
|
|||
res.DeviceLists = &DeviceLists{}
|
||||
res.ToDevice = &ToDeviceResponse{}
|
||||
res.DeviceListsOTKCount = map[string]int{}
|
||||
res.OTPseudoIDsCount = map[string]int{}
|
||||
res.OTCryptoIDsCount = map[string]int{}
|
||||
|
||||
return &res
|
||||
}
|
||||
|
@ -532,7 +533,7 @@ type InviteResponse struct {
|
|||
InviteState struct {
|
||||
Events []json.RawMessage `json:"events"`
|
||||
} `json:"invite_state"`
|
||||
OneTimePseudoID string `json:"one_time_pseudoid,omitempty"`
|
||||
OneTimeCryptoID string `json:"one_time_cryptoid,omitempty"`
|
||||
}
|
||||
|
||||
// NewInviteResponse creates an empty response with initialised arrays.
|
||||
|
@ -540,13 +541,17 @@ func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *t
|
|||
res := InviteResponse{}
|
||||
res.InviteState.Events = []json.RawMessage{}
|
||||
|
||||
res.OneTimePseudoID = *event.PDU.StateKey()
|
||||
logrus.Infof("Room version: %s", event.Version())
|
||||
if event.Version() == gomatrixserverlib.RoomVersionCryptoIDs {
|
||||
logrus.Infof("Setting invite cryptoID to %s", *event.PDU.StateKey())
|
||||
res.OneTimeCryptoID = *event.PDU.StateKey()
|
||||
}
|
||||
|
||||
// First see if there's invite_room_state in the unsigned key of the invite.
|
||||
// If there is then unmarshal it into the response. This will contain the
|
||||
// partial room state such as join rules, room name etc.
|
||||
if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() {
|
||||
if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation {
|
||||
if (event.Version() == gomatrixserverlib.RoomVersionPseudoIDs || event.Version() == gomatrixserverlib.RoomVersionCryptoIDs) && eventFormat != synctypes.FormatSyncFederation {
|
||||
updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, inviteRoomState, event.PDU, event.RoomID(), eventFormat)
|
||||
|
|
|
@ -51,7 +51,7 @@ type AppserviceUserAPI interface {
|
|||
type RoomserverUserAPI interface {
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||
ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||
}
|
||||
|
||||
// api functions required by the media api
|
||||
|
@ -670,7 +670,7 @@ type UploadDeviceKeysAPI interface {
|
|||
type SyncKeyAPI interface {
|
||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||
QueryOneTimePseudoIDs(ctx context.Context, userID string) (OneTimePseudoIDsCount, *KeyError)
|
||||
QueryOneTimeCryptoIDs(ctx context.Context, userID string) (OneTimeCryptoIDsCount, *KeyError)
|
||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
|
@ -774,7 +774,7 @@ type OneTimeKeys struct {
|
|||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
type OneTimePseudoIDs struct {
|
||||
type OneTimeCryptoIDs struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// A map of algorithm:key_id => key JSON
|
||||
|
@ -788,7 +788,7 @@ func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
|||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimePseudoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
func (k *OneTimeCryptoIDs) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
@ -807,7 +807,7 @@ type OneTimeKeysCount struct {
|
|||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
type OneTimePseudoIDsCount struct {
|
||||
type OneTimeCryptoIDsCount struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// algorithm to count e.g:
|
||||
|
@ -823,7 +823,7 @@ type PerformUploadKeysRequest struct {
|
|||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
OneTimePseudoIDs []OneTimePseudoIDs
|
||||
OneTimeCryptoIDs []OneTimeCryptoIDs
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
|
@ -838,7 +838,7 @@ type PerformUploadKeysResponse struct {
|
|||
// A map of user_id -> device_id -> Error for tracking failures.
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
OneTimePseudoIDCounts []OneTimePseudoIDsCount
|
||||
OneTimeCryptoIDCounts []OneTimeCryptoIDsCount
|
||||
}
|
||||
|
||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||
|
|
|
@ -57,21 +57,21 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
|
|||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimePseudoIDs) > 0 {
|
||||
a.uploadOneTimePseudoIDs(ctx, req, res)
|
||||
if len(req.OneTimeCryptoIDs) > 0 {
|
||||
a.uploadOneTimeCryptoIDs(ctx, req, res)
|
||||
}
|
||||
logrus.Infof("One time pseudoIDs count before: %v", res.OneTimePseudoIDCounts)
|
||||
logrus.Infof("One time cryptoIDs count before: %v", res.OneTimeCryptoIDCounts)
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
otpIDs, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||
otpIDs, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimePseudoIDCounts = []api.OneTimePseudoIDsCount{*otpIDs}
|
||||
logrus.Infof("One time pseudoIDs count after: %v", res.OneTimePseudoIDCounts)
|
||||
res.OneTimeCryptoIDCounts = []api.OneTimeCryptoIDsCount{*otpIDs}
|
||||
logrus.Infof("One time cryptoIDs count after: %v", res.OneTimeCryptoIDCounts)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -193,11 +193,11 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimePseudoIDs(ctx context.Context, userID string) (api.OneTimePseudoIDsCount, *api.KeyError) {
|
||||
count, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, userID)
|
||||
func (a *UserInternalAPI) QueryOneTimeCryptoIDs(ctx context.Context, userID string) (api.OneTimeCryptoIDsCount, *api.KeyError) {
|
||||
count, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, userID)
|
||||
if err != nil {
|
||||
return api.OneTimePseudoIDsCount{}, &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||
return api.OneTimeCryptoIDsCount{}, &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTID counts: %s", err),
|
||||
}
|
||||
}
|
||||
return *count, nil
|
||||
|
@ -796,26 +796,26 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
|
|||
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
func (a *UserInternalAPI) uploadOneTimeCryptoIDs(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
}
|
||||
}
|
||||
if len(req.OneTimePseudoIDs) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimePseudoIDsCount(ctx, req.UserID)
|
||||
if len(req.OneTimeCryptoIDs) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimeCryptoIDsCount(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimePseudoIDsCount: %s", err),
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimeCryptoIDsCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
logrus.Infof("Uploading one-time pseudoIDs: early result count: %v", *counts)
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
logrus.Infof("Uploading one-time cryptoIDs: early result count: %v", *counts)
|
||||
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, key := range req.OneTimePseudoIDs {
|
||||
for _, key := range req.OneTimeCryptoIDs {
|
||||
// grab existing keys based on (user/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
|
@ -823,10 +823,10 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
|||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimePseudoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimeCryptoIDs(ctx, req.UserID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time pseudoIDs: " + err.Error(),
|
||||
Err: "failed to query existing one-time cryptoIDs: " + err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
@ -834,22 +834,22 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
|||
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time pseudoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time cryptoID already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.KeyDatabase.StoreOneTimePseudoIDs(ctx, key)
|
||||
counts, err := a.KeyDatabase.StoreOneTimeCryptoIDs(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time pseudoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time cryptoIDs: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
logrus.Infof("Uploading one-time pseudoIDs: result count: %v", *counts)
|
||||
res.OneTimePseudoIDCounts = append(res.OneTimePseudoIDCounts, *counts)
|
||||
logrus.Infof("Uploading one-time cryptoIDs: result count: %v", *counts)
|
||||
res.OneTimeCryptoIDCounts = append(res.OneTimeCryptoIDCounts, *counts)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -857,16 +857,16 @@ type Ed25519Key struct {
|
|||
Key spec.Base64Bytes `json:"key"`
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
pseudoID, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519")
|
||||
func (a *UserInternalAPI) ClaimOneTimeCryptoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||
cryptoID, err := a.KeyDatabase.ClaimOneTimeCryptoID(ctx, userID, "ed25519")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
logrus.Infof("Claimed one time pseuodID: %s", pseudoID)
|
||||
logrus.Infof("Claimed one time cryptoID: %s", cryptoID)
|
||||
|
||||
if pseudoID != nil {
|
||||
for key, value := range pseudoID.KeyJSON {
|
||||
if cryptoID != nil {
|
||||
for key, value := range cryptoID.KeyJSON {
|
||||
keyParts := strings.Split(key, ":")
|
||||
if keyParts[0] == "ed25519" {
|
||||
var key_bytes Ed25519Key
|
||||
|
@ -885,7 +885,7 @@ func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.
|
|||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String())
|
||||
return "", fmt.Errorf("failed claiming a valid one time cryptoID for this user: %s", userID.String())
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
|
|
|
@ -175,10 +175,10 @@ type KeyDatabase interface {
|
|||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error)
|
||||
ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||
OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||
ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
|
|
@ -27,78 +27,78 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimePseudoIDsSchema = `
|
||||
-- Stores one-time pseudoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||
var oneTimeCryptoIDsSchema = `
|
||||
-- Stores one-time cryptoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 3-uple of user/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_pseudoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||
CONSTRAINT keyserver_one_time_cryptoids_unique UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||
`
|
||||
|
||||
const upsertPseudoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
const upsertCryptoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_pseudoids_unique" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_cryptoids_unique" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimePseudoIDsSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||
const selectOneTimeCryptoIDsSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1 AND concat(algorithm, ':', key_id) = ANY($2);"
|
||||
|
||||
const selectPseudoIDsCountSQL = "" +
|
||||
const selectCryptoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimePseudoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
const deleteOneTimeCryptoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectPseudoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
const selectCryptoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimePseudoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||
const deleteOneTimeCryptoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||
|
||||
type oneTimePseudoIDsStatements struct {
|
||||
type oneTimeCryptoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsCountStmt *sql.Stmt
|
||||
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||
upsertCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsCountStmt *sql.Stmt
|
||||
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||
s := &oneTimePseudoIDsStatements{
|
||||
func NewPostgresOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||
s := &oneTimeCryptoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
|
@ -114,16 +114,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
|
@ -135,26 +135,26 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
|
|||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) {
|
||||
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
|
@ -167,25 +167,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(ctx context.Context,
|
|||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -149,7 +149,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewPostgresOneTimePseudoIDsTable(db)
|
||||
otpid, err := NewPostgresOneTimeCryptoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -176,7 +176,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimePseudoIDsTable: otpid,
|
||||
OneTimeCryptoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -65,7 +65,7 @@ type Database struct {
|
|||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
OneTimePseudoIDsTable tables.OneTimePseudoIDs
|
||||
OneTimeCryptoIDsTable tables.OneTimeCryptoIDs
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
|
@ -946,31 +946,31 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
|
|||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimePseudoIDsTable.SelectOneTimePseudoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||
func (d *KeyDatabase) ExistingOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimeCryptoIDsTable.SelectOneTimeCryptoIDs(ctx, userID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (counts *api.OneTimePseudoIDsCount, err error) {
|
||||
func (d *KeyDatabase) StoreOneTimeCryptoIDs(ctx context.Context, keys api.OneTimeCryptoIDs) (counts *api.OneTimeCryptoIDsCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimePseudoIDsTable.InsertOneTimePseudoIDs(ctx, txn, keys)
|
||||
counts, err = d.OneTimeCryptoIDsTable.InsertOneTimeCryptoIDs(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
|
||||
func (d *KeyDatabase) OneTimeCryptoIDsCount(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
return d.OneTimeCryptoIDsTable.CountOneTimeCryptoIDs(ctx, userID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) {
|
||||
var result *api.OneTimePseudoIDs
|
||||
func (d *KeyDatabase) ClaimOneTimeCryptoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimeCryptoIDs, error) {
|
||||
var result *api.OneTimeCryptoIDs
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm)
|
||||
keyJSON, err := d.OneTimeCryptoIDsTable.SelectAndDeleteOneTimeCryptoID(ctx, txn, userID.String(), algorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = &api.OneTimePseudoIDs{
|
||||
result = &api.OneTimeCryptoIDs{
|
||||
UserID: userID.String(),
|
||||
KeyJSON: keyJSON,
|
||||
}
|
||||
|
|
|
@ -27,9 +27,9 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var oneTimePseudoIDsSchema = `
|
||||
-- Stores one-time pseudoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
||||
var oneTimeCryptoIDsSchema = `
|
||||
-- Stores one-time cryptoIDs for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_cryptoids (
|
||||
user_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
|
@ -39,66 +39,66 @@ CREATE TABLE IF NOT EXISTS keyserver_one_time_pseudoids (
|
|||
UNIQUE (user_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_pseudoids_idx ON keyserver_one_time_pseudoids (user_id);
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_cryptoids_idx ON keyserver_one_time_cryptoids (user_id);
|
||||
`
|
||||
|
||||
const upsertPseudoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_pseudoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
const upsertCryptoIDsSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_cryptoids (user_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (user_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $5"
|
||||
|
||||
const selectOneTimePseudoIDsSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_pseudoids WHERE user_id=$1"
|
||||
const selectOneTimeCryptoIDsSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_cryptoids WHERE user_id=$1"
|
||||
|
||||
const selectPseudoIDsCountSQL = "" +
|
||||
const selectCryptoIDsCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_pseudoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_cryptoids WHERE user_id = $1 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimePseudoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
const deleteOneTimeCryptoIDSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 AND key_id = $3"
|
||||
|
||||
const selectPseudoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_pseudoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
const selectCryptoIDByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_cryptoids WHERE user_id = $1 AND algorithm = $2 LIMIT 1"
|
||||
|
||||
const deleteOneTimePseudoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_pseudoids WHERE user_id = $1"
|
||||
const deleteOneTimeCryptoIDsSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_cryptoids WHERE user_id = $1"
|
||||
|
||||
type oneTimePseudoIDsStatements struct {
|
||||
type oneTimeCryptoIDsStatements struct {
|
||||
db *sql.DB
|
||||
upsertPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsStmt *sql.Stmt
|
||||
selectPseudoIDsCountStmt *sql.Stmt
|
||||
selectPseudoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDStmt *sql.Stmt
|
||||
deleteOneTimePseudoIDsStmt *sql.Stmt
|
||||
upsertCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsStmt *sql.Stmt
|
||||
selectCryptoIDsCountStmt *sql.Stmt
|
||||
selectCryptoIDByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDStmt *sql.Stmt
|
||||
deleteOneTimeCryptoIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimePseudoIDsTable(db *sql.DB) (tables.OneTimePseudoIDs, error) {
|
||||
s := &oneTimePseudoIDsStatements{
|
||||
func NewSqliteOneTimeCryptoIDsTable(db *sql.DB) (tables.OneTimeCryptoIDs, error) {
|
||||
s := &oneTimeCryptoIDsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimePseudoIDsSchema)
|
||||
_, err := db.Exec(oneTimeCryptoIDsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertPseudoIDsStmt, upsertPseudoIDsSQL},
|
||||
{&s.selectPseudoIDsStmt, selectOneTimePseudoIDsSQL},
|
||||
{&s.selectPseudoIDsCountStmt, selectPseudoIDsCountSQL},
|
||||
{&s.selectPseudoIDByAlgorithmStmt, selectPseudoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimePseudoIDStmt, deleteOneTimePseudoIDSQL},
|
||||
{&s.deleteOneTimePseudoIDsStmt, deleteOneTimePseudoIDsSQL},
|
||||
{&s.upsertCryptoIDsStmt, upsertCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsStmt, selectOneTimeCryptoIDsSQL},
|
||||
{&s.selectCryptoIDsCountStmt, selectCryptoIDsCountSQL},
|
||||
{&s.selectCryptoIDByAlgorithmStmt, selectCryptoIDByAlgorithmSQL},
|
||||
{&s.deleteOneTimeCryptoIDStmt, deleteOneTimeCryptoIDSQL},
|
||||
{&s.deleteOneTimeCryptoIDsStmt, deleteOneTimeCryptoIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectPseudoIDsStmt.QueryContext(ctx, userID)
|
||||
func (s *oneTimeCryptoIDsStatements) SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectCryptoIDsStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
|
@ -121,16 +121,16 @@ func (s *oneTimePseudoIDsStatements) SelectOneTimePseudoIDs(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) {
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
func (s *oneTimeCryptoIDsStatements) CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error) {
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectPseudoIDsCountStmt.QueryContext(ctx, userID)
|
||||
rows, err := s.selectCryptoIDsCountStmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
|
@ -142,28 +142,28 @@ func (s *oneTimePseudoIDsStatements) CountOneTimePseudoIDs(ctx context.Context,
|
|||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs,
|
||||
) (*api.OneTimePseudoIDsCount, error) {
|
||||
func (s *oneTimeCryptoIDsStatements) InsertOneTimeCryptoIDs(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs,
|
||||
) (*api.OneTimeCryptoIDsCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimePseudoIDsCount{
|
||||
counts := &api.OneTimeCryptoIDsCount{
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertPseudoIDsStmt).ExecContext(
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertCryptoIDsStmt).ExecContext(
|
||||
ctx, keys.UserID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPseudoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCryptoIDsCountStmt).QueryContext(ctx, keys.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectPseudoIDsCountStmt: rows.close() failed")
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCryptoIDsCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
|
@ -176,25 +176,25 @@ func (s *oneTimePseudoIDsStatements) InsertOneTimePseudoIDs(
|
|||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||
func (s *oneTimeCryptoIDsStatements) SelectAndDeleteOneTimeCryptoID(
|
||||
ctx context.Context, txn *sql.Tx, userID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectCryptoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logrus.Warnf("No rows found for one time pseudoIDs")
|
||||
logrus.Warnf("No rows found for one time cryptoIDs")
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimePseudoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeCryptoIDStmt).ExecContext(ctx, userID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
logrus.Warnf("Empty key JSON for one time pseudoIDs")
|
||||
logrus.Warnf("Empty key JSON for one time cryptoIDs")
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
|
@ -202,7 +202,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
|||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimePseudoIDsStatements) DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimePseudoIDsStmt).ExecContext(ctx, userID)
|
||||
func (s *oneTimeCryptoIDsStatements) DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeCryptoIDsStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
|
@ -146,7 +146,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otpid, err := NewSqliteOneTimePseudoIDsTable(db)
|
||||
otpid, err := NewSqliteOneTimeCryptoIDsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
|
|||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
OneTimePseudoIDsTable: otpid,
|
||||
OneTimeCryptoIDsTable: otpid,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
|
|
|
@ -760,29 +760,29 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestOneTimePseudoIDs(t *testing.T) {
|
||||
func TestOneTimeCryptoIDs(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
userID := "@alice:localhost"
|
||||
otk := api.OneTimePseudoIDs{
|
||||
otk := api.OneTimeCryptoIDs{
|
||||
UserID: userID,
|
||||
KeyJSON: map[string]json.RawMessage{"pseudoid_curve25519:KEY1": []byte(`{"key":"v1"}`)},
|
||||
}
|
||||
|
||||
// Add a one time pseudoID to the DB
|
||||
_, err := db.StoreOneTimePseudoIDs(ctx, otk)
|
||||
_, err := db.StoreOneTimeCryptoIDs(ctx, otk)
|
||||
MustNotError(t, err)
|
||||
|
||||
// Check the count of one time pseudoIDs is correct
|
||||
count, err := db.OneTimePseudoIDsCount(ctx, userID)
|
||||
count, err := db.OneTimeCryptoIDsCount(ctx, userID)
|
||||
MustNotError(t, err)
|
||||
if count.KeyCount["pseudoid_curve25519"] != 1 {
|
||||
t.Fatalf("Expected 1 pseudoID, got %d", count.KeyCount["pseudoid_curve25519"])
|
||||
}
|
||||
|
||||
// Check the actual pseudoid contents are correct
|
||||
keysJSON, err := db.ExistingOneTimePseudoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||
keysJSON, err := db.ExistingOneTimeCryptoIDs(ctx, userID, []string{"pseudoid_curve25519:KEY1"})
|
||||
MustNotError(t, err)
|
||||
keyJSON, err := keysJSON["pseudoid_curve25519:KEY1"].MarshalJSON()
|
||||
MustNotError(t, err)
|
||||
|
|
|
@ -168,12 +168,12 @@ type OneTimeKeys interface {
|
|||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
}
|
||||
|
||||
type OneTimePseudoIDs interface {
|
||||
SelectOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimePseudoIDs(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||
InsertOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||
SelectAndDeleteOneTimePseudoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimePseudoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
type OneTimeCryptoIDs interface {
|
||||
SelectOneTimeCryptoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimeCryptoIDs(ctx context.Context, userID string) (*api.OneTimeCryptoIDsCount, error)
|
||||
InsertOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, keys api.OneTimeCryptoIDs) (*api.OneTimeCryptoIDsCount, error)
|
||||
SelectAndDeleteOneTimeCryptoID(ctx context.Context, txn *sql.Tx, userID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimeCryptoIDs(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
|
|
Loading…
Reference in a new issue