cryptoID changes

This commit is contained in:
Devon Hudson 2023-10-31 16:52:58 -06:00
parent 60be1391bf
commit b7d320f8d1
No known key found for this signature in database
GPG key ID: CD06B18E77F6A628
13 changed files with 148 additions and 47 deletions

View file

@ -17,6 +17,7 @@ package routing
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"net/http"
"time"
@ -279,6 +280,7 @@ func SendInvite(
req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI,
cryptoIDs bool,
) util.JSONResponse {
body, evTime, reqErr := extractRequestData(req)
if reqErr != nil {
@ -323,7 +325,7 @@ func SendInvite(
}
// We already received the return value, so no need to check for an error here.
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime)
response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs)
return response
}
@ -336,6 +338,7 @@ func sendInvite(
cfg *config.ClientAPI,
rsAPI roomserverAPI.ClientRoomserverAPI,
asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time,
cryptoIDs bool,
) (util.JSONResponse, error) {
validRoomID, err := spec.NewRoomID(roomID)
if err != nil {
@ -372,7 +375,7 @@ func sendInvite(
JSON: spec.InternalServerError{},
}, err
}
err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: roomserverAPI.InviteInput{
RoomID: *validRoomID,
Inviter: *inviter,
@ -387,7 +390,7 @@ func sendInvite(
},
InviteRoomState: nil, // ask the roomserver to draw up invite room state for us
SendAsServer: string(device.UserDomain()),
})
}, cryptoIDs)
switch e := err.(type) {
case roomserverAPI.ErrInvalidID:
@ -410,10 +413,22 @@ func sendInvite(
}, err
}
return util.JSONResponse{
response := util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}, nil
}
type inviteCryptoIDResponse struct {
PDU json.RawMessage `json:"pdu"`
}
if inviteEvent != nil {
response.JSON = inviteCryptoIDResponse{
PDU: json.RawMessage(inviteEvent.JSON()),
}
}
return response, nil
}
func buildMembershipEventDirect(

View file

@ -480,7 +480,6 @@ func Setup(
return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs
v3mux.Handle("/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req, device); r != nil {
@ -490,7 +489,20 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI)
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/org.matrix.msc_cryptoids/rooms/{roomID}/invite",
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
logrus.Info("Processing request to /org.matrix.msc_cryptoids/rooms/{roomID}/invite")
if r := rateLimits.Limit(req, device); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true)
}),
).Methods(http.MethodPost, http.MethodOptions)
// TODO: update for cryptoIDs

View file

@ -215,7 +215,7 @@ func SendServerNotice(
}
if !membershipRes.IsInRoom {
// re-invite the user
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now())
res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false)
if err != nil {
return res
}

View file

@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error
ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
}
type InputRoomEventsAPI interface {
@ -243,7 +244,7 @@ type ClientRoomserverAPI interface {
PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error
PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error)
PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error)
PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error
PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error)
@ -309,7 +310,7 @@ type FederationRoomserverAPI interface {
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error
PerformInvite(ctx context.Context, req *PerformInviteRequest) error
PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error)
// Query a given amount (or less) of events prior to a given set of events.
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error

View file

