diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index a79b369a4..0274d030e 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -65,8 +65,16 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } var senderID spec.SenderID if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - // TODO: pseudoIDs - generate senderID kere! - senderID = "pseudo_id.sender.key" + // create user room key if needed + key, err := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) } else { senderID = spec.SenderID(userID.String()) } @@ -363,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - // create user room key if needed - if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: spec.InternalServerError{}, - } - } - } - // send the remaining events if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 11d2bf395..fa66bc61c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -25,6 +25,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -191,49 +192,58 @@ func (r *Joiner) performJoinRoomByID( return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)} } + // Look up the room NID for the supplied room ID. var senderID spec.SenderID - var roomVersion gomatrixserverlib.RoomVersion - if forceFederatedJoin { - // TODO: pseudoIDs - lookup room version kere! - } else { - roomVersion, err = r.RSAPI.QueryRoomVersionForRoom(ctx, roomID.String()) - if err != nil { - return "", "", err + checkInvitePending := false + info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias) + if err == nil && info != nil { + switch info.RoomVersion { + case gomatrixserverlib.RoomVersionPseudoIDs: + senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID) + if err == nil { + checkInvitePending = true + } else { + // create user room key if needed + key, err := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", err) + } + senderID = spec.SenderID(spec.Base64Bytes(key).Encode()) + } + default: + checkInvitePending = true + senderID = spec.SenderID(userID.String()) } } - if roomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - // TODO: pseudoIDs - generate senderID kere! - senderID = "pseudo_id.sender.key" - } else { - senderID = spec.SenderID(userID.String()) - } - senderIDString := string(senderID) userDomain := userID.Domain() // Force a federated join if we're dealing with a pending invite // and we aren't in the room. - isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) - if err == nil && !serverInRoom && isInvitePending { - inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender) - if queryErr != nil { - return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) - } + if checkInvitePending { + isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID) + if err == nil && !serverInRoom && isInvitePending { + inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender) + if queryErr != nil { + return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr) + } - // If we were invited by someone from another server then we can - // assume they are in the room so we can join via them. - if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { - req.ServerNames = append(req.ServerNames, inviter.Domain()) - forceFederatedJoin = true - memberEvent := gjson.Parse(string(inviteEvent.JSON())) - // only set unsigned if we've got a content.membership, which we _should_ - if memberEvent.Get("content.membership").Exists() { - req.Unsigned = map[string]interface{}{ - "prev_sender": memberEvent.Get("sender").Str, - "prev_content": map[string]interface{}{ - "is_direct": memberEvent.Get("content.is_direct").Bool(), - "membership": memberEvent.Get("content.membership").Str, - }, + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. + if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) { + req.ServerNames = append(req.ServerNames, inviter.Domain()) + forceFederatedJoin = true + memberEvent := gjson.Parse(string(inviteEvent.JSON())) + // only set unsigned if we've got a content.membership, which we _should_ + if memberEvent.Get("content.membership").Exists() { + req.Unsigned = map[string]interface{}{ + "prev_sender": memberEvent.Get("sender").Str, + "prev_content": map[string]interface{}{ + "is_direct": memberEvent.Get("content.is_direct").Bool(), + "membership": memberEvent.Get("content.membership").Str, + }, + } } } } @@ -261,6 +271,7 @@ func (r *Joiner) performJoinRoomByID( // If we should do a forced federated join then do that. var joinedVia spec.ServerName if forceFederatedJoin { + // TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet joinedVia, err = r.performFederatedJoinRoomByID(ctx, req, senderID) return req.RoomIDOrAlias, joinedVia, err } @@ -277,6 +288,8 @@ func (r *Joiner) performJoinRoomByID( return "", "", fmt.Errorf("error joining local room: %q", err) } + senderIDString := string(senderID) + // Prepare the template for the join event. proto := gomatrixserverlib.ProtoEvent{ Type: spec.MRoomMember, @@ -308,15 +321,6 @@ func (r *Joiner) performJoinRoomByID( switch err.(type) { case nil: - // create user room key if needed - if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { - _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) - if err != nil { - logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") - return "", "", fmt.Errorf("failed to get user room private key: %w", err) - } - } - // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't