mediaapi/writers/download: Add logger to downloadRequest and use it

This commit is contained in:
Robert Swain 2017-05-17 16:24:30 +02:00
parent 6e24fb86cb
commit 13b1051a3e

View file

@ -39,10 +39,11 @@ import (
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download
type downloadRequest struct {
MediaMetadata *types.MediaMetadata
Logger *log.Entry
}
// Validate validates the downloadRequest fields
func (r downloadRequest) Validate() *util.JSONResponse {
func (r *downloadRequest) Validate() *util.JSONResponse {
// FIXME: the following errors aren't bad JSON, rather just a bad request path
// maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...?
if r.MediaMetadata.MediaID == "" {
@ -60,11 +61,11 @@ func (r downloadRequest) Validate() *util.JSONResponse {
return nil
}
func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log.Entry) {
func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) {
// Marshal JSON response into raw bytes to send as the HTTP body
resBytes, err := json.Marshal(res.JSON)
if err != nil {
logger.WithError(err).Error("Failed to marshal JSONResponse")
r.Logger.WithError(err).Error("Failed to marshal JSONResponse")
// this should never fail to be marshalled so drop err to the floor
res = util.MessageResponse(500, "Internal Server Error")
resBytes, _ = json.Marshal(res.JSON)
@ -72,7 +73,7 @@ func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log
// Set status code and write the body
w.WriteHeader(res.Code)
logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
w.Write(resBytes)
}
@ -90,26 +91,25 @@ var nTries = 5
// If they are not present in the cache, they are obtained from the remote server and
// simultaneously served back to the client and written into the cache.
func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
logger := util.GetLogger(req.Context())
// request validation
if req.Method != "GET" {
jsonErrorResponse(w, util.JSONResponse{
Code: 405,
JSON: jsonerror.Unknown("request method must be GET"),
}, logger)
return
}
r := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
Origin: origin,
},
Logger: util.GetLogger(req.Context()),
}
// request validation
if req.Method != "GET" {
r.jsonErrorResponse(w, util.JSONResponse{
Code: 405,
JSON: jsonerror.Unknown("request method must be GET"),
})
return
}
if resErr := r.Validate(); resErr != nil {
jsonErrorResponse(w, *resErr, logger)
r.jsonErrorResponse(w, *resErr)
return
}
@ -118,7 +118,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
if err == nil {
// If we have a record, we can respond from the local file
r.respondFromLocalFile(w, logger, cfg)
r.respondFromLocalFile(w, cfg)
return
} else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName {
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
@ -133,31 +133,31 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata)
if err == nil {
// If we have a record, we can respond from the local file
r.respondFromLocalFile(w, logger, cfg)
r.respondFromLocalFile(w, cfg)
activeRemoteRequests.Unlock()
return
}
if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok {
if tries >= nTries {
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Warnf("Other goroutines are trying to download the remote file and failing.")
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
})
activeRemoteRequests.Unlock()
return
}
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Infof("Waiting for another goroutine to fetch the remote file.")
activeRemoteRequestCondition.Wait()
activeRemoteRequests.Unlock()
} else {
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Infof("Fetching remote file")
@ -167,18 +167,18 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName,
}
}
r.respondFromRemoteFile(w, logger, cfg, db, activeRemoteRequests)
r.respondFromRemoteFile(w, cfg, db, activeRemoteRequests)
} else {
// If we do not have a record and the origin is local, or if we have another error from the database, the file is not found
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger)
})
}
}
func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI) {
logger.WithFields(log.Fields{
func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, cfg config.MediaAPI) {
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
@ -191,7 +191,7 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
file, err := os.Open(filePath)
if err != nil {
// FIXME: Remove erroneous file from database?
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
})
@ -201,15 +201,15 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
stat, err := file.Stat()
if err != nil {
// FIXME: Remove erroneous file from database?
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger)
})
return
}
if r.MediaMetadata.ContentLength > 0 && int64(r.MediaMetadata.ContentLength) != stat.Size() {
logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size())
r.Logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size())
// FIXME: Remove erroneous file from database?
}
@ -223,22 +223,22 @@ func respondFromLocalFile(w http.ResponseWriter, logger *log.Entry, cfg config.M
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if bytesResponded, err := io.Copy(w, file); err != nil {
logger.Warnf("Failed to copy from cache %v\n", err)
r.Logger.Warnf("Failed to copy from cache %v\n", err)
if bytesResponded == 0 {
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)),
}, logger)
})
}
// If we have written any data then we have already responded with 200 OK and all we can do is close the connection
return
}
}
func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response, *util.JSONResponse) {
func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) {
urls := getMatrixUrls(r.MediaMetadata.Origin)
logger.Printf("Connecting to remote %q\n", urls[0])
r.Logger.Printf("Connecting to remote %q\n", urls[0])
remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil)
@ -261,7 +261,7 @@ func (r *downloadRequest) createRemoteRequest(logger *log.Entry) (*http.Response
}
if resp.StatusCode != 200 {
logger.Printf("Server responded with %d\n", resp.StatusCode)
r.Logger.Printf("Server responded with %d\n", resp.StatusCode)
if resp.StatusCode == 404 {
return nil, &util.JSONResponse{
Code: 404,
@ -335,16 +335,16 @@ func copyToActiveAndPassive(r io.Reader, wActive io.Writer, wPassive io.Writer,
return bytesResponded, bytesWritten, fetchError
}
func closeConnection(w http.ResponseWriter, logger *log.Entry) {
logger.Println("Attempting to close the connection.")
func (r *downloadRequest) closeConnection(w http.ResponseWriter) {
r.Logger.Println("Attempting to close the connection.")
hijacker, ok := w.(http.Hijacker)
if ok {
connection, _, hijackErr := hijacker.Hijack()
if hijackErr == nil {
logger.Println("Closing")
r.Logger.Println("Closing")
connection.Close()
} else {
logger.Printf("Error trying to hijack: %v", hijackErr)
r.Logger.Printf("Error trying to hijack: %v", hijackErr)
}
}
}
@ -357,10 +357,10 @@ func completeRemoteRequest(activeRemoteRequests *types.ActiveRemoteRequests, mxc
activeRemoteRequests.Unlock()
}
func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string, logger *log.Entry) bool {
func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string) bool {
updateActiveRemoteRequests := true
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
@ -377,7 +377,7 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
if err != nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return updateActiveRemoteRequests
}
@ -396,12 +396,12 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
finalDir := path.Dir(getPathFromMediaMetadata(r.MediaMetadata, basePath))
finalDirErr := os.RemoveAll(finalDir)
if finalDirErr != nil {
logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr)
r.Logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr)
}
completeRemoteRequest(activeRemoteRequests, mxcURL)
return updateActiveRemoteRequests
}
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Infof("Signalling other goroutines waiting for us to fetch the file.")
@ -409,8 +409,8 @@ func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, basePath type
return updateActiveRemoteRequests
}
func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *log.Entry, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
logger.WithFields(log.Fields{
func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Infof("Fetching remote file")
@ -431,9 +431,9 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
}()
// create request for remote file
resp, errorResponse := r.createRemoteRequest(logger)
resp, errorResponse := r.createRemoteRequest()
if errorResponse != nil {
jsonErrorResponse(w, *errorResponse, logger)
r.jsonErrorResponse(w, *errorResponse)
return
}
defer resp.Body.Close()
@ -441,7 +441,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
// get metadata from request and set metadata on response
contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil {
logger.Warn("Failed to parse content length")
r.Logger.Warn("Failed to parse content length")
}
r.MediaMetadata.ContentLength = types.ContentLength(contentLength)
@ -450,7 +450,7 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
// FIXME: parse from Content-Disposition header if possible, else fall back
//r.MediaMetadata.UploadName = types.Filename()
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Infof("Connected to remote")
@ -465,39 +465,39 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
// create the temporary file writer
tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, logger)
tmpFileWriter, tmpFile, tmpDir, errorResponse := createTempFileWriter(cfg.BasePath, r.Logger)
if errorResponse != nil {
jsonErrorResponse(w, *errorResponse, logger)
r.jsonErrorResponse(w, *errorResponse)
return
}
defer tmpFile.Close()
// read the remote request's response body
// simultaneously write it to the incoming request's response body and the temporary file
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Infof("Proxying and caching remote file")
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, logger)
bytesResponded, bytesWritten, fetchError := copyToActiveAndPassive(resp.Body, w, tmpFileWriter, cfg.MaxFileSize, r.MediaMetadata, r.Logger)
tmpFileWriter.Flush()
if fetchError != nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
r.Logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
// Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point
if bytesResponded < 1 {
jsonErrorResponse(w, util.JSONResponse{
r.jsonErrorResponse(w, util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
})
} else {
// We attempt to bluntly close the connection because that is the
// best thing we can do after we've sent a 200 OK
closeConnection(w, logger)
r.closeConnection(w)
}
return
}
@ -514,11 +514,11 @@ func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, logger *l
r.MediaMetadata.ContentLength = types.ContentLength(bytesWritten)
r.MediaMetadata.UserID = types.MatrixUserID("@:" + string(r.MediaMetadata.Origin))
updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL, logger)
updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, cfg.BasePath, activeRemoteRequests, db, mxcURL)
// TODO: generate thumbnails
logger.WithFields(log.Fields{
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,