diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 8b8cc47bc..63f820cd8 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -17,6 +17,7 @@ package routing import ( "context" "crypto/ed25519" + "encoding/json" "fmt" "net/http" "time" @@ -279,6 +280,7 @@ func SendInvite( req *http.Request, profileAPI userapi.ClientUserAPI, device *userapi.Device, roomID string, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, + cryptoIDs bool, ) util.JSONResponse { body, evTime, reqErr := extractRequestData(req) if reqErr != nil { @@ -323,7 +325,7 @@ func SendInvite( } // We already received the return value, so no need to check for an error here. - response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime) + response, _ := sendInvite(req.Context(), profileAPI, device, roomID, body.UserID, body.Reason, cfg, rsAPI, asAPI, evTime, cryptoIDs) return response } @@ -336,6 +338,7 @@ func sendInvite( cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, evTime time.Time, + cryptoIDs bool, ) (util.JSONResponse, error) { validRoomID, err := spec.NewRoomID(roomID) if err != nil { @@ -372,7 +375,7 @@ func sendInvite( JSON: spec.InternalServerError{}, }, err } - err = rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + inviteEvent, err := rsAPI.PerformInvite(ctx, &api.PerformInviteRequest{ InviteInput: roomserverAPI.InviteInput{ RoomID: *validRoomID, Inviter: *inviter, @@ -387,7 +390,7 @@ func sendInvite( }, InviteRoomState: nil, // ask the roomserver to draw up invite room state for us SendAsServer: string(device.UserDomain()), - }) + }, cryptoIDs) switch e := err.(type) { case roomserverAPI.ErrInvalidID: @@ -410,10 +413,22 @@ func sendInvite( }, err } - return util.JSONResponse{ + response := util.JSONResponse{ Code: http.StatusOK, JSON: struct{}{}, - }, nil + } + + type inviteCryptoIDResponse struct { + PDU json.RawMessage `json:"pdu"` + } + + if inviteEvent != nil { + response.JSON = inviteCryptoIDResponse{ + PDU: json.RawMessage(inviteEvent.JSON()), + } + } + + return response, nil } func buildMembershipEventDirect( diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 1cf689e6a..62cdb584e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -480,7 +480,6 @@ func Setup( return SendBan(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) }), ).Methods(http.MethodPost, http.MethodOptions) - // TODO: update for cryptoIDs v3mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if r := rateLimits.Limit(req, device); r != nil { @@ -490,7 +489,20 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI) + return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, false) + }), + ).Methods(http.MethodPost, http.MethodOptions) + unstableMux.Handle("/org.matrix.msc_cryptoids/rooms/{roomID}/invite", + httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + logrus.Info("Processing request to /org.matrix.msc_cryptoids/rooms/{roomID}/invite") + if r := rateLimits.Limit(req, device); r != nil { + return *r + } + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return SendInvite(req, userAPI, device, vars["roomID"], cfg, rsAPI, asAPI, true) }), ).Methods(http.MethodPost, http.MethodOptions) // TODO: update for cryptoIDs diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index 5deb559df..c711f623d 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -215,7 +215,7 @@ func SendServerNotice( } if !membershipRes.IsInRoom { // re-invite the user - res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now()) + res, err := sendInvite(ctx, userAPI, senderDevice, roomID, r.UserID, "Server notice room", cfgClient, rsAPI, asAPI, time.Now(), false) if err != nil { return res } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 6e02550f0..2f2ed295d 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -92,6 +92,7 @@ type UserRoomPrivateKeyCreator interface { // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) StoreUserRoomPublicKey(ctx context.Context, senderID spec.SenderID, userID spec.UserID, roomID spec.RoomID) error + ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) } type InputRoomEventsAPI interface { @@ -243,7 +244,7 @@ type ClientRoomserverAPI interface { PerformAdminDownloadState(ctx context.Context, roomID, userID string, serverName spec.ServerName) error PerformPeek(ctx context.Context, req *PerformPeekRequest) (roomID string, err error) PerformUnpeek(ctx context.Context, roomID, userID, deviceID string) error - PerformInvite(ctx context.Context, req *PerformInviteRequest) error + PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) PerformSendJoinCryptoIDs(ctx context.Context, req *PerformJoinRequestCryptoIDs) error PerformJoinCryptoIDs(ctx context.Context, req *PerformJoinRequest) (join gomatrixserverlib.PDU, roomID string, version gomatrixserverlib.RoomVersion, serverName spec.ServerName, err error) @@ -309,7 +310,7 @@ type FederationRoomserverAPI interface { PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error - PerformInvite(ctx context.Context, req *PerformInviteRequest) error + PerformInvite(ctx context.Context, req *PerformInviteRequest, cryptoIDs bool) (gomatrixserverlib.PDU, error) // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1e08f6a3a..64f7f3c2d 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -54,6 +54,7 @@ type RoomserverInternalAPI struct { ServerACLs *acls.ServerACLs fsAPI fsAPI.RoomserverFederationAPI asAPI asAPI.AppServiceInternalAPI + usAPI userapi.RoomserverUserAPI NATSClient *nats.Conn JetStream nats.JetStreamContext Durable string @@ -214,6 +215,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI r.Inputer.UserAPI = userAPI + r.usAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { @@ -251,8 +253,9 @@ func (r *RoomserverInternalAPI) PerformCreateRoom( func (r *RoomserverInternalAPI) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, -) error { - return r.Inviter.PerformInvite(ctx, req) + cryptoIDs bool, +) (gomatrixserverlib.PDU, error) { + return r.Inviter.PerformInvite(ctx, req, cryptoIDs) } func (r *RoomserverInternalAPI) PerformLeave( @@ -308,6 +311,10 @@ func (r *RoomserverInternalAPI) StoreUserRoomPublicKey(ctx context.Context, send return err } +func (r *RoomserverInternalAPI) ClaimOneTimeSenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + return r.usAPI.ClaimOneTimePseudoID(ctx, roomID, userID) +} + func (r *RoomserverInternalAPI) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) { roomVersion, ok := r.Cache.GetRoomVersion(roomID.String()) if !ok { diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 38ce06eab..996c797e7 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -918,7 +918,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ + _, err = c.RSAPI.PerformInvite(ctx, &api.PerformInviteRequest{ InviteInput: api.InviteInput{ RoomID: roomID, Inviter: userID, @@ -933,7 +933,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo }, InviteRoomState: globalStrippedState, SendAsServer: string(userID.Domain()), - }) + }, false) switch e := err.(type) { case api.ErrInvalidID: return "", &util.JSONResponse{ diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 3abb69cb9..6f3eb036e 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -125,16 +125,17 @@ func (r *Inviter) ProcessInviteMembership( func (r *Inviter) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, -) error { + cryptoIDs bool, +) (gomatrixserverlib.PDU, error) { senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.InviteInput.RoomID, req.InviteInput.Inviter) if err != nil { - return err + return nil, err } else if senderID == nil { - return fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID) + return nil, fmt.Errorf("sender ID not found for %s in %s", req.InviteInput.Inviter, req.InviteInput.RoomID) } info, err := r.DB.RoomInfo(ctx, req.InviteInput.RoomID.String()) if err != nil { - return err + return nil, err } proto := gomatrixserverlib.ProtoEvent{ @@ -152,11 +153,11 @@ func (r *Inviter) PerformInvite( } if err = proto.SetContent(content); err != nil { - return err + return nil, err } if !r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Inviter.Domain()) { - return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} + return nil, api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} } isTargetLocal := r.Cfg.Matrix.IsLocalServerName(req.InviteInput.Invitee.Domain()) @@ -165,7 +166,7 @@ func (r *Inviter) PerformInvite( if info.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { signingKey, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, req.InviteInput.Inviter, req.InviteInput.RoomID) if err != nil { - return err + return nil, err } } @@ -222,6 +223,10 @@ func (r *Inviter) PerformInvite( } return r.RSAPI.StoreUserRoomPublicKey(ctx, senderID, *storeUserID, roomID) }, + CryptoIDs: cryptoIDs, + ClaimSenderID: func(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + return r.RSAPI.ClaimOneTimeSenderIDForUser(ctx, roomID, userID) + }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) @@ -229,33 +234,38 @@ func (r *Inviter) PerformInvite( switch e := err.(type) { case spec.MatrixError: if e.ErrCode == spec.ErrorForbidden { - return api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} + return nil, api.ErrNotAllowed{Err: fmt.Errorf("%s", e.Err)} } } - return err + return nil, err } - // Send the invite event to the roomserver input stream. This will - // notify existing users in the room about the invite, update the - // membership table and ensure that the event is ready and available - // to use as an auth event when accepting the invite. - // It will NOT notify the invitee of this invite. - inputReq := &api.InputRoomEventsRequest{ - InputRoomEvents: []api.InputRoomEvent{ - { - Kind: api.KindNew, - Event: &types.HeaderedEvent{PDU: inviteEvent}, - Origin: req.InviteInput.Inviter.Domain(), - SendAsServer: req.SendAsServer, + var response gomatrixserverlib.PDU + if !cryptoIDs { + // Send the invite event to the roomserver input stream. This will + // notify existing users in the room about the invite, update the + // membership table and ensure that the event is ready and available + // to use as an auth event when accepting the invite. + // It will NOT notify the invitee of this invite. + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{ + { + Kind: api.KindNew, + Event: &types.HeaderedEvent{PDU: inviteEvent}, + Origin: req.InviteInput.Inviter.Domain(), + SendAsServer: req.SendAsServer, + }, }, - }, - } - inputRes := &api.InputRoomEventsResponse{} - r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) - if err := inputRes.Err(); err != nil { - util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed") - return api.ErrNotAllowed{Err: err} + } + inputRes := &api.InputRoomEventsResponse{} + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err := inputRes.Err(); err != nil { + util.GetLogger(ctx).WithField("event_id", inviteEvent.EventID()).Error("r.InputRoomEvents failed") + return nil, api.ErrNotAllowed{Err: err} + } + } else { + response = inviteEvent } - return nil + return response, nil } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 2734bbac2..28862937f 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -269,7 +269,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. defer userStreamListener.Close() giveup := func() util.JSONResponse { - syncReq.Log.Info("Responding to sync since client gave up or timeout was reached") syncReq.Log.Debugln("Responding to sync since client gave up or timeout was reached") syncReq.Response.NextBatch = syncReq.Since // We should always try to include OTKs in sync responses, otherwise clients might upload keys @@ -285,9 +284,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. if err != nil && err != context.Canceled { syncReq.Log.WithError(err).Warn("failed to get OTPseudoID counts") } - - syncReq.Log.Infof("one-time pseudoID counts: %v", syncReq.Response.OTPseudoIDsCount) - syncReq.Log.Infof("one-time key counts: %v", syncReq.Response.DeviceListsOTKCount) } return util.JSONResponse{ Code: http.StatusOK, diff --git a/userapi/api/api.go b/userapi/api/api.go index 56a409f6d..83909622c 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -51,6 +51,7 @@ type AppserviceUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) + ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) } // api functions required by the media api diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go index 85f245435..3b538e0b0 100644 --- a/userapi/internal/key_api.go +++ b/userapi/internal/key_api.go @@ -17,9 +17,11 @@ package internal import ( "bytes" "context" + "crypto/ed25519" "encoding/json" "errors" "fmt" + "strings" "sync" "time" @@ -851,6 +853,41 @@ func (a *UserInternalAPI) uploadOneTimePseudoIDs(ctx context.Context, req *api.P } } +type Ed25519Key struct { + Key spec.Base64Bytes `json:"key"` +} + +func (a *UserInternalAPI) ClaimOneTimePseudoID(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + pseudoIDs, err := a.KeyDatabase.ClaimOneTimePseudoID(ctx, userID, "ed25519") + if err != nil { + return "", err + } + + logrus.Infof("Claimed one time pseuodID: %v", pseudoIDs) + + if pseudoIDs != nil { + for key, value := range pseudoIDs.KeyJSON { + keyParts := strings.Split(key, ":") + if keyParts[0] == "ed25519" { + var key_bytes Ed25519Key + err := json.Unmarshal(value, &key_bytes) + if err != nil { + return "", err + } + + length := len(key_bytes.Key) + if length != ed25519.PublicKeySize { + return "", fmt.Errorf("Invalid ed25519 public key, %d is the wrong size", length) + } + // TODO: cryptoIDs - store senderID for this user here? + return spec.SenderID(key_bytes.Key.Encode()), nil + } + } + } + + return "", fmt.Errorf("failed claiming a valid one time pseudoID for this user: %s", userID.String()) +} + func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { // if we only want to update the display names, we can skip the checks below if onlyUpdateDisplayName { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index af1da509c..315d4e4ac 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -178,6 +178,7 @@ type KeyDatabase interface { ExistingOneTimePseudoIDs(ctx context.Context, userID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) StoreOneTimePseudoIDs(ctx context.Context, keys api.OneTimePseudoIDs) (*api.OneTimePseudoIDsCount, error) OneTimePseudoIDsCount(ctx context.Context, userID string) (*api.OneTimePseudoIDsCount, error) + ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 1a8c2a0d6..374dab24c 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -962,6 +962,24 @@ func (d *KeyDatabase) OneTimePseudoIDsCount(ctx context.Context, userID string) return d.OneTimePseudoIDsTable.CountOneTimePseudoIDs(ctx, userID) } +func (d *KeyDatabase) ClaimOneTimePseudoID(ctx context.Context, userID spec.UserID, algorithm string) (*api.OneTimePseudoIDs, error) { + var result *api.OneTimePseudoIDs + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + keyJSON, err := d.OneTimePseudoIDsTable.SelectAndDeleteOneTimePseudoID(ctx, txn, userID.String(), algorithm) + if err != nil { + return err + } + if keyJSON != nil { + result = &api.OneTimePseudoIDs{ + UserID: userID.String(), + KeyJSON: keyJSON, + } + } + return nil + }) + return result, err +} + func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } diff --git a/userapi/storage/sqlite3/one_time_pseudoids_table.go b/userapi/storage/sqlite3/one_time_pseudoids_table.go index abb71e09a..e8015373b 100644 --- a/userapi/storage/sqlite3/one_time_pseudoids_table.go +++ b/userapi/storage/sqlite3/one_time_pseudoids_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/sirupsen/logrus" ) var oneTimePseudoIDsSchema = ` @@ -183,6 +184,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( err := sqlutil.TxStmtContext(ctx, txn, s.selectPseudoIDByAlgorithmStmt).QueryRowContext(ctx, userID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { + logrus.Warnf("No rows found for one time pseudoIDs") return nil, nil } return nil, err @@ -192,6 +194,7 @@ func (s *oneTimePseudoIDsStatements) SelectAndDeleteOneTimePseudoID( return nil, err } if keyJSON == "" { + logrus.Warnf("Empty key JSON for one time pseudoIDs") return nil, nil } return map[string]json.RawMessage{