cryptoID changes

This commit is contained in:
Devon Hudson 2023-10-31 16:52:58 -06:00
parent 60be1391bf
commit b7d320f8d1
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
13 changed files with 148 additions and 47 deletions

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -279,6 +280,7 @@ func SendInvite(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI, roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse { ) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req) body, evTime, reqErr := extractRequestData(req)
if reqErr != nil { if reqErr != nil {
@ -323,7 +325,7 @@ func SendInvite(
} }
// We already received the return value, so no need to check for an error here. // We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
return response return response
} }
@ -336,6 +338,7 @@ func sendInvite(
cfg *config.ClientAPI, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
cryptoIDs bool,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
validRoomID, err := spec.NewRoomID(roomID) validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
@ -372,7 +375,7 @@ func sendInvite(
JSON: spec.InternalServerError{}, JSON: spec.InternalServerError{},
}, err }, err
} }
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: roomserverAPI.InviteInput{ InviteInput: roomserverAPI.InviteInput{
RoomID: *validRoomID, RoomID: *validRoomID,
Inviter: *inviter, Inviter: *inviter,
@ -387,7 +390,7 @@ func sendInvite(
}, },
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
SendAsServer: string(device.UserDomain()), SendAsServer: string(device.UserDomain()),
}) }, cryptoIDs)
switch e := err.(type) { switch e := err.(type) {
case roomserverAPI.ErrInvalidID: case roomserverAPI.ErrInvalidID:
@ -410,10 +413,22 @@ func sendInvite(
}, err }, err
} }
return util.JSONResponse{ response := util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: struct{}{}, JSON: struct{}{},
}, nil }
type inviteCryptoIDResponse struct {
PDU json.RawMessage `json:"pdu"`
}
if inviteEvent != nil {
response.JSON = inviteCryptoIDResponse{
PDU: json.RawMessage(inviteEvent.JSON()),
}
}
return response, nil
} }
func buildMembershipEventDirect( func buildMembershipEventDirect(

View file

@ -480,7 +480,6 @@ func Setup(
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/invite", v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil { if r := rateLimits.Limit(req, device); r != nil {
@ -490,7 +489,20 @@ func Setup(
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc_cryptoids/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc_cryptoids/rooms/{roomID}/invite")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs // TODO: update for cryptoIDs

View file

@ -215,7 +215,7 @@ func SendServerNotice(
} }
if !membershipRes.IsInRoom { if !membershipRes.IsInRoom {
// re-invite the user // re-invite the user
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
if err != nil { if err != nil {
return res return res
} }

View file

@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
} }
type InputRoomEventsAPI interface { type InputRoomEventsAPI interface {
@ -243,7 +244,7 @@ type ClientRoomserverAPI interface {
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error) PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error) PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
@ -309,7 +310,7 @@ type FederationRoomserverAPI interface {
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
// Query a given amount (or less) of events prior to a given set of events. // Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error

View file

@ -54,6 +54,7 @@ type RoomserverInternalAPI struct {
ServerACLs *acls.ServerACLs ServerACLs *acls.ServerACLs
fsAPI fsAPI.RoomserverFederationAPI fsAPI fsAPI.RoomserverFederationAPI
asAPI asAPI.AppServiceInternalAPI asAPI asAPI.AppServiceInternalAPI
usAPI userapi.RoomserverUserAPI
NATSClient *nats.Conn NATSClient *nats.Conn
JetStream nats.JetStreamContext JetStream nats.JetStreamContext
Durable string Durable string
@ -214,6 +215,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI r.Inputer.UserAPI = userAPI
r.usAPI = userAPI
} }
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
@ -251,8 +253,9 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
func (r *RoomserverInternalAPI) PerformInvite( func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
) error { cryptoIDs bool,
return r.Inviter.PerformInvite(ctx, req) ) (gomatrixserverlib.PDU, error) {
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
} }
func (r *RoomserverInternalAPI) PerformLeave( func (r *RoomserverInternalAPI) PerformLeave(
@ -308,6 +311,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
return err return err
} }
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID)
}
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
if !ok { if !ok {

View file

@ -918,7 +918,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
} }
} }
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ _, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: api.InviteInput{ InviteInput: api.InviteInput{
RoomID: roomID, RoomID: roomID,
Inviter: userID, Inviter: userID,
@ -933,7 +933,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}, },
InviteRoomState: globalStrippedState, InviteRoomState: globalStrippedState,
SendAsServer: string(userID.Domain()), SendAsServer: string(userID.Domain()),
}) }, false)
switch e := err.(type) { switch e := err.(type) {
case api.ErrInvalidID: case api.ErrInvalidID:
return "", &util.JSONResponse{ return "", &util.JSONResponse{

View file

@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
) error { cryptoIDs bool,
) (gomatrixserverlib.PDU, error) {
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter) senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
if err != nil { if err != nil {
return err return nil, err
} else if senderID == nil { } else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID) return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
} }
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String()) info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
if err != nil { if err != nil {
return err return nil, err
} }
proto := gomatrixserverlib.ProtoEvent{ proto := gomatrixserverlib.ProtoEvent{
@ -152,11 +153,11 @@ func (r *Inviter) PerformInvite(
} }
if err = proto.SetContent(content); err != nil { if err = proto.SetContent(content); err != nil {
return err return nil, err
} }
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) { if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
} }
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain()) isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
@ -165,7 +166,7 @@ func (r *Inviter) PerformInvite(
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID) signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
if err != nil { if err != nil {
return err return nil, err
} }
} }
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
} }
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
}, },
CryptoIDs: cryptoIDs,
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
},
} }
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
@ -229,33 +234,38 @@ func (r *Inviter) PerformInvite(
switch e := err.(type) { switch e := err.(type) {
case spec.MatrixError: case spec.MatrixError:
if e.ErrCode == spec.ErrorForbidden { if e.ErrCode == spec.ErrorForbidden {
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
} }
} }
return err return nil, err
} }
// Send the invite event to the roomserver input stream. This will var response gomatrixserverlib.PDU
// notify existing users in the room about the invite, update the if !cryptoIDs {
// membership table and ensure that the event is ready and available // Send the invite event to the roomserver input stream. This will
// to use as an auth event when accepting the invite. // notify existing users in the room about the invite, update the
// It will NOT notify the invitee of this invite. // membership table and ensure that the event is ready and available
inputReq := &api.InputRoomEventsRequest{ // to use as an auth event when accepting the invite.
InputRoomEvents: []api.InputRoomEvent{ // It will NOT notify the invitee of this invite.
{ inputReq := &api.InputRoomEventsRequest{
Kind: api.KindNew, InputRoomEvents: []api.InputRoomEvent{
Event: &types.HeaderedEvent{PDU: inviteEvent}, {
Origin: req.InviteInput.Inviter.Domain(), Kind: api.KindNew,
SendAsServer: req.SendAsServer, Event: &types.HeaderedEvent{PDU: inviteEvent},
Origin: req.InviteInput.Inviter.Domain(),
SendAsServer: req.SendAsServer,
},
}, },
}, }
} inputRes := &api.InputRoomEventsResponse{}
inputRes := &api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) if err := inputRes.Err(); err != nil {
if err := inputRes.Err(); err != nil { util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed") return nil, api.ErrNotAllowed{Err: err}
return api.ErrNotAllowed{Err: err} }
} else {
response = inviteEvent
} }
return nil return response, nil
} }

