Pass through content, try to handle multiple server names

This commit is contained in:
Neil Alexander 2020-05-01 16:36:19 +01:00
parent 913f2ab8c1
commit f53dffc02a
3 changed files with 65 additions and 19 deletions

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
@ -40,12 +41,21 @@ func JoinRoomByIDOrAlias(
keyRing gomatrixserverlib.KeyRing, // nolint:unparam keyRing gomatrixserverlib.KeyRing, // nolint:unparam
accountDB accounts.Database, // nolint:unparam accountDB accounts.Database, // nolint:unparam
) util.JSONResponse { ) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join.
joinReq := roomserverAPI.PerformJoinRequest{ joinReq := roomserverAPI.PerformJoinRequest{
RoomIDOrAlias: roomIDOrAlias, RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID, UserID: device.UserID,
Content: nil,
} }
joinRes := roomserverAPI.PerformJoinResponse{} joinRes := roomserverAPI.PerformJoinResponse{}
// If content was provided in the request then incude that
// in the request. It'll get used as a part of the membership
// event content.
if err := httputil.UnmarshalJSONRequest(req, &joinReq.Content); err != nil {
return *err
}
// Ask the roomserver to perform the join.
if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,

View file

@ -4,6 +4,7 @@ import (
"context" "context"
commonHTTP "github.com/matrix-org/dendrite/common/http" commonHTTP "github.com/matrix-org/dendrite/common/http"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
) )
@ -19,6 +20,7 @@ type PerformJoinRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"` RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
Content map[string]interface{} `json:"content"` Content map[string]interface{} `json:"content"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
} }
type PerformJoinResponse struct { type PerformJoinResponse struct {

View file

@ -10,6 +10,7 @@ import (
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
// WriteOutputEvents implements OutputRoomEventWriter // WriteOutputEvents implements OutputRoomEventWriter
@ -39,6 +40,13 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
res *api.PerformJoinResponse, res *api.PerformJoinResponse,
) error { ) error {
// Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil {
return fmt.Errorf("supplied room alias %q in incorrect format", req.RoomIDOrAlias)
}
req.ServerNames = append(req.ServerNames, domain)
// Look up if we know this room alias. // Look up if we know this room alias.
roomID, err := r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias) roomID, err := r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias)
if err != nil { if err != nil {
@ -50,11 +58,20 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
return r.performJoinRoomByID(ctx, req, res) return r.performJoinRoomByID(ctx, req, res)
} }
// TODO: Break this function up a bit
// nolint:gocyclo
func (r *RoomserverInternalAPI) performJoinRoomByID( func (r *RoomserverInternalAPI) performJoinRoomByID(
ctx context.Context, ctx context.Context,
req *api.PerformJoinRequest, req *api.PerformJoinRequest,
res *api.PerformJoinResponse, res *api.PerformJoinResponse, // nolint:unparam
) error { ) error {
// Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil {
return fmt.Errorf("supplied room alias %q in incorrect format", req.RoomIDOrAlias)
}
req.ServerNames = append(req.ServerNames, domain)
// Prepare the template for the join event. // Prepare the template for the join event.
userID := req.UserID userID := req.UserID
eb := gomatrixserverlib.EventBuilder{ eb := gomatrixserverlib.EventBuilder{
@ -64,18 +81,18 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
RoomID: req.RoomIDOrAlias, RoomID: req.RoomIDOrAlias,
Redacts: "", Redacts: "",
} }
if err := eb.SetUnsigned(struct{}{}); err != nil { if err = eb.SetUnsigned(struct{}{}); err != nil {
return fmt.Errorf("eb.SetUnsigned: %w", err) return fmt.Errorf("eb.SetUnsigned: %w", err)
} }
// It is possible for the requestoto include some "content" for the // It is possible for the request to include some "content" for the
// event. We'll always overwrite the "membership" key, but the rest, // event. We'll always overwrite the "membership" key, but the rest,
// like "display_name" or "avatar_url", will be kept if supplied. // like "display_name" or "avatar_url", will be kept if supplied.
if req.Content == nil { if req.Content == nil {
req.Content = map[string]interface{}{} req.Content = map[string]interface{}{}
} }
req.Content["membership"] = "join" req.Content["membership"] = "join"
if err := eb.SetContent(req.Content); err != nil { if err = eb.SetContent(req.Content); err != nil {
return fmt.Errorf("eb.SetContent: %w", err) return fmt.Errorf("eb.SetContent: %w", err)
} }
@ -124,18 +141,35 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
return fmt.Errorf("error trying to join %q room: %w", req.RoomIDOrAlias, derr) return fmt.Errorf("error trying to join %q room: %w", req.RoomIDOrAlias, derr)
} }
// Try joining by all of the supplied server names.
// TODO: Update the FS API so that it accepts a list of server names and
// does this bit by itself.
joined := false
for _, serverName := range req.ServerNames {
// Otherwise, if we've reached this point, the room isn't a local room // Otherwise, if we've reached this point, the room isn't a local room
// and we should ask the federation sender to try and join for us. // and we should ask the federation sender to try and join for us.
fedReq := fsAPI.PerformJoinRequest{ fedReq := fsAPI.PerformJoinRequest{
RoomID: req.RoomIDOrAlias, // the room ID to try and join RoomID: req.RoomIDOrAlias, // the room ID to try and join
UserID: req.UserID, // the user ID joining the room UserID: req.UserID, // the user ID joining the room
ServerName: domain, // the server to try joining with ServerName: serverName, // the server to try joining with
Content: req.Content, // the membership event content Content: req.Content, // the membership event content
} }
fedRes := fsAPI.PerformJoinResponse{} fedRes := fsAPI.PerformJoinResponse{}
err = r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes) err = r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes)
if err != nil { if err != nil {
return fmt.Errorf("error joining federated room %q: %w", req.RoomIDOrAlias, err) logrus.WithError(err).Errorf("error joining federated room %q", req.RoomIDOrAlias)
continue
}
joined = true
}
// If we didn't successfully join the room using any of the supplied
// servers then return an error saying such.
if !joined {
return fmt.Errorf(
"failed to join %q using %d server(s)",
req.RoomIDOrAlias, len(req.ServerNames),
)
} }
default: default: