diff --git a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go index 7d9725f72..7b525984e 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go @@ -16,6 +16,7 @@ package routing import ( "net/http" + "sync" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/mediaapi/config" @@ -37,6 +38,9 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, return writers.Upload(req, cfg, db) }))) + activeRemoteRequests := &types.ActiveRemoteRequests{ + Set: map[string]*sync.Cond{}, + } r0mux.Handle("/download/{serverName}/{mediaId}", prometheus.InstrumentHandler("download", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -47,7 +51,7 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, w.Header().Set("Content-Type", "application/json") vars := mux.Vars(req) - writers.Download(w, req, types.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db) + writers.Download(w, req, types.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests) })), ) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go index e1e1a3a44..34bf80655 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go @@ -14,6 +14,8 @@ package types +import "sync" + // ContentDisposition is an HTTP Content-Disposition header string type ContentDisposition string @@ -55,3 +57,10 @@ type MediaMetadata struct { UploadName Filename UserID MatrixUserID } + +// ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers +// It is used for ensuring multiple requests for the same file do not clobber each other. +type ActiveRemoteRequests struct { + sync.Mutex + Set map[string]*sync.Cond +} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 4204c6ffb..38aa8cc8f 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -25,6 +25,7 @@ import ( "path" "strconv" "strings" + "sync" log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -80,13 +81,15 @@ var errRead = fmt.Errorf("failed to read response from remote server") var errResponse = fmt.Errorf("failed to write file data to response body") var errWrite = fmt.Errorf("failed to write file to disk") +var nAttempts = 5 + // Download implements /download // Files from this server (i.e. origin == cfg.ServerName) are served directly // Files from remote servers (i.e. origin != cfg.ServerName) are cached locally. // If they are present in the cache, they are served directly. // If they are not present in the cache, they are obtained from the remote server and // simultaneously served back to the client and written into the cache. -func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg config.MediaAPI, db *storage.Database) { +func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) { logger := util.GetLogger(req.Context()) // request validation @@ -124,7 +127,38 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, "Origin": r.MediaMetadata.Origin, }).Infof("Fetching remote file") - // TODO: lock request in hash set + mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + + for attempts := 0; ; attempts++ { + activeRemoteRequests.Lock() + err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata) + if err == nil { + // If we have a record, we can respond from the local file + respondFromLocalFile(w, logger, r.MediaMetadata, cfg) + activeRemoteRequests.Unlock() + return + } + if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok { + if attempts >= nAttempts { + logger.Warnf("Other goroutines are trying to download the remote file and failing.") + jsonErrorResponse(w, util.JSONResponse{ + Code: 500, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + }, logger) + return + } + logger.WithFields(log.Fields{ + "Origin": r.MediaMetadata.Origin, + "MediaID": r.MediaMetadata.MediaID, + }).Infof("Waiting for another goroutine to fetch the file.") + activeRemoteRequestCondition.Wait() + activeRemoteRequests.Unlock() + } else { + activeRemoteRequests.Set[mxcURL] = &sync.Cond{L: activeRemoteRequests} + activeRemoteRequests.Unlock() + break + } + } // FIXME: Only request once (would race if multiple requests for the same remote file) // Use a hash set based on the origin and media ID (the request URL should be fine...) and synchronise adding / removing members @@ -319,20 +353,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, "Content-Disposition": r.MediaMetadata.ContentDisposition, }).Infof("Storing file metadata to media repository database") - // if written to disk, add to db - err = db.StoreMediaMetadata(r.MediaMetadata) - if err != nil { - tmpDirErr := os.RemoveAll(string(tmpDir)) - if tmpDirErr != nil { - logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) - } - return - } - - // TODO: unlock request in hash set - - // TODO: generate thumbnails - + // The database is the source of truth so we need to have moved the file first err = moveFile( types.Path(path.Join(string(tmpDir), "content")), types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)), @@ -345,6 +366,36 @@ func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, return } + // Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic. + // If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT. + // Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests. + // If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database. + // The locking below mitigates this situation. + activeRemoteRequests.Lock() + // FIXME: unlock after timeout of db request + // if written to disk, add to db + err = db.StoreMediaMetadata(r.MediaMetadata) + if err != nil { + finalDir := path.Dir(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)) + finalDirErr := os.RemoveAll(finalDir) + if finalDirErr != nil { + logger.Warnf("Failed to remove finalDir (%v): %q\n", finalDir, finalDirErr) + } + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() + return + } + activeRemoteRequestCondition, _ := activeRemoteRequests.Set[mxcURL] + logger.WithFields(log.Fields{ + "Origin": r.MediaMetadata.Origin, + "MediaID": r.MediaMetadata.MediaID, + }).Infof("Signalling other goroutines waiting for us to fetch the file.") + activeRemoteRequestCondition.Broadcast() + delete(activeRemoteRequests.Set, mxcURL) + activeRemoteRequests.Unlock() + + // TODO: generate thumbnails + logger.WithFields(log.Fields{ "MediaID": r.MediaMetadata.MediaID, "Origin": r.MediaMetadata.Origin,