@ -54,6 +54,7 @@ type RoomserverInternalAPI struct {
ServerACLs *acls.ServerACLs
fsAPI fsAPI.RoomserverFederationAPI
asAPI asAPI.AppServiceInternalAPI
usAPI userapi.RoomserverUserAPI
NATSClient *nats.Conn
JetStream nats.JetStreamContext
Durable string
@ -214,6 +215,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) {
r.Leaver.UserAPI = userAPI
r.Inputer.UserAPI = userAPI
r.usAPI = userAPI
}
func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) {
@ -251,8 +253,9 @@ func (r *RoomserverInternalAPI) PerformCreateRoom(
func (r *RoomserverInternalAPI) PerformInvite(
ctx context.Context,
req *api.PerformInviteRequest,
) error {
return r.Inviter.PerformInvite(ctx, req)
cryptoIDs bool,
) (gomatrixserverlib.PDU, error) {
return r.Inviter.PerformInvite(ctx, req, cryptoIDs)
}
func (r *RoomserverInternalAPI) PerformLeave(
@ -308,6 +311,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send
return err
}
func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID)
}
func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) {
roomVersion, ok := r.Cache.GetRoomVersion(roomID.String())
if !ok {

View file

@ -918,7 +918,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
_, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{
InviteInput: api.InviteInput{
RoomID: roomID,
Inviter: userID,
@ -933,7 +933,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
},
InviteRoomState: globalStrippedState,
SendAsServer: string(userID.Domain()),
})
}, false)
switch e := err.(type) {
case api.ErrInvalidID:
return "", &util.JSONResponse{

View file

@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership(
func (r *Inviter) PerformInvite(
ctx context.Context,
req *api.PerformInviteRequest,
) error {
cryptoIDs bool,
) (gomatrixserverlib.PDU, error) {
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter)
if err != nil {
return err
return nil, err
} else if senderID == nil {
return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID)
}
info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String())
if err != nil {
return err
return nil, err
}
proto := gomatrixserverlib.ProtoEvent{
@ -152,11 +153,11 @@ func (r *Inviter) PerformInvite(
}
if err = proto.SetContent(content); err != nil {
return err
return nil, err
}
if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
}
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain())
@ -165,7 +166,7 @@ func (r *Inviter) PerformInvite(
if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID)
if err != nil {
return err
return nil, err
}
}
@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite(
}
return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID)
},
CryptoIDs: cryptoIDs,
ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID)
},
}
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
@ -229,12 +234,14 @@ func (r *Inviter) PerformInvite(
switch e := err.(type) {
case spec.MatrixError:
if e.ErrCode == spec.ErrorForbidden {
return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)}
}
}
return err
return nil, err
}
var response gomatrixserverlib.PDU
if !cryptoIDs {
// Send the invite event to the roomserver input stream. This will
// notify existing users in the room about the invite, update the
// membership table and ensure that the event is ready and available
@ -254,8 +261,11 @@ func (r *Inviter) PerformInvite(
r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err := inputRes.Err(); err != nil {
util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed")
return api.ErrNotAllowed{Err: err}
return nil, api.ErrNotAllowed{Err: err}
}
} else {
response = inviteEvent
}
return nil
return response, nil
}

View file

@ -269,7 +269,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
defer userStreamListener.Close()
giveup := func() util.JSONResponse {
syncReq.Log.Info("Responding to sync since client gave up or timeout was reached")
syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached")
syncReq.Response.NextBatch = syncReq.Since
// We should always try to include OTKs in sync responses, otherwise clients might upload keys
@ -285,9 +284,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts")
}
syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount)
syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount)
}
return util.JSONResponse{
Code: http.StatusOK,

View file

@ -51,6 +51,7 @@ type AppserviceUserAPI interface {
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
}
// api functions required by the media api

View file

@ -17,9 +17,11 @@ package internal
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
@ -851,6 +853,41 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P
}
}
type Ed25519Key struct {
Key spec.Base64Bytes `json:"key"`
}
func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
pseudoIDs, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519")
if err != nil {
return "", err
}
logrus.Infof("Claimed one time pseuodID: %v", pseudoIDs)
if pseudoIDs != nil {
for key, value := range pseudoIDs.KeyJSON {
keyParts := strings.Split(key, ":")
if keyParts[0] == "ed25519" {
var key_bytes Ed25519Key
err := json.Unmarshal(value, &key_bytes)
if err != nil {
return "", err
}
length := len(key_bytes.Key)
if length != ed25519.PublicKeySize {
return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length)
}
// TODO: cryptoIDs - store senderID for this user here?
return spec.SenderID(key_bytes.Key.Encode()), nil
}
}
}
return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String())
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
// if we only want to update the display names, we can skip the checks below
if onlyUpdateDisplayName {

View file

@ -178,6 +178,7 @@ type KeyDatabase interface {
ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error)
OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error)
ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error

View file

@ -962,6 +962,24 @@ func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string)
return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID)
}
func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) {
var result *api.OneTimePseudoIDs
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm)
if err != nil {
return err
}
if keyJSON != nil {
result = &api.OneTimePseudoIDs{
UserID: userID.String(),
KeyJSON: keyJSON,
}
}
return nil
})
return result, err
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/sirupsen/logrus"
)
var oneTimePseudoIDsSchema = `
@ -183,6 +184,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
logrus.Warnf("No rows found for one time pseudoIDs")
return nil, nil
}
return nil, err
@ -192,6 +194,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID(
return nil, err
}
if keyJSON == "" {
logrus.Warnf("Empty key JSON for one time pseudoIDs")
return nil, nil
}
return map[string]json.RawMessage{