Hacks for supporting Riot iOS (#1148)
* Join room body is optional * Support deprecated login by user/password * Implement dummy key upload endpoint * Make a very determinate end to /messages if we hit the create event in back-pagination * Linting
This commit is contained in:
parent
84a7881468
commit
ddf1c8adf1
|
@ -43,9 +43,7 @@ func JoinRoomByIDOrAlias(
|
||||||
// If content was provided in the request then incude that
|
// If content was provided in the request then incude that
|
||||||
// in the request. It'll get used as a part of the membership
|
// in the request. It'll get used as a part of the membership
|
||||||
// event content.
|
// event content.
|
||||||
if err := httputil.UnmarshalJSONRequest(req, &joinReq.Content); err != nil {
|
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)
|
||||||
return *err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Work out our localpart for the client profile request.
|
// Work out our localpart for the client profile request.
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
|
|
@ -47,6 +47,7 @@ type loginIdentifier struct {
|
||||||
|
|
||||||
type passwordRequest struct {
|
type passwordRequest struct {
|
||||||
Identifier loginIdentifier `json:"identifier"`
|
Identifier loginIdentifier `json:"identifier"`
|
||||||
|
User string `json:"user"` // deprecated in favour of identifier
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
||||||
// Thus a pointer is needed to differentiate between the two
|
// Thus a pointer is needed to differentiate between the two
|
||||||
|
@ -81,6 +82,7 @@ func Login(
|
||||||
} else if req.Method == http.MethodPost {
|
} else if req.Method == http.MethodPost {
|
||||||
var r passwordRequest
|
var r passwordRequest
|
||||||
var acc *api.Account
|
var acc *api.Account
|
||||||
|
var errJSON *util.JSONResponse
|
||||||
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
resErr := httputil.UnmarshalJSONRequest(req, &r)
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return *resErr
|
return *resErr
|
||||||
|
@ -93,32 +95,24 @@ func Login(
|
||||||
JSON: jsonerror.BadJSON("'user' must be supplied."),
|
JSON: jsonerror.BadJSON("'user' must be supplied."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User)
|
||||||
util.GetLogger(req.Context()).WithField("user", r.Identifier.User).Info("Processing login request")
|
if errJSON != nil {
|
||||||
|
return *errJSON
|
||||||
localpart, err := userutil.ParseUsernameParam(r.Identifier.User, &cfg.Matrix.ServerName)
|
|
||||||
if err != nil {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password)
|
|
||||||
if err != nil {
|
|
||||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
|
||||||
// but that would leak the existence of the user.
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
// TODO: The below behaviour is deprecated but without it Riot iOS won't log in
|
||||||
|
if r.User != "" {
|
||||||
|
acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.User)
|
||||||
|
if errJSON != nil {
|
||||||
|
return *errJSON
|
||||||
|
}
|
||||||
|
} else {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"),
|
JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
token, err := auth.GenerateAccessToken()
|
token, err := auth.GenerateAccessToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -163,3 +157,32 @@ func getDevice(
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *passwordRequest) processUsernamePasswordLoginRequest(
|
||||||
|
req *http.Request, accountDB accounts.Database,
|
||||||
|
cfg *config.Dendrite, username string,
|
||||||
|
) (acc *api.Account, errJSON *util.JSONResponse) {
|
||||||
|
util.GetLogger(req.Context()).WithField("user", username).Info("Processing login request")
|
||||||
|
|
||||||
|
localpart, err := userutil.ParseUsernameParam(username, &cfg.Matrix.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
errJSON = &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password)
|
||||||
|
if err != nil {
|
||||||
|
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||||
|
// but that would leak the existence of the user.
|
||||||
|
errJSON = &util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -36,9 +36,19 @@ func Setup(
|
||||||
publicAPIMux *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI,
|
publicAPIMux *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI,
|
||||||
) {
|
) {
|
||||||
r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter()
|
r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter()
|
||||||
|
|
||||||
r0mux.Handle("/keys/query",
|
r0mux.Handle("/keys/query",
|
||||||
httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return QueryKeys(req)
|
return QueryKeys(req)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
|
r0mux.Handle("/keys/upload/{keyID}",
|
||||||
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,6 +158,7 @@ func OnIncomingMessagesRequest(
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||||
"from": from.String(),
|
"from": from.String(),
|
||||||
"to": to.String(),
|
"to": to.String(),
|
||||||
|
@ -246,6 +247,12 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
// change the way topological positions are defined (as depth isn't the most
|
// change the way topological positions are defined (as depth isn't the most
|
||||||
// reliable way to define it), it would be easier and less troublesome to
|
// reliable way to define it), it would be easier and less troublesome to
|
||||||
// only have to change it in one place, i.e. the database.
|
// only have to change it in one place, i.e. the database.
|
||||||
|
start, end, err = r.getStartEnd(events)
|
||||||
|
|
||||||
|
return clientEvents, start, end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
|
||||||
start, err = r.db.EventPositionInTopology(
|
start, err = r.db.EventPositionInTopology(
|
||||||
r.ctx, events[0].EventID(),
|
r.ctx, events[0].EventID(),
|
||||||
)
|
)
|
||||||
|
@ -253,6 +260,11 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err)
|
err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate {
|
||||||
|
// We've hit the beginning of the room so there's really nowhere else
|
||||||
|
// to go. This seems to fix Riot iOS from looping on /messages endlessly.
|
||||||
|
end = types.NewTopologyToken(0, 0)
|
||||||
|
} else {
|
||||||
end, err = r.db.EventPositionInTopology(
|
end, err = r.db.EventPositionInTopology(
|
||||||
r.ctx, events[len(events)-1].EventID(),
|
r.ctx, events[len(events)-1].EventID(),
|
||||||
)
|
)
|
||||||
|
@ -260,7 +272,6 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
|
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.backwardOrdering {
|
if r.backwardOrdering {
|
||||||
// A stream/topological position is a cursor located between two events.
|
// A stream/topological position is a cursor located between two events.
|
||||||
// While they are identified in the code by the event on their right (if
|
// While they are identified in the code by the event on their right (if
|
||||||
|
@ -269,8 +280,8 @@ func (r *messagesReq) retrieveEvents() (
|
||||||
// end position we send in the response if we're going backward.
|
// end position we send in the response if we're going backward.
|
||||||
end.Decrement()
|
end.Decrement()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return clientEvents, start, end, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEmptyEventsSlice handles the case where the initial request to the
|
// handleEmptyEventsSlice handles the case where the initial request to the
|
||||||
|
|
Loading…
Reference in a new issue