View file

@ -269,7 +269,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer userStreamListener.Close() defer userStreamListener.Close()
giveup := func() util.JSONResponse { giveup := func() util.JSONResponse {
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
syncReq.Response.NextBatch = syncReq.Since syncReq.Response.NextBatch = syncReq.Since
// We should always try to include OTKs in sync responses, otherwise clients might upload keys // We should always try to include OTKs in sync responses, otherwise clients might upload keys
@ -285,9 +284,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled { if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
} }
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,

View file

@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
type RoomserverUserAPI interface { type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
} }
// api functions required by the media api // api functions required by the media api

View file

@ -17,9 +17,11 @@ package internal
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
@ -851,6 +853,41 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
} }
} }
type Ed25519Key struct {
Key spec.Base64Bytes `json:"key"`
}
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
pseudoIDs, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519")
if err != nil {
return "", err
}
logrus.Infof("Claimed one time pseuodID: %v", pseudoIDs)
if pseudoIDs != nil {
for key, value := range pseudoIDs.KeyJSON {
keyParts := strings.Split(key, ":")
if keyParts[0] == "ed25519" {
var key_bytes Ed25519Key
err := json.Unmarshal(value, &key_bytes)
if err != nil {
return "", err
}
length := len(key_bytes.Key)
if length != ed25519.PublicKeySize {
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
}
// TODO: cryptoIDs - store senderID for this user here?
return spec.SenderID(key_bytes.Key.Encode()), nil
}
}
}
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String())
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
// if we only want to update the display names, we can skip the checks below // if we only want to update the display names, we can skip the checks below
if onlyUpdateDisplayName { if onlyUpdateDisplayName {

View file

@ -178,6 +178,7 @@ type KeyDatabase interface {
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -962,6 +962,24 @@ func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string)
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID) return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
} }
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) {
var result *api.OneTimePseudoIDs
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm)
if err != nil {
return err
}
if keyJSON != nil {
result = &api.OneTimePseudoIDs{
UserID: userID.String(),
KeyJSON: keyJSON,
}
}
return nil
})
return result, err
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
) )
var oneTimePseudoIDsSchema = ` var oneTimePseudoIDsSchema = `
@ -183,6 +184,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
logrus.Warnf("No rows found for one time pseudoIDs")
return nil, nil return nil, nil
} }
return nil, err return nil, err
@ -192,6 +194,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
return nil, err return nil, err
} }
if keyJSON == "" { if keyJSON == "" {
logrus.Warnf("Empty key JSON for one time pseudoIDs")
return nil, nil return nil, nil
} }
return map[string]json.RawMessage{ return map[string]json.RawMessage{