cryptoID changes
This commit is contained in:
parent
60be1391bf
commit
b7d320f8d1
|
@ -17,6 +17,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -279,6 +280,7 @@ func SendInvite(
|
||||||
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
|
||||||
|
cryptoIDs bool,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
body, evTime, reqErr := extractRequestData(req)
|
body, evTime, reqErr := extractRequestData(req)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
|
@ -323,7 +325,7 @@ func SendInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// We already received the return value, so no need to check for an error here.
|
// We already received the return value, so no need to check for an error here.
|
||||||
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
|
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,6 +338,7 @@ func sendInvite(
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.ClientRoomserverAPI,
|
rsAPI roomserverAPI.ClientRoomserverAPI,
|
||||||
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
|
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
|
||||||
|
cryptoIDs bool,
|
||||||
) (util.JSONResponse, error) {
|
) (util.JSONResponse, error) {
|
||||||
validRoomID, err := spec.NewRoomID(roomID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -372,7 +375,7 @@ func sendInvite(
|
||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||||
InviteInput: roomserverAPI.InviteInput{
|
InviteInput: roomserverAPI.InviteInput{
|
||||||
RoomID: *validRoomID,
|
RoomID: *validRoomID,
|
||||||
Inviter: *inviter,
|
Inviter: *inviter,
|
||||||
|
@ -387,7 +390,7 @@ func sendInvite(
|
||||||
},
|
},
|
||||||
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
|
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
|
||||||
SendAsServer: string(device.UserDomain()),
|
SendAsServer: string(device.UserDomain()),
|
||||||
})
|
}, cryptoIDs)
|
||||||
|
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case roomserverAPI.ErrInvalidID:
|
case roomserverAPI.ErrInvalidID:
|
||||||
|
@ -410,10 +413,22 @@ func sendInvite(
|
||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.JSONResponse{
|
response := util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: struct{}{},
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
type inviteCryptoIDResponse struct {
|
||||||
|
PDU json.RawMessage `json:"pdu"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if inviteEvent != nil {
|
||||||
|
response.JSON = inviteCryptoIDResponse{
|
||||||
|
PDU: json.RawMessage(inviteEvent.JSON()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildMembershipEventDirect(
|
func buildMembershipEventDirect(
|
||||||
|
|
|
@ -480,7 +480,6 @@ func Setup(
|
||||||
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
// TODO: update for cryptoIDs
|
|
||||||
v3mux.Handle("/rooms/{roomID}/invite",
|
v3mux.Handle("/rooms/{roomID}/invite",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.Limit(req, device); r != nil {
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
@ -490,7 +489,20 @@ func Setup(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
|
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
unstableMux.Handle("/org.matrix.msc_cryptoids/rooms/{roomID}/invite",
|
||||||
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
logrus.Info("Processing request to /org.matrix.msc_cryptoids/rooms/{roomID}/invite")
|
||||||
|
if r := rateLimits.Limit(req, device); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
// TODO: update for cryptoIDs
|
// TODO: update for cryptoIDs
|
||||||
|
|
|
@ -215,7 +215,7 @@ func SendServerNotice(
|
||||||
}
|
}
|
||||||
if !membershipRes.IsInRoom {
|
if !membershipRes.IsInRoom {
|
||||||
// re-invite the user
|
// re-invite the user
|
||||||
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
|
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
|
||||||
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
||||||
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
||||||
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
|
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
|
||||||
|
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputRoomEventsAPI interface {
|
type InputRoomEventsAPI interface {
|
||||||
|
@ -243,7 +244,7 @@ type ClientRoomserverAPI interface {
|
||||||
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
|
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
|
||||||
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
|
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
|
||||||
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
|
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
|
||||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||||
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
|
||||||
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
|
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
|
||||||
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
|
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
|
||||||
|
@ -309,7 +310,7 @@ type FederationRoomserverAPI interface {
|
||||||
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
|
||||||
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
|
||||||
|
|
||||||
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
|
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ type RoomserverInternalAPI struct {
|
||||||
ServerACLs *acls.ServerACLs
|
ServerACLs *acls.ServerACLs
|
||||||
fsAPI fsAPI.RoomserverFederationAPI
|
fsAPI fsAPI.RoomserverFederationAPI
|
||||||
asAPI asAPI.AppServiceInternalAPI
|
asAPI asAPI.AppServiceInternalAPI
|
||||||
|
usAPI userapi.RoomserverUserAPI
|
||||||
NATSClient *nats.Conn
|
NATSClient *nats.Conn
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
Durable string
|
Durable string
|
||||||
|
@ -214,6 +215,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
||||||
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
|
||||||
r.Leaver.UserAPI = userAPI
|
r.Leaver.UserAPI = userAPI
|
||||||
r.Inputer.UserAPI = userAPI
|
r.Inputer.UserAPI = userAPI
|
||||||
|
r.usAPI = userAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
|
||||||
|
@ -251,8 +253,9 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
|
||||||
func (r *RoomserverInternalAPI) PerformInvite(
|
func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformInviteRequest,
|
req *api.PerformInviteRequest,
|
||||||
) error {
|
cryptoIDs bool,
|
||||||
return r.Inviter.PerformInvite(ctx, req)
|
) (gomatrixserverlib.PDU, error) {
|
||||||
|
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) PerformLeave(
|
func (r *RoomserverInternalAPI) PerformLeave(
|
||||||
|
@ -308,6 +311,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
|
||||||
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
|
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -918,7 +918,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
_, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
|
||||||
InviteInput: api.InviteInput{
|
InviteInput: api.InviteInput{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
Inviter: userID,
|
Inviter: userID,
|
||||||
|
@ -933,7 +933,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
||||||
},
|
},
|
||||||
InviteRoomState: globalStrippedState,
|
InviteRoomState: globalStrippedState,
|
||||||
SendAsServer: string(userID.Domain()),
|
SendAsServer: string(userID.Domain()),
|
||||||
})
|
}, false)
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case api.ErrInvalidID:
|
case api.ErrInvalidID:
|
||||||
return "", &util.JSONResponse{
|
return "", &util.JSONResponse{
|
||||||
|
|
|
@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
|
||||||
func (r *Inviter) PerformInvite(
|
func (r *Inviter) PerformInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformInviteRequest,
|
req *api.PerformInviteRequest,
|
||||||
) error {
|
cryptoIDs bool,
|
||||||
|
) (gomatrixserverlib.PDU, error) {
|
||||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
|
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
} else if senderID == nil {
|
} else if senderID == nil {
|
||||||
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||||
}
|
}
|
||||||
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
|
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := gomatrixserverlib.ProtoEvent{
|
proto := gomatrixserverlib.ProtoEvent{
|
||||||
|
@ -152,11 +153,11 @@ func (r *Inviter) PerformInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = proto.SetContent(content); err != nil {
|
if err = proto.SetContent(content); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
|
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
|
||||||
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
|
||||||
}
|
}
|
||||||
|
|
||||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
|
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
|
||||||
|
@ -165,7 +166,7 @@ func (r *Inviter) PerformInvite(
|
||||||
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||||
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
|
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
|
||||||
}
|
}
|
||||||
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
|
||||||
},
|
},
|
||||||
|
CryptoIDs: cryptoIDs,
|
||||||
|
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||||
|
@ -229,12 +234,14 @@ func (r *Inviter) PerformInvite(
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case spec.MatrixError:
|
case spec.MatrixError:
|
||||||
if e.ErrCode == spec.ErrorForbidden {
|
if e.ErrCode == spec.ErrorForbidden {
|
||||||
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var response gomatrixserverlib.PDU
|
||||||
|
if !cryptoIDs {
|
||||||
// Send the invite event to the roomserver input stream. This will
|
// Send the invite event to the roomserver input stream. This will
|
||||||
// notify existing users in the room about the invite, update the
|
// notify existing users in the room about the invite, update the
|
||||||
// membership table and ensure that the event is ready and available
|
// membership table and ensure that the event is ready and available
|
||||||
|
@ -254,8 +261,11 @@ func (r *Inviter) PerformInvite(
|
||||||
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
|
||||||
if err := inputRes.Err(); err != nil {
|
if err := inputRes.Err(); err != nil {
|
||||||
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
|
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
|
||||||
return api.ErrNotAllowed{Err: err}
|
return nil, api.ErrNotAllowed{Err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response = inviteEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -269,7 +269,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
defer userStreamListener.Close()
|
defer userStreamListener.Close()
|
||||||
|
|
||||||
giveup := func() util.JSONResponse {
|
giveup := func() util.JSONResponse {
|
||||||
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
|
|
||||||
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
|
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
|
||||||
syncReq.Response.NextBatch = syncReq.Since
|
syncReq.Response.NextBatch = syncReq.Since
|
||||||
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
|
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
|
||||||
|
@ -285,9 +284,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
if err != nil && err != context.Canceled {
|
if err != nil && err != context.Canceled {
|
||||||
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
|
||||||
}
|
}
|
||||||
|
|
||||||
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
|
|
||||||
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
|
|
||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
|
|
@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
|
||||||
type RoomserverUserAPI interface {
|
type RoomserverUserAPI interface {
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||||
|
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// api functions required by the media api
|
// api functions required by the media api
|
||||||
|
|
|
@ -17,9 +17,11 @@ package internal
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -851,6 +853,41 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Ed25519Key struct {
|
||||||
|
Key spec.Base64Bytes `json:"key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
|
pseudoIDs, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Claimed one time pseuodID: %v", pseudoIDs)
|
||||||
|
|
||||||
|
if pseudoIDs != nil {
|
||||||
|
for key, value := range pseudoIDs.KeyJSON {
|
||||||
|
keyParts := strings.Split(key, ":")
|
||||||
|
if keyParts[0] == "ed25519" {
|
||||||
|
var key_bytes Ed25519Key
|
||||||
|
err := json.Unmarshal(value, &key_bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
length := len(key_bytes.Key)
|
||||||
|
if length != ed25519.PublicKeySize {
|
||||||
|
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
|
||||||
|
}
|
||||||
|
// TODO: cryptoIDs - store senderID for this user here?
|
||||||
|
return spec.SenderID(key_bytes.Key.Encode()), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String())
|
||||||
|
}
|
||||||
|
|
||||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||||
// if we only want to update the display names, we can skip the checks below
|
// if we only want to update the display names, we can skip the checks below
|
||||||
if onlyUpdateDisplayName {
|
if onlyUpdateDisplayName {
|
||||||
|
|
|
@ -178,6 +178,7 @@ type KeyDatabase interface {
|
||||||
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
|
||||||
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
|
||||||
|
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error)
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
|
@ -962,6 +962,24 @@ func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string)
|
||||||
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
|
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) {
|
||||||
|
var result *api.OneTimePseudoIDs
|
||||||
|
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if keyJSON != nil {
|
||||||
|
result = &api.OneTimePseudoIDs{
|
||||||
|
UserID: userID.String(),
|
||||||
|
KeyJSON: keyJSON,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var oneTimePseudoIDsSchema = `
|
var oneTimePseudoIDsSchema = `
|
||||||
|
@ -183,6 +184,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
logrus.Warnf("No rows found for one time pseudoIDs")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -192,6 +194,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if keyJSON == "" {
|
if keyJSON == "" {
|
||||||
|
logrus.Warnf("Empty key JSON for one time pseudoIDs")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return map[string]json.RawMessage{
|
return map[string]json.RawMessage{
|
||||||
|
|
Loading…
Reference in a new issue