diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go b/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go index 6dd52a003..27404c03c 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/membership.go @@ -15,6 +15,7 @@ package writers import ( + "errors" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -32,6 +33,8 @@ import ( "github.com/matrix-org/util" ) +var errMissingUserID = errors.New("'user_id' must be supplied.") + // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // by building a m.room.member event then sending it to the room server func SendMembership( @@ -50,14 +53,46 @@ func SendMembership( return *res } - stateKey, reason, reqErr := getMembershipStateKey(body, device, membership) - if reqErr != nil { - return *reqErr + event, err := buildMembershipEvent( + body, accountDB, device, membership, roomID, cfg, queryAPI, + ) + if err == errMissingUserID { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON(err.Error()), + } + } else if err == events.ErrRoomNoExists { + return util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(err.Error()), + } + } else if err != nil { + return httputil.LogThenError(req, err) + } + + if err := producer.SendEvents([]gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: 200, + JSON: struct{}{}, + } +} + +func buildMembershipEvent( + body threepid.MembershipRequest, accountDB *accounts.Database, + device *authtypes.Device, membership string, roomID string, cfg config.Dendrite, + queryAPI api.RoomserverQueryAPI, +) (*gomatrixserverlib.Event, error) { + stateKey, reason, err := getMembershipStateKey(body, device, membership) + if err != nil { + return nil, err } profile, err := loadProfile(stateKey, cfg, accountDB) if err != nil { - return httputil.LogThenError(req, err) + return nil, err } builder := gomatrixserverlib.EventBuilder{ @@ -80,27 +115,10 @@ func SendMembership( } if err = builder.SetContent(content); err != nil { - return httputil.LogThenError(req, err) + return nil, err } - event, err := events.BuildEvent(&builder, cfg, queryAPI, nil) - if err == events.ErrRoomNoExists { - return util.JSONResponse{ - Code: 404, - JSON: jsonerror.NotFound(err.Error()), - } - } else if err != nil { - return httputil.LogThenError(req, err) - } - - if err := producer.SendEvents([]gomatrixserverlib.Event{*event}, cfg.Matrix.ServerName); err != nil { - return httputil.LogThenError(req, err) - } - - return util.JSONResponse{ - Code: 200, - JSON: struct{}{}, - } + return events.BuildEvent(&builder, cfg, queryAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns @@ -130,16 +148,13 @@ func loadProfile(userID string, cfg config.Dendrite, accountDB *accounts.Databas // returns a JSONResponse with a corresponding error code and message. func getMembershipStateKey( body threepid.MembershipRequest, device *authtypes.Device, membership string, -) (stateKey string, reason string, response *util.JSONResponse) { +) (stateKey string, reason string, err error) { if membership == "ban" || membership == "unban" || membership == "kick" || membership == "invite" { // If we're in this case, the state key is contained in the request body, // possibly along with a reason (for "kick" and "ban") so we need to parse // it if body.UserID == "" { - response = &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("'user_id' must be supplied."), - } + err = errMissingUserID return }