From f2a6633a62298b19d84cd415aa2f0364f04356c6 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 20 Feb 2017 16:18:09 +0000 Subject: [PATCH 1/6] Add auth package - Extract the access token from the HTTP request --- .../dendrite/clientapi/auth/auth.go | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 src/github.com/matrix-org/dendrite/clientapi/auth/auth.go diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go new file mode 100644 index 000000000..4d2c9b094 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go @@ -0,0 +1,43 @@ +package auth + +import ( + "fmt" + "net/http" + "strings" +) + +// VerifyAccessToken verifies that an access token was supplied in the given HTTP request +// and returns the user ID it corresponds to. Returns an error if there is no access token +// or the token is invalid. +func VerifyAccessToken(req *http.Request) (userID string, err error) { + _, tokenErr := extractAccessToken(req) + if tokenErr != nil { + // err = MatrixError(MatrixError.M_MISSING_TOKEN, tokenErr.Error()) + return + } + // TODO: Do something with the token + return +} + +// extractAccessToken from a request, or return an error detailing what went wrong. +func extractAccessToken(req *http.Request) (string, error) { + authBearer := req.Header.Get("Authorization") + queryToken := req.URL.Query().Get("access_token") + if authBearer != "" && queryToken != "" { + return "", fmt.Errorf("mixing Authorization headers and access_token query parameters") + } + + if queryToken != "" { + return queryToken, nil + } + + if authBearer != "" { + parts := strings.SplitN(authBearer, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + return "", fmt.Errorf("invalid Authorization header") + } + return parts[1], nil + } + + return "", fmt.Errorf("missing access token") +} From 62a9375f5643070f8e0375ac9debb4e0c3071a29 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 21 Feb 2017 16:32:53 +0000 Subject: [PATCH 2/6] Return the correct error --- .../matrix-org/dendrite/clientapi/auth/auth.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go index 4d2c9b094..927e294c5 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "strings" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" ) // VerifyAccessToken verifies that an access token was supplied in the given HTTP request @@ -12,15 +14,17 @@ import ( func VerifyAccessToken(req *http.Request) (userID string, err error) { _, tokenErr := extractAccessToken(req) if tokenErr != nil { - // err = MatrixError(MatrixError.M_MISSING_TOKEN, tokenErr.Error()) + err = jsonerror.MissingToken(tokenErr.Error()) return } - // TODO: Do something with the token + // TODO: Check the token against the database return } -// extractAccessToken from a request, or return an error detailing what went wrong. +// extractAccessToken from a request, or return an error detailing what went wrong. The +// error message MUST be human-readable and comprehensible to the client. func extractAccessToken(req *http.Request) (string, error) { + // cf https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/api/auth.py#L631 authBearer := req.Header.Get("Authorization") queryToken := req.URL.Query().Get("access_token") if authBearer != "" && queryToken != "" { From 17cc782affbcac3a2b72733b1c9b572f7dde1d4c Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 22 Feb 2017 11:01:21 +0000 Subject: [PATCH 3/6] 'verify' the access token --- .../matrix-org/dendrite/clientapi/readers/sync.go | 11 ++++++++++- .../dendrite/clientapi/writers/sendmessage.go | 15 ++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go index 27a470945..b733a3dbc 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go @@ -3,13 +3,22 @@ package readers import ( "net/http" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/util" ) // Sync implements /sync func Sync(req *http.Request) (interface{}, *util.HTTPError) { logger := util.GetLogger(req.Context()) - logger.Info("Doing stuff...") + userID, err := auth.VerifyAccessToken(req) + if err != nil { + return nil, &util.HTTPError{ + Code: 403, + JSON: err, + } + } + + logger.WithField("userID", userID).Info("Doing stuff...") return nil, &util.HTTPError{ Code: 404, Message: "Not implemented yet", diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go index 11b17740d..c5d2c8340 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go @@ -3,13 +3,26 @@ package writers import ( "net/http" + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/util" ) // SendMessage implements /rooms/{roomID}/send/{eventType} func SendMessage(req *http.Request, roomID, eventType string) (interface{}, *util.HTTPError) { logger := util.GetLogger(req.Context()) - logger.WithField("roomID", roomID).WithField("eventType", eventType).Info("Doing stuff...") + userID, err := auth.VerifyAccessToken(req) + if err != nil { + return nil, &util.HTTPError{ + Code: 403, + JSON: err, + } + } + logger.WithFields(log.Fields{ + "roomID": roomID, + "eventType": eventType, + "userID": userID, + }).Info("Doing stuff...") return nil, &util.HTTPError{ Code: 404, Message: "Not implemented yet", From 31c57a8579bff15134175cc9dc575dd5f35598f1 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 24 Feb 2017 10:41:28 +0000 Subject: [PATCH 4/6] Update to use util.JSONResponse --- .../dendrite/clientapi/readers/sync.go | 7 +- .../dendrite/clientapi/routing/routing.go | 10 +- .../dendrite/clientapi/writers/sendmessage.go | 7 +- vendor/manifest | 4 +- .../src/github.com/matrix-org/util/context.go | 35 ++++ .../src/github.com/matrix-org/util/error.go | 24 --- vendor/src/github.com/matrix-org/util/json.go | 153 +++++++----------- .../github.com/matrix-org/util/json_test.go | 138 +++++++++++++--- 8 files changed, 218 insertions(+), 160 deletions(-) create mode 100644 vendor/src/github.com/matrix-org/util/context.go delete mode 100644 vendor/src/github.com/matrix-org/util/error.go diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go index 27a470945..5f4516fb1 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go @@ -7,11 +7,8 @@ import ( ) // Sync implements /sync -func Sync(req *http.Request) (interface{}, *util.HTTPError) { +func Sync(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) logger.Info("Doing stuff...") - return nil, &util.HTTPError{ - Code: 404, - Message: "Not implemented yet", - } + return util.MessageResponse(404, "Not implemented yet") } diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 13cf4048d..6e830e03d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -17,11 +17,11 @@ const pathPrefixR0 = "/_matrix/client/r0" func Setup(servMux *http.ServeMux, httpClient *http.Client) { apiMux := mux.NewRouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() - r0mux.Handle("/sync", make("sync", wrap(func(req *http.Request) (interface{}, *util.HTTPError) { + r0mux.Handle("/sync", make("sync", wrap(func(req *http.Request) util.JSONResponse { return readers.Sync(req) }))) r0mux.Handle("/rooms/{roomID}/send/{eventType}", - make("send_message", wrap(func(req *http.Request) (interface{}, *util.HTTPError) { + make("send_message", wrap(func(req *http.Request) util.JSONResponse { vars := mux.Vars(req) return writers.SendMessage(req, vars["roomID"], vars["eventType"]) })), @@ -38,12 +38,12 @@ func make(metricsName string, h util.JSONRequestHandler) http.Handler { // jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler type jsonRequestHandlerWrapper struct { - function func(req *http.Request) (interface{}, *util.HTTPError) + function func(req *http.Request) util.JSONResponse } -func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) (interface{}, *util.HTTPError) { +func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) util.JSONResponse { return r.function(req) } -func wrap(f func(req *http.Request) (interface{}, *util.HTTPError)) *jsonRequestHandlerWrapper { +func wrap(f func(req *http.Request) util.JSONResponse) *jsonRequestHandlerWrapper { return &jsonRequestHandlerWrapper{f} } diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go index 11b17740d..ae4103da4 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go @@ -7,11 +7,8 @@ import ( ) // SendMessage implements /rooms/{roomID}/send/{eventType} -func SendMessage(req *http.Request, roomID, eventType string) (interface{}, *util.HTTPError) { +func SendMessage(req *http.Request, roomID, eventType string) util.JSONResponse { logger := util.GetLogger(req.Context()) logger.WithField("roomID", roomID).WithField("eventType", eventType).Info("Doing stuff...") - return nil, &util.HTTPError{ - Code: 404, - Message: "Not implemented yet", - } + return util.MessageResponse(404, "Not implemented yet") } diff --git a/vendor/manifest b/vendor/manifest index 99f433a94..79bac494c 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -98,7 +98,7 @@ { "importpath": "github.com/matrix-org/util", "repository": "https://github.com/matrix-org/util", - "revision": "4de125c773716ad380f2f80cc6c04789ef4c906a", + "revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828", "branch": "master" }, { @@ -206,4 +206,4 @@ "branch": "master" } ] -} \ No newline at end of file +} diff --git a/vendor/src/github.com/matrix-org/util/context.go b/vendor/src/github.com/matrix-org/util/context.go new file mode 100644 index 000000000..d8def4f9b --- /dev/null +++ b/vendor/src/github.com/matrix-org/util/context.go @@ -0,0 +1,35 @@ +package util + +import ( + "context" + + log "github.com/Sirupsen/logrus" +) + +// contextKeys is a type alias for string to namespace Context keys per-package. +type contextKeys string + +// ctxValueRequestID is the key to extract the request ID for an HTTP request +const ctxValueRequestID = contextKeys("requestid") + +// GetRequestID returns the request ID associated with this context, or the empty string +// if one is not associated with this context. +func GetRequestID(ctx context.Context) string { + id := ctx.Value(ctxValueRequestID) + if id == nil { + return "" + } + return id.(string) +} + +// ctxValueLogger is the key to extract the logrus Logger. +const ctxValueLogger = contextKeys("logger") + +// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger. +func GetLogger(ctx context.Context) *log.Entry { + l := ctx.Value(ctxValueLogger) + if l == nil { + return nil + } + return l.(*log.Entry) +} diff --git a/vendor/src/github.com/matrix-org/util/error.go b/vendor/src/github.com/matrix-org/util/error.go deleted file mode 100644 index 530a581b1..000000000 --- a/vendor/src/github.com/matrix-org/util/error.go +++ /dev/null @@ -1,24 +0,0 @@ -package util - -import "fmt" - -// HTTPError An HTTP Error response, which may wrap an underlying native Go Error. -type HTTPError struct { - WrappedError error - // A human-readable message to return to the client in a JSON response. This - // is ignored if JSON is supplied. - Message string - // HTTP status code. - Code int - // JSON represents the JSON that should be serialized and sent to the client - // instead of the given Message. - JSON interface{} -} - -func (e HTTPError) Error() string { - var wrappedErrMsg string - if e.WrappedError != nil { - wrappedErrMsg = e.WrappedError.Error() - } - return fmt.Sprintf("%s: %d: %s", e.Message, e.Code, wrappedErrMsg) -} diff --git a/vendor/src/github.com/matrix-org/util/json.go b/vendor/src/github.com/matrix-org/util/json.go index 92604c54b..b0834eac7 100644 --- a/vendor/src/github.com/matrix-org/util/json.go +++ b/vendor/src/github.com/matrix-org/util/json.go @@ -3,7 +3,6 @@ package util import ( "context" "encoding/json" - "fmt" "math/rand" "net/http" "runtime/debug" @@ -12,46 +11,51 @@ import ( log "github.com/Sirupsen/logrus" ) -// contextKeys is a type alias for string to namespace Context keys per-package. -type contextKeys string - -// ctxValueRequestID is the key to extract the request ID for an HTTP request -const ctxValueRequestID = contextKeys("requestid") - -// GetRequestID returns the request ID associated with this context, or the empty string -// if one is not associated with this context. -func GetRequestID(ctx context.Context) string { - id := ctx.Value(ctxValueRequestID) - if id == nil { - return "" - } - return id.(string) +// JSONResponse represents an HTTP response which contains a JSON body. +type JSONResponse struct { + // HTTP status code. + Code int + // JSON represents the JSON that should be serialized and sent to the client + JSON interface{} + // Headers represent any headers that should be sent to the client + Headers map[string]string } -// ctxValueLogger is the key to extract the logrus Logger. -const ctxValueLogger = contextKeys("logger") +// Is2xx returns true if the Code is between 200 and 299. +func (r JSONResponse) Is2xx() bool { + return r.Code/100 == 2 +} -// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger. -func GetLogger(ctx context.Context) *log.Entry { - l := ctx.Value(ctxValueLogger) - if l == nil { - return nil +// RedirectResponse returns a JSONResponse which 302s the client to the given location. +func RedirectResponse(location string) JSONResponse { + headers := make(map[string]string) + headers["Location"] = location + return JSONResponse{ + Code: 302, + JSON: struct{}{}, + Headers: headers, } - return l.(*log.Entry) +} + +// MessageResponse returns a JSONResponse with a 'message' key containing the given text. +func MessageResponse(code int, msg string) JSONResponse { + return JSONResponse{ + Code: code, + JSON: struct { + Message string `json:"message"` + }{msg}, + } +} + +// ErrorResponse returns an HTTP 500 JSONResponse with the stringified form of the given error. +func ErrorResponse(err error) JSONResponse { + return MessageResponse(500, err.Error()) } // JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming -// HTTP requests with JSON. The interface returned will be marshalled into JSON to be sent to the client, -// unless the interface is []byte in which case the bytes are sent to the client unchanged. -// If an error is returned, a JSON error response will also be returned, unless the error code -// is a 302 REDIRECT in which case a redirect is sent based on the Message field. +// HTTP requests with JSON. type JSONRequestHandler interface { - OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) -} - -// JSONError represents a JSON API error response -type JSONError struct { - Message string `json:"message"` + OnIncomingRequest(req *http.Request) JSONResponse } // Protect panicking HTTP requests from taking down the entire process, and log them using @@ -67,12 +71,7 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { }).Errorf( "Request panicked!\n%s", debug.Stack(), ) - jsonErrorResponse( - w, req, &HTTPError{ - Message: "Internal Server Error", - Code: 500, - }, - ) + respond(w, req, MessageResponse(500, "Internal Server Error")) } }() handler(w, req) @@ -81,11 +80,11 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { // MakeJSONAPI creates an HTTP handler which always responds to incoming requests with JSON responses. // Incoming http.Requests will have a logger (with a request ID/method/path logged) attached to the Context. -// This can be accessed via GetLogger(Context). The type of the logger is *log.Entry from github.com/Sirupsen/logrus +// This can be accessed via GetLogger(Context). func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { return Protect(func(w http.ResponseWriter, req *http.Request) { reqID := RandomString(12) - // Set a Logger on the context + // Set a Logger and request ID on the context ctx := context.WithValue(req.Context(), ctxValueLogger, log.WithFields(log.Fields{ "req.method": req.Method, "req.path": req.URL.Path, @@ -97,75 +96,39 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { logger := req.Context().Value(ctxValueLogger).(*log.Entry) logger.Print("Incoming request") - res, httpErr := handler.OnIncomingRequest(req) + res := handler.OnIncomingRequest(req) // Set common headers returned regardless of the outcome of the request w.Header().Set("Content-Type", "application/json") SetCORSHeaders(w) - if httpErr != nil { - jsonErrorResponse(w, req, httpErr) - return - } - - // if they've returned bytes as the response, then just return them rather than marshalling as JSON. - // This gives handlers an escape hatch if they want to return cached bytes. - var resBytes []byte - resBytes, ok := res.([]byte) - if !ok { - r, err := json.Marshal(res) - if err != nil { - jsonErrorResponse(w, req, &HTTPError{ - Message: "Failed to serialise response as JSON", - Code: 500, - }) - return - } - resBytes = r - } - logger.Print(fmt.Sprintf("Responding (%d bytes)", len(resBytes))) - w.Write(resBytes) + respond(w, req, res) }) } -func jsonErrorResponse(w http.ResponseWriter, req *http.Request, httpErr *HTTPError) { +func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) { logger := req.Context().Value(ctxValueLogger).(*log.Entry) - if httpErr.Code == 302 { - logger.WithField("err", httpErr.Error()).Print("Redirecting") - http.Redirect(w, req, httpErr.Message, 302) - return - } - logger.WithFields(log.Fields{ - log.ErrorKey: httpErr, - }).Print("Responding with error") - w.WriteHeader(httpErr.Code) // Set response code - - var err error - var r []byte - if httpErr.JSON != nil { - r, err = json.Marshal(httpErr.JSON) - if err != nil { - // failed to marshal the supplied interface. Whine and fallback to the HTTP message. - logger.WithError(err).Error("Failed to marshal HTTPError.JSON") + // Set custom headers + if res.Headers != nil { + for h, val := range res.Headers { + w.Header().Set(h, val) } } - // failed to marshal or no custom JSON was supplied, send message JSON. - if err != nil || httpErr.JSON == nil { - r, err = json.Marshal(&JSONError{ - Message: httpErr.Message, - }) + // Marshal JSON response into raw bytes to send as the HTTP body + resBytes, err := json.Marshal(res.JSON) + if err != nil { + logger.WithError(err).Error("Failed to marshal JSONResponse") + // this should never fail to be marshalled so drop err to the floor + res = MessageResponse(500, "Internal Server Error") + resBytes, _ = json.Marshal(res.JSON) } - if err != nil { - // We should never fail to marshal the JSON error response, but in this event just skip - // marshalling altogether - logger.Warn("Failed to marshal error response") - w.Write([]byte(`{}`)) - return - } - w.Write(r) + // Set status code and write the body + w.WriteHeader(res.Code) + logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) + w.Write(resBytes) } // WithCORSOptions intercepts all OPTIONS requests and responds with CORS headers. The request handler diff --git a/vendor/src/github.com/matrix-org/util/json_test.go b/vendor/src/github.com/matrix-org/util/json_test.go index 2248ac3ff..687db277f 100644 --- a/vendor/src/github.com/matrix-org/util/json_test.go +++ b/vendor/src/github.com/matrix-org/util/json_test.go @@ -2,6 +2,7 @@ package util import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -10,10 +11,10 @@ import ( ) type MockJSONRequestHandler struct { - handler func(req *http.Request) (interface{}, *HTTPError) + handler func(req *http.Request) JSONResponse } -func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) { +func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) JSONResponse { return h.handler(req) } @@ -24,36 +25,27 @@ type MockResponse struct { func TestMakeJSONAPI(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output tests := []struct { - Return interface{} - Err *HTTPError + Return JSONResponse ExpectCode int ExpectJSON string }{ - // Error message return values - {nil, &HTTPError{nil, "Everything is broken", 500, nil}, 500, `{"message":"Everything is broken"}`}, - // Error JSON return values - {nil, &HTTPError{nil, "Everything is broken", 500, struct { - Foo string `json:"foo"` - }{"yep"}}, 500, `{"foo":"yep"}`}, + // MessageResponse return values + {MessageResponse(500, "Everything is broken"), 500, `{"message":"Everything is broken"}`}, + // interface return values + {JSONResponse{500, MockResponse{"yep"}, nil}, 500, `{"foo":"yep"}`}, // Error JSON return values which fail to be marshalled should fallback to text - {nil, &HTTPError{nil, "Everything is broken", 500, struct { + {JSONResponse{500, struct { Foo interface{} `json:"foo"` - }{func(cannotBe, marshalled string) {}}}, 500, `{"message":"Everything is broken"}`}, + }{func(cannotBe, marshalled string) {}}, nil}, 500, `{"message":"Internal Server Error"}`}, // With different status codes - {nil, &HTTPError{nil, "Not here", 404, nil}, 404, `{"message":"Not here"}`}, - // Success return values - {&MockResponse{"yep"}, nil, 200, `{"foo":"yep"}`}, + {JSONResponse{201, MockResponse{"narp"}, nil}, 201, `{"foo":"narp"}`}, // Top-level array success values - {[]MockResponse{{"yep"}, {"narp"}}, nil, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, - // raw []byte escape hatch - {[]byte(`actually bytes`), nil, 200, `actually bytes`}, - // impossible marshal - {func(cannotBe, marshalled string) {}, nil, 500, `{"message":"Failed to serialise response as JSON"}`}, + {JSONResponse{200, []MockResponse{{"yep"}, {"narp"}}, nil}, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, } for _, tst := range tests { - mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { - return tst.Return, tst.Err + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + return tst.Return }} mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() @@ -69,10 +61,38 @@ func TestMakeJSONAPI(t *testing.T) { } } +func TestMakeJSONAPICustomHeaders(t *testing.T) { + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + headers := make(map[string]string) + headers["Custom"] = "Thing" + headers["X-Custom"] = "Things" + return JSONResponse{ + Code: 200, + JSON: MockResponse{"yep"}, + Headers: headers, + } + }} + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockWriter := httptest.NewRecorder() + handlerFunc := MakeJSONAPI(&mock) + handlerFunc(mockWriter, mockReq) + if mockWriter.Code != 200 { + t.Errorf("TestMakeJSONAPICustomHeaders wanted HTTP status 200, got %d", mockWriter.Code) + } + h := mockWriter.Header().Get("Custom") + if h != "Thing" { + t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'Custom: Thing' , got 'Custom: %s'", h) + } + h = mockWriter.Header().Get("X-Custom") + if h != "Things" { + t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'X-Custom: Things' , got 'X-Custom: %s'", h) + } +} + func TestMakeJSONAPIRedirect(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output - mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { - return nil, &HTTPError{nil, "https://matrix.org", 302, nil} + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + return RedirectResponse("https://matrix.org") }} mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() @@ -87,6 +107,50 @@ func TestMakeJSONAPIRedirect(t *testing.T) { } } +func TestMakeJSONAPIError(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + err := errors.New("oops") + return ErrorResponse(err) + }} + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockWriter := httptest.NewRecorder() + handlerFunc := MakeJSONAPI(&mock) + handlerFunc(mockWriter, mockReq) + if mockWriter.Code != 500 { + t.Errorf("TestMakeJSONAPIError wanted HTTP status 500, got %d", mockWriter.Code) + } + actualBody := mockWriter.Body.String() + expect := `{"message":"oops"}` + if actualBody != expect { + t.Errorf("TestMakeJSONAPIError wanted body '%s', got '%s'", expect, actualBody) + } +} + +func TestIs2xx(t *testing.T) { + tests := []struct { + Code int + Expect bool + }{ + {200, true}, + {201, true}, + {299, true}, + {300, false}, + {199, false}, + {0, false}, + {500, false}, + } + for _, test := range tests { + j := JSONResponse{ + Code: test.Code, + } + actual := j.Is2xx() + if actual != test.Expect { + t.Errorf("TestIs2xx wanted %t, got %t", test.Expect, actual) + } + } +} + func TestGetLogger(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output entry := log.WithField("test", "yep") @@ -130,6 +194,32 @@ func TestProtect(t *testing.T) { } } +func TestWithCORSOptions(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + mockWriter := httptest.NewRecorder() + mockReq, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) + h := WithCORSOptions(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(200) + w.Write([]byte("yep")) + }) + h(mockWriter, mockReq) + if mockWriter.Code != 200 { + t.Errorf("TestWithCORSOptions wanted HTTP status 200, got %d", mockWriter.Code) + } + + origin := mockWriter.Header().Get("Access-Control-Allow-Origin") + if origin != "*" { + t.Errorf("TestWithCORSOptions wanted Access-Control-Allow-Origin header '*', got '%s'", origin) + } + + // OPTIONS request shouldn't hit the handler func + expectBody := "" + actualBody := mockWriter.Body.String() + if actualBody != expectBody { + t.Errorf("TestWithCORSOptions wanted body %s, got %s", expectBody, actualBody) + } +} + func TestGetRequestID(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output reqID := "alphabetsoup" From 4137b6185d3a31bb69bbbfb6da8cc3094fa2beb7 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 24 Feb 2017 15:14:47 +0000 Subject: [PATCH 5/6] three-value returns --- .../dendrite/clientapi/auth/auth.go | 17 ++++++++++++----- .../dendrite/clientapi/jsonerror/jsonerror.go | 19 ++++++++++++++++++- .../dendrite/clientapi/readers/sync.go | 12 +++++++----- .../dendrite/clientapi/writers/sendmessage.go | 12 +++++++----- 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go index 927e294c5..cb809dcd3 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/auth.go @@ -6,17 +6,24 @@ import ( "strings" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/util" ) // VerifyAccessToken verifies that an access token was supplied in the given HTTP request -// and returns the user ID it corresponds to. Returns an error if there is no access token -// or the token is invalid. -func VerifyAccessToken(req *http.Request) (userID string, err error) { - _, tokenErr := extractAccessToken(req) +// and returns the user ID it corresponds to. Returns err if there was a fatal problem checking +// the token. Returns resErr (an error response which can be sent to the client) if the token is invalid. +func VerifyAccessToken(req *http.Request) (userID string, resErr *util.JSONResponse, err error) { + token, tokenErr := extractAccessToken(req) if tokenErr != nil { - err = jsonerror.MissingToken(tokenErr.Error()) + resErr = &util.JSONResponse{ + Code: 401, + JSON: jsonerror.MissingToken(tokenErr.Error()), + } return } + if token == "fail" { + err = fmt.Errorf("Fatal error") + } // TODO: Check the token against the database return } diff --git a/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go b/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go index a0111197c..ea64896db 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go +++ b/src/github.com/matrix-org/dendrite/clientapi/jsonerror/jsonerror.go @@ -1,6 +1,9 @@ package jsonerror -import "fmt" +import ( + "fmt" + "github.com/matrix-org/util" +) // MatrixError represents the "standard error response" in Matrix. // http://matrix.org/docs/spec/client_server/r0.2.0.html#api-standards @@ -13,6 +16,20 @@ func (e *MatrixError) Error() string { return fmt.Sprintf("%s: %s", e.ErrCode, e.Err) } +// InternalServerError returns a 500 Internal Server Error in a matrix-compliant +// format. +func InternalServerError() util.JSONResponse { + return util.JSONResponse{ + Code: 500, + JSON: Unknown("Internal Server Error"), + } +} + +// Unknown is an unexpected error +func Unknown(msg string) *MatrixError { + return &MatrixError{"M_UNKNOWN", msg} +} + // Forbidden is an error when the client tries to access a resource // they are not allowed to access. func Forbidden(msg string) *MatrixError { diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go index 8db717103..05def8148 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/sync.go @@ -4,18 +4,20 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/util" ) // Sync implements /sync func Sync(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) - userID, err := auth.VerifyAccessToken(req) + userID, resErr, err := auth.VerifyAccessToken(req) if err != nil { - return util.JSONResponse{ - Code: 403, - JSON: err, - } + logger.WithError(err).Error("Failed to verify access token") + return jsonerror.InternalServerError() + } + if resErr != nil { + return *resErr } logger.WithField("userID", userID).Info("Doing stuff...") diff --git a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go index d6713b1af..4b8ad9052 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/writers/sendmessage.go @@ -5,18 +5,20 @@ import ( log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/util" ) // SendMessage implements /rooms/{roomID}/send/{eventType} func SendMessage(req *http.Request, roomID, eventType string) util.JSONResponse { logger := util.GetLogger(req.Context()) - userID, err := auth.VerifyAccessToken(req) + userID, resErr, err := auth.VerifyAccessToken(req) if err != nil { - return util.JSONResponse{ - Code: 403, - JSON: err, - } + logger.WithError(err).Error("Failed to verify access token") + return jsonerror.InternalServerError() + } + if resErr != nil { + return *resErr } logger.WithFields(log.Fields{ "roomID": roomID, From 3d575aca37c310b8609fff8611268d14345ff07a Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 24 Feb 2017 16:24:21 +0000 Subject: [PATCH 6/6] Add user_ips, users, access_tokens tables Done in the same style as `roomserver`. We should probably think about how we will do upgrades to the schema. --- .../clientapi/storage/access_tokens_table.go | 30 +++++++++++++++ .../dendrite/clientapi/storage/sql.go | 37 +++++++++++++++++++ .../clientapi/storage/users_ips_table.go | 30 +++++++++++++++ .../dendrite/clientapi/storage/users_table.go | 34 +++++++++++++++++ 4 files changed, 131 insertions(+) create mode 100644 src/github.com/matrix-org/dendrite/clientapi/storage/access_tokens_table.go create mode 100644 src/github.com/matrix-org/dendrite/clientapi/storage/sql.go create mode 100644 src/github.com/matrix-org/dendrite/clientapi/storage/users_ips_table.go create mode 100644 src/github.com/matrix-org/dendrite/clientapi/storage/users_table.go diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/access_tokens_table.go b/src/github.com/matrix-org/dendrite/clientapi/storage/access_tokens_table.go new file mode 100644 index 000000000..fe804443e --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/access_tokens_table.go @@ -0,0 +1,30 @@ +package storage + +import ( + "database/sql" +) + +const accessTokensSchema = ` +CREATE TABLE IF NOT EXISTS access_tokens ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + -- Timestamp (ms) when this access token was last used. + last_used BIGINT, + UNIQUE(token) +); + +CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id) ; +` + +type accessTokenStatements struct { +} + +func (s *accessTokenStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(accessTokensSchema) + if err != nil { + return + } + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/sql.go b/src/github.com/matrix-org/dendrite/clientapi/storage/sql.go new file mode 100644 index 000000000..91092e42c --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/sql.go @@ -0,0 +1,37 @@ +package storage + +import ( + "database/sql" +) + +type statements struct { + userIPStatements + accessTokenStatements + usersStatements +} + +func (s *statements) prepare(db *sql.DB) error { + var err error + + if err = s.prepareUserStatements(db); err != nil { + return err + } + + return nil +} + +func (s *statements) prepareUserStatements(db *sql.DB) error { + var err error + + if err = s.accessTokenStatements.prepare(db); err != nil { + return err + } + if err = s.userIPStatements.prepare(db); err != nil { + return err + } + if err = s.usersStatements.prepare(db); err != nil { + return err + } + + return nil +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/users_ips_table.go b/src/github.com/matrix-org/dendrite/clientapi/storage/users_ips_table.go new file mode 100644 index 000000000..26f2f5b80 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/users_ips_table.go @@ -0,0 +1,30 @@ +package storage + +import ( + "database/sql" +) + +const userIPsSchema = ` +CREATE TABLE IF NOT EXISTS user_ips ( + user_id TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + last_seen BIGINT NOT NULL +); + +CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip); +CREATE INDEX user_ips_device_id ON user_ips (user_id, device_id, last_seen); +` + +type userIPStatements struct { +} + +func (s *userIPStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(userIPsSchema) + if err != nil { + return + } + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/storage/users_table.go b/src/github.com/matrix-org/dendrite/clientapi/storage/users_table.go new file mode 100644 index 000000000..f7c9c6b05 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/storage/users_table.go @@ -0,0 +1,34 @@ +package storage + +import ( + "database/sql" +) + +const usersSchema = ` +CREATE TABLE IF NOT EXISTS users ( + user_id TEXT NOT NULL, + -- bcrypt hash of the users password. Can be null for passwordless users like + -- application service users. + password_hash TEXT, + -- Timestamp (ms) when this user was registered on the server. + created_at BIGINT, + -- The ID of the application service which created this user, if applicable. + appservice_id TEXT, + -- Flag which if set indicates this user is a server administrator. + is_admin SMALLINT DEFAULT 0 NOT NULL, + -- Flag which if set indicates this user is a guest. + is_guest SMALLINT DEFAULT 0 NOT NULL, + UNIQUE(user_id) +); +` + +type usersStatements struct { +} + +func (s *usersStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(usersSchema) + if err != nil { + return + } + return +}