From 7cf34af30bc5a97b4045ab239e3aa83fd8f8ac77 Mon Sep 17 00:00:00 2001 From: Robert Swain Date: Tue, 9 May 2017 19:49:39 +0200 Subject: [PATCH] WIP: Refactoring --- .../cmd/dendrite-media-api-server/main.go | 20 +- .../dendrite/mediaapi/config/config.go | 9 +- .../dendrite/mediaapi/routing/routing.go | 45 +- .../dendrite/mediaapi/storage/fileio.go | 92 --- .../dendrite/mediaapi/storage/media.go | 89 --- .../storage/media_repository_table.go | 107 ++++ .../dendrite/mediaapi/storage/repository.go | 283 --------- .../dendrite/mediaapi/storage/storage.go | 20 +- .../dendrite/mediaapi/types/types.go | 57 ++ .../dendrite/mediaapi/writers/download.go | 555 ++++++++++-------- .../dendrite/mediaapi/writers/upload.go | 190 ++++-- .../dendrite/mediaapi/writers/utils.go | 80 +++ 12 files changed, 718 insertions(+), 829 deletions(-) delete mode 100644 src/github.com/matrix-org/dendrite/mediaapi/storage/fileio.go delete mode 100644 src/github.com/matrix-org/dendrite/mediaapi/storage/media.go create mode 100644 src/github.com/matrix-org/dendrite/mediaapi/storage/media_repository_table.go delete mode 100644 src/github.com/matrix-org/dendrite/mediaapi/storage/repository.go create mode 100644 src/github.com/matrix-org/dendrite/mediaapi/types/types.go create mode 100644 src/github.com/matrix-org/dendrite/mediaapi/writers/utils.go diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go index 31d8bfb62..aa6ffef81 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-media-api-server/main.go @@ -27,9 +27,9 @@ import ( ) var ( - bindAddr = os.Getenv("BIND_ADDRESS") - database = os.Getenv("DATABASE") - logDir = os.Getenv("LOG_DIR") + bindAddr = os.Getenv("BIND_ADDRESS") + dataSource = os.Getenv("DATABASE") + logDir = os.Getenv("LOG_DIR") ) func main() { @@ -40,9 +40,10 @@ func main() { } cfg := config.MediaAPI{ - ServerName: "localhost", - BasePath: "/Users/robertsw/dendrite", - DataSource: database, + ServerName: "localhost", + BasePath: "/Users/robertsw/dendrite", + MaxFileSize: 61440, + DataSource: dataSource, } db, err := storage.Open(cfg.DataSource) @@ -50,13 +51,8 @@ func main() { log.Panicln("Failed to open database:", err) } - repo := &storage.Repository{ - StorePrefix: cfg.BasePath, - MaxBytes: 61440, - } - log.Info("Starting mediaapi") - routing.Setup(http.DefaultServeMux, http.DefaultClient, cfg, db, repo) + routing.Setup(http.DefaultServeMux, http.DefaultClient, cfg, db) log.Fatal(http.ListenAndServe(bindAddr, nil)) } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/config/config.go b/src/github.com/matrix-org/dendrite/mediaapi/config/config.go index 5900d9d56..2002cd86a 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/config/config.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/config/config.go @@ -14,12 +14,17 @@ package config +import "github.com/matrix-org/dendrite/mediaapi/types" + // MediaAPI contains the config information necessary to spin up a mediaapi process. type MediaAPI struct { // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. - ServerName string `yaml:"server_name"` + ServerName types.ServerName `yaml:"server_name"` // The base path to where media files will be stored. - BasePath string `yaml:"base_path"` + BasePath types.Path `yaml:"base_path"` + // The maximum file size in bytes that is allowed to be stored on this server. + // Note that remote files larger than this can still be proxied to a client, they will just not be cached. + MaxFileSize types.ContentLength `yaml:"base_path"` // The postgres connection config for connecting to the database e.g a postgres:// URI DataSource string `yaml:"database"` } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go index 5319d9caa..daa598628 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/routing/routing.go @@ -20,6 +20,7 @@ import ( "github.com/gorilla/mux" "github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/writers" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -27,45 +28,27 @@ import ( const pathPrefixR0 = "/_matrix/media/v1" -type downloadRequestHandler struct { - Config config.MediaAPI - Database *storage.Database - DownloadServer writers.DownloadServer -} - -func (handler downloadRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - util.SetupRequestLogging(req) - - // Set common headers returned regardless of the outcome of the request - util.SetCORSHeaders(w) - w.Header().Set("Content-Type", "application/json") - - vars := mux.Vars(req) - writers.Download(w, req, vars["serverName"], vars["mediaId"], handler.Config, handler.Database, handler.DownloadServer) -} - // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // to clients which need to make outbound HTTP requests. -func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) { +func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database) { apiMux := mux.NewRouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux.Handle("/upload", make("upload", util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return writers.Upload(req, cfg, db, repo) + return writers.Upload(req, cfg, db) }))) - downloadServer := writers.DownloadServer{ - Repository: *repo, - LocalServerName: cfg.ServerName, - } - - handler := downloadRequestHandler{ - Config: cfg, - Database: db, - DownloadServer: downloadServer, - } - r0mux.Handle("/download/{serverName}/{mediaId}", - prometheus.InstrumentHandler("download", handler), + prometheus.InstrumentHandler("download", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + util.SetupRequestLogging(req) + + // Set common headers returned regardless of the outcome of the request + util.SetCORSHeaders(w) + // TODO: fix comment + w.Header().Set("Content-Type", "application/json") + + vars := mux.Vars(req) + writers.Download(w, req, types.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db) + })), ) servMux.Handle("/metrics", prometheus.Handler()) diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/fileio.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/fileio.go deleted file mode 100644 index 5bd87ff8c..000000000 --- a/src/github.com/matrix-org/dendrite/mediaapi/storage/fileio.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - "os" - - log "github.com/Sirupsen/logrus" -) - -// LimitedFileWriter writes only a limited number of bytes to a file. -// -// If the callee attempts to write more bytes the file is deleted and further -// writes are silently discarded. -// -// This isn't thread safe. -type LimitedFileWriter struct { - filePath string - file *os.File - writtenBytes uint64 - maxBytes uint64 -} - -// NewLimitedFileWriter creates a new LimitedFileWriter at the given location. -// -// If a file already exists at the location it is immediately truncated. -// -// A maxBytes of 0 or negative is treated as no limit. -func NewLimitedFileWriter(filePath string, maxBytes uint64) (*LimitedFileWriter, error) { - file, err := os.Create(filePath) - if err != nil { - return nil, err - } - - writer := LimitedFileWriter{ - filePath: filePath, - file: file, - maxBytes: maxBytes, - } - - return &writer, nil -} - -// Close closes the underlying file descriptor, if its open. -// -// Any error comes from File.Close -func (writer *LimitedFileWriter) Close() error { - if writer.file != nil { - file := writer.file - writer.file = nil - return file.Close() - } - return nil -} - -func (writer *LimitedFileWriter) Write(p []byte) (n int, err error) { - if writer.maxBytes > 0 && uint64(len(p))+writer.writtenBytes > writer.maxBytes { - if writer.file != nil { - writer.Close() - err = os.Remove(writer.filePath) - if err != nil { - log.Printf("Failed to delete file %v\n", err) - } - } - - return 0, fmt.Errorf("Reached limit") - } - - if writer.file != nil { - n, err = writer.file.Write(p) - writer.writtenBytes += uint64(n) - - if err != nil { - log.Printf("Failed to write to file %v\n", err) - } - } - - return -} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/media.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/media.go deleted file mode 100644 index 8aee283d4..000000000 --- a/src/github.com/matrix-org/dendrite/mediaapi/storage/media.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "database/sql" - "time" -) - -const mediaSchema = ` --- The events table holds metadata for each media upload to the local server, --- the actual file is stored separately. -CREATE TABLE IF NOT EXISTS media_repository ( - -- The id used to refer to the media. - -- This is a base64-encoded sha256 hash of the file data - media_id TEXT PRIMARY KEY, - -- The origin of the media as requested by the client. - media_origin TEXT NOT NULL, - -- The MIME-type of the media file. - content_type TEXT NOT NULL, - -- The HTTP Content-Disposition header for the media file. - content_disposition TEXT NOT NULL, - -- Size of the media file in bytes. - file_size BIGINT NOT NULL, - -- When the content was uploaded in ms. - created_ts BIGINT NOT NULL, - -- The name with which the media was uploaded. - upload_name TEXT NOT NULL, - -- The user who uploaded the file. - user_id TEXT NOT NULL, - UNIQUE(media_id, media_origin) -); -` - -const insertMediaSQL = ` -INSERT INTO media_repository (media_id, media_origin, content_type, content_disposition, file_size, created_ts, upload_name, user_id) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -` - -const selectMediaSQL = ` -SELECT content_type, content_disposition, file_size, upload_name FROM media_repository WHERE media_id = $1 AND media_origin = $2 -` - -type mediaStatements struct { - insertMediaStmt *sql.Stmt - selectMediaStmt *sql.Stmt -} - -func (s *mediaStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(mediaSchema) - if err != nil { - return - } - - return statementList{ - {&s.insertMediaStmt, insertMediaSQL}, - {&s.selectMediaStmt, selectMediaSQL}, - }.prepare(db) -} - -func (s *mediaStatements) insertMedia(mediaID string, mediaOrigin string, contentType string, - contentDisposition string, fileSize int64, uploadName string, userID string) error { - _, err := s.insertMediaStmt.Exec( - mediaID, mediaOrigin, contentType, contentDisposition, fileSize, - int64(time.Now().UnixNano()/1000000), uploadName, userID, - ) - return err -} - -func (s *mediaStatements) selectMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) { - var contentType string - var contentDisposition string - var fileSize int64 - var uploadName string - err := s.selectMediaStmt.QueryRow(mediaID, mediaOrigin).Scan(&contentType, &contentDisposition, &fileSize, &uploadName) - return string(contentType), string(contentDisposition), int64(fileSize), string(uploadName), err -} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/media_repository_table.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/media_repository_table.go new file mode 100644 index 000000000..11b9064f1 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/mediaapi/storage/media_repository_table.go @@ -0,0 +1,107 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "database/sql" + "time" + + "github.com/matrix-org/dendrite/mediaapi/types" +) + +const mediaSchema = ` +-- The media_repository table holds metadata for each media file stored and accessible to the local server, +-- the actual file is stored separately. +CREATE TABLE IF NOT EXISTS media_repository ( + -- The id used to refer to the media. + -- For uploads to this server this is a base64-encoded sha256 hash of the file data + -- For media from remote servers, this can be any unique identifier string + media_id TEXT NOT NULL, + -- The origin of the media as requested by the client. Should be a homeserver domain. + media_origin TEXT NOT NULL, + -- The MIME-type of the media file as specified when uploading. + content_type TEXT NOT NULL, + -- The HTTP Content-Disposition header for the media file as specified when uploading. + content_disposition TEXT NOT NULL, + -- Size of the media file in bytes. + content_length BIGINT NOT NULL, + -- When the content was uploaded in UNIX epoch ms. + creation_ts BIGINT NOT NULL, + -- The file name with which the media was uploaded. + upload_name TEXT NOT NULL, + -- The user who uploaded the file. Should be a Matrix user ID. + user_id TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS media_repository_index ON media_repository (media_id, media_origin); +` + +const insertMediaSQL = ` +INSERT INTO media_repository (media_id, media_origin, content_type, content_disposition, content_length, creation_ts, upload_name, user_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +` + +const selectMediaSQL = ` +SELECT content_type, content_disposition, content_length, creation_ts, upload_name, user_id FROM media_repository WHERE media_id = $1 AND media_origin = $2 +` + +type mediaStatements struct { + insertMediaStmt *sql.Stmt + selectMediaStmt *sql.Stmt +} + +func (s *mediaStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(mediaSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertMediaStmt, insertMediaSQL}, + {&s.selectMediaStmt, selectMediaSQL}, + }.prepare(db) +} + +func (s *mediaStatements) insertMedia(mediaMetadata *types.MediaMetadata) error { + mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) + _, err := s.insertMediaStmt.Exec( + mediaMetadata.MediaID, + mediaMetadata.Origin, + mediaMetadata.ContentType, + mediaMetadata.ContentDisposition, + mediaMetadata.ContentLength, + mediaMetadata.CreationTimestamp, + mediaMetadata.UploadName, + mediaMetadata.UserID, + ) + return err +} + +func (s *mediaStatements) selectMedia(mediaID types.MediaID, mediaOrigin types.ServerName) (*types.MediaMetadata, error) { + mediaMetadata := types.MediaMetadata{ + MediaID: mediaID, + Origin: mediaOrigin, + } + err := s.selectMediaStmt.QueryRow( + mediaMetadata.MediaID, mediaMetadata.Origin, + ).Scan( + &mediaMetadata.ContentType, + &mediaMetadata.ContentDisposition, + &mediaMetadata.ContentLength, + &mediaMetadata.CreationTimestamp, + &mediaMetadata.UploadName, + &mediaMetadata.UserID, + ) + return &mediaMetadata, err +} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/repository.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/repository.go deleted file mode 100644 index 2378646ae..000000000 --- a/src/github.com/matrix-org/dendrite/mediaapi/storage/repository.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "crypto/sha256" - "encoding/base64" - "hash" - "io" - "io/ioutil" - "os" - "path" - - log "github.com/Sirupsen/logrus" -) - -// Description contains various attributes for an image. -type Description struct { - Type string - Length int64 -} - -type repositoryPaths struct { - contentPath string - typePath string -} - -// Repository stores locally uploaded media, and caches remote media that has -// been requested. -type Repository struct { - StorePrefix string - MaxBytes uint64 -} - -// ReaderFromRemoteCache returns a io.ReadCloser with the cached remote content, -// if it exists. Use IsNotExist to check if the error was due to it not existing -// in the cache -func (repo Repository) ReaderFromRemoteCache(host, name string) (io.ReadCloser, *Description, error) { - mediaDir := repo.getDirForRemoteMedia(host, name) - repoPaths := getPathsForMedia(mediaDir) - - return repo.readerFromRepository(repoPaths) -} - -// ReaderFromLocalRepo returns a io.ReadCloser with the locally uploaded content, -// if it exists. Use IsNotExist to check if the error was due to it not existing -// in the cache -func (repo Repository) ReaderFromLocalRepo(name string) (io.ReadCloser, *Description, error) { - mediaDir := repo.getDirForLocalMedia(name) - repoPaths := getPathsForMedia(mediaDir) - - return repo.readerFromRepository(repoPaths) -} - -func (repo Repository) readerFromRepository(repoPaths repositoryPaths) (io.ReadCloser, *Description, error) { - contentTypeBytes, err := ioutil.ReadFile(repoPaths.typePath) - if err != nil { - return nil, nil, err - } - - contentType := string(contentTypeBytes) - - file, err := os.Open(repoPaths.contentPath) - if err != nil { - return nil, nil, err - } - - stat, err := file.Stat() - if err != nil { - return nil, nil, err - } - - descr := Description{ - Type: contentType, - Length: stat.Size(), - } - - return file, &descr, nil -} - -// WriterToLocalRepository returns a RepositoryWriter for writing newly uploaded -// content into the repository. -// -// The returned RepositoryWriter will fail if more than MaxBytes tries to be -// written. -func (repo Repository) WriterToLocalRepository(descr Description) (RepositoryWriter, error) { - return newLocalRepositoryWriter(repo, descr) -} - -// WriterToRemoteCache returns a RepositoryWriter for caching newly downloaded -// remote content. -// -// The returned RepositoryWriter will silently stop writing if more than MaxBytes -// tries to be written and does *not* return an error. -func (repo Repository) WriterToRemoteCache(host, name string, descr Description) (RepositoryWriter, error) { - return newRemoteRepositoryWriter(repo, host, name, descr) -} - -func (repo *Repository) makeTempDir() (string, error) { - tmpDir := path.Join(repo.StorePrefix, "tmp") - os.MkdirAll(tmpDir, 0770) - return ioutil.TempDir(tmpDir, "") -} - -func (repo *Repository) getDirForLocalMedia(name string) string { - return path.Join(repo.StorePrefix, "local", name[:3], name[3:]) -} - -func (repo *Repository) getDirForRemoteMedia(host, sanitizedName string) string { - return path.Join(repo.StorePrefix, "remote", host, sanitizedName[:3], sanitizedName[3:]) -} - -// Get the actual paths for the data and metadata associated with remote media. -func getPathsForMedia(dir string) repositoryPaths { - contentPath := path.Join(dir, "content") - typePath := path.Join(dir, "type") - return repositoryPaths{ - contentPath: contentPath, - typePath: typePath, - } -} - -// IsNotExists check if error was due to content not existing in cache. -func IsNotExists(err error) bool { return os.IsNotExist(err) } - -// RepositoryWriter is used to either store into the repository newly uploaded -// media or to cache recently fetched remote media. -type RepositoryWriter interface { - io.WriteCloser - - // Finished should be called when successfully finished writing; otherwise - // the written content will not be committed to the repository. - Finished() (string, error) -} - -type remoteRepositoryWriter struct { - tmpDir string - finalDir string - name string - file io.WriteCloser - erred bool -} - -func newRemoteRepositoryWriter(repo Repository, host, name string, descr Description) (*remoteRepositoryWriter, error) { - tmpFile, tmpDir, err := getTempWriter(repo, descr) - if err != nil { - log.Printf("Failed to create writer: %v\n", err) - return nil, err - } - - return &remoteRepositoryWriter{ - tmpDir: tmpDir, - finalDir: repo.getDirForRemoteMedia(host, name), - name: name, - file: tmpFile, - erred: false, - }, nil -} - -func (writer remoteRepositoryWriter) Write(p []byte) (int, error) { - // Its OK to fail when writing to the remote repo. We just hide the error - // from the layers above - if !writer.erred { - if _, err := writer.file.Write(p); err != nil { - writer.erred = true - } - } - return len(p), nil -} - -func (writer remoteRepositoryWriter) Close() error { - os.RemoveAll(writer.tmpDir) - writer.file.Close() - return nil -} - -func (writer remoteRepositoryWriter) Finished() (string, error) { - var err error - if !writer.erred { - os.MkdirAll(path.Dir(writer.finalDir), 0770) - err = os.Rename(writer.tmpDir, writer.finalDir) - if err != nil { - return "", err - } - } - err = writer.Close() - return writer.name, err -} - -type localRepositoryWriter struct { - repo Repository - tmpDir string - hasher hash.Hash - file io.WriteCloser - finished bool -} - -func newLocalRepositoryWriter(repo Repository, descr Description) (*localRepositoryWriter, error) { - tmpFile, tmpDir, err := getTempWriter(repo, descr) - if err != nil { - return nil, err - } - - return &localRepositoryWriter{ - repo: repo, - tmpDir: tmpDir, - hasher: sha256.New(), - file: tmpFile, - finished: false, - }, nil -} - -func (writer localRepositoryWriter) Write(p []byte) (int, error) { - writer.hasher.Write(p) // Never errors. - n, err := writer.file.Write(p) - if err != nil { - writer.Close() - } - return n, err -} - -func (writer localRepositoryWriter) Close() error { - var err error - if !writer.finished { - err = os.RemoveAll(writer.tmpDir) - if err != nil { - return err - } - } - - err = writer.file.Close() - return err -} - -func (writer localRepositoryWriter) Finished() (string, error) { - hash := writer.hasher.Sum(nil) - name := base64.URLEncoding.EncodeToString(hash[:]) - finalDir := writer.repo.getDirForLocalMedia(name) - os.MkdirAll(path.Dir(finalDir), 0770) - err := os.Rename(writer.tmpDir, finalDir) - if err != nil { - log.Println("Failed to move temp directory:", writer.tmpDir, finalDir, err) - return "", err - } - writer.finished = true - writer.Close() - return name, nil -} - -func getTempWriter(repo Repository, descr Description) (io.WriteCloser, string, error) { - tmpDir, err := repo.makeTempDir() - if err != nil { - log.Printf("Failed to create temp dir: %v\n", err) - return nil, "", err - } - - repoPaths := getPathsForMedia(tmpDir) - - if err = ioutil.WriteFile(repoPaths.typePath, []byte(descr.Type), 0660); err != nil { - log.Printf("Failed to create typeFile: %q\n", err) - return nil, "", err - } - - tmpFile, err := NewLimitedFileWriter(repoPaths.contentPath, repo.MaxBytes) - if err != nil { - log.Printf("Failed to create limited file: %v\n", err) - return nil, "", err - } - - return tmpFile, tmpDir, nil -} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go b/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go index 72dc0a62f..121a06354 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/storage/storage.go @@ -19,6 +19,7 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/mediaapi/types" ) // A Database is used to store room events and stream offsets. @@ -40,12 +41,19 @@ func Open(dataSourceName string) (*Database, error) { return &d, nil } -// CreateMedia inserts the metadata about the uploaded media into the database. -func (d *Database) CreateMedia(mediaID string, mediaOrigin string, contentType string, contentDisposition string, fileSize int64, uploadName string, userID string) error { - return d.statements.insertMedia(mediaID, mediaOrigin, contentType, contentDisposition, fileSize, uploadName, userID) +// StoreMediaMetadata inserts the metadata about the uploaded media into the database. +func (d *Database) StoreMediaMetadata(mediaMetadata *types.MediaMetadata) error { + return d.statements.insertMedia(mediaMetadata) } -// GetMedia possibly selects the metadata about previously uploaded media from the database. -func (d *Database) GetMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) { - return d.statements.selectMedia(mediaID, mediaOrigin) +// GetMediaMetadata possibly selects the metadata about previously uploaded media from the database. +func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin types.ServerName, mediaMetadata *types.MediaMetadata) error { + metadata, err := d.statements.selectMedia(mediaID, mediaOrigin) + mediaMetadata.ContentType = metadata.ContentType + mediaMetadata.ContentDisposition = metadata.ContentDisposition + mediaMetadata.ContentLength = metadata.ContentLength + mediaMetadata.CreationTimestamp = metadata.CreationTimestamp + mediaMetadata.UploadName = metadata.UploadName + mediaMetadata.UserID = metadata.UserID + return err } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/types/types.go b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go new file mode 100644 index 000000000..e1e1a3a44 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/mediaapi/types/types.go @@ -0,0 +1,57 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +// ContentDisposition is an HTTP Content-Disposition header string +type ContentDisposition string + +// ContentLength is an HTTP Content-Length header which is a number of bytes to be expected in a request body +type ContentLength int64 + +// ContentType is an HTTP Content-Type header string representing the MIME type of a request body +type ContentType string + +// Filename is a string representing the name of a file +type Filename string + +// Path is an absolute or relative UNIX filesystem path +type Path string + +// MediaID is a string representing the unique identifier for a file (could be a hash but does not have to be) +type MediaID string + +// ServerName is the host of a matrix homeserver, e.g. matrix.org +type ServerName string + +// RequestMethod is an HTTP request method i.e. GET, POST, etc +type RequestMethod string + +// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org +type MatrixUserID string + +// UnixMs is the milliseconds since the Unix epoch +type UnixMs int64 + +// MediaMetadata is metadata associated with a media file +type MediaMetadata struct { + MediaID MediaID + Origin ServerName + ContentType ContentType + ContentDisposition ContentDisposition + ContentLength ContentLength + CreationTimestamp UnixMs + UploadName Filename + UserID MatrixUserID +} diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go index 730746603..ec270e914 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/download.go @@ -15,11 +15,14 @@ package writers import ( + "database/sql" "encoding/json" "fmt" "io" "net" "net/http" + "os" + "path" "strconv" "strings" @@ -27,30 +30,30 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/util" ) -// DownloadRequest metadata included in or derivable from an upload request +// downloadRequest metadata included in or derivable from an download request // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download -type DownloadRequest struct { - MediaID string - ServerName string +type downloadRequest struct { + MediaMetadata *types.MediaMetadata } -// Validate validates the DownloadRequest fields -func (r DownloadRequest) Validate() *util.JSONResponse { +// Validate validates the downloadRequest fields +func (r downloadRequest) Validate() *util.JSONResponse { // FIXME: the following errors aren't bad JSON, rather just a bad request path // maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...? - if r.MediaID == "" { + if r.MediaMetadata.MediaID == "" { return &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("mediaId must be a non-empty string"), + Code: 404, + JSON: jsonerror.NotFound("mediaId must be a non-empty string"), } } - if r.ServerName == "" { + if r.MediaMetadata.Origin == "" { return &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("serverName must be a non-empty string"), + Code: 404, + JSON: jsonerror.NotFound("serverName must be a non-empty string"), } } return nil @@ -72,10 +75,21 @@ func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log w.Write(resBytes) } +var errFileIsTooLarge = fmt.Errorf("file is too large") +var errRead = fmt.Errorf("failed to read response from remote server") +var errResponse = fmt.Errorf("failed to write file data to response body") +var errWrite = fmt.Errorf("failed to write file to disk") + // Download implements /download -func Download(w http.ResponseWriter, req *http.Request, serverName string, mediaID string, cfg config.MediaAPI, db *storage.Database, downloadServer DownloadServer) { +// Files from this server (i.e. origin == cfg.ServerName) are served directly +// Files from remote servers (i.e. origin != cfg.ServerName) are cached locally. +// If they are present in the cache, they are served directly. +// If they are not present in the cache, they are obtained from the remote server and +// simultaneously served back to the client and written into the cache. +func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg config.MediaAPI, db *storage.Database) { logger := util.GetLogger(req.Context()) + // request validation if req.Method != "GET" { jsonErrorResponse(w, util.JSONResponse{ Code: 405, @@ -84,9 +98,11 @@ func Download(w http.ResponseWriter, req *http.Request, serverName string, media return } - r := &DownloadRequest{ - MediaID: mediaID, - ServerName: serverName, + r := &downloadRequest{ + MediaMetadata: &types.MediaMetadata{ + MediaID: mediaID, + Origin: origin, + }, } if resErr := r.Validate(); resErr != nil { @@ -94,190 +110,285 @@ func Download(w http.ResponseWriter, req *http.Request, serverName string, media return } - contentType, contentDisposition, fileSize, filename, err := db.GetMedia(r.MediaID, r.ServerName) - if err != nil && strings.Compare(r.ServerName, cfg.ServerName) == 0 { + // check if we have a record of the media in our database + err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata) + + if err == nil { + // If we have a record, we can respond from the local file + logger.WithFields(log.Fields{ + "MediaID": r.MediaMetadata.MediaID, + "Origin": r.MediaMetadata.Origin, + "UploadName": r.MediaMetadata.UploadName, + "Content-Length": r.MediaMetadata.ContentLength, + "Content-Type": r.MediaMetadata.ContentType, + "Content-Disposition": r.MediaMetadata.ContentDisposition, + }).Infof("Downloading file") + + filePath := getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath) + file, err := os.Open(filePath) + if err != nil { + // FIXME: Remove erroneous file from database? + jsonErrorResponse(w, util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + }, logger) + return + } + + stat, err := file.Stat() + if err != nil { + // FIXME: Remove erroneous file from database? + jsonErrorResponse(w, util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + }, logger) + return + } + + if r.MediaMetadata.ContentLength > 0 && int64(r.MediaMetadata.ContentLength) != stat.Size() { + logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size()) + // FIXME: Remove erroneous file from database? + } + + w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType)) + w.Header().Set("Content-Length", strconv.FormatInt(stat.Size(), 10)) + contentSecurityPolicy := "default-src 'none';" + + " script-src 'none';" + + " plugin-types application/pdf;" + + " style-src 'unsafe-inline';" + + " object-src 'self';" + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + + if bytesResponded, err := io.Copy(w, file); err != nil { + logger.Warnf("Failed to copy from cache %v\n", err) + if bytesResponded == 0 { + jsonErrorResponse(w, util.JSONResponse{ + Code: 500, + JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)), + }, logger) + } + // If we have written any data then we have already responded with 200 OK and all we can do is close the connection + return + } + } else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName { + // If we do not have a record and the origin is remote, we need to fetch it and respond with that file + logger.WithFields(log.Fields{ + "MediaID": r.MediaMetadata.MediaID, + "Origin": r.MediaMetadata.Origin, + "UploadName": r.MediaMetadata.UploadName, + "Content-Length": r.MediaMetadata.ContentLength, + "Content-Type": r.MediaMetadata.ContentType, + "Content-Disposition": r.MediaMetadata.ContentDisposition, + }).Infof("Fetching remote file") + + // TODO: lock request in hash set + + // FIXME: Only request once (would race if multiple requests for the same remote file) + // Use a hash set based on the origin and media ID (the request URL should be fine...) and synchronise adding / removing members + urls := getMatrixUrls(r.MediaMetadata.Origin) + + logger.Printf("Connecting to remote %q\n", urls[0]) + + remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) + remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil) + if err != nil { + jsonErrorResponse(w, util.JSONResponse{ + Code: 500, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + }, logger) + return + } + + remoteReq.Header.Set("Host", string(r.MediaMetadata.Origin)) + + client := http.Client{} + resp, err := client.Do(remoteReq) + if err != nil { + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + }, logger) + return + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + logger.Printf("Server responded with %d\n", resp.StatusCode) + if resp.StatusCode == 404 { + jsonErrorResponse(w, util.JSONResponse{ + Code: 404, + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), + }, logger) + return + } + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + }, logger) + return + } + + contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + logger.Warn("Failed to parse content length") + } + r.MediaMetadata.ContentLength = types.ContentLength(contentLength) + + w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType)) + w.Header().Set("Content-Length", strconv.FormatInt(int64(r.MediaMetadata.ContentLength), 10)) + contentSecurityPolicy := "default-src 'none';" + + " script-src 'none';" + + " plugin-types application/pdf;" + + " style-src 'unsafe-inline';" + + " object-src 'self';" + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + + tmpDir, err := createTempDir(cfg.BasePath) + if err != nil { + logger.Infof("Failed to create temp dir %q\n", err) + jsonErrorResponse(w, util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + }, logger) + return + } + tmpFile, writer, err := createFileWriter(tmpDir, types.Filename(r.MediaMetadata.MediaID[3:])) + if err != nil { + logger.Infof("Failed to create file writer %q\n", err) + jsonErrorResponse(w, util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + }, logger) + return + } + defer tmpFile.Close() + + // bytesResponded is the total number of bytes written to the response to the client request + // bytesWritten is the total number of bytes written to disk + var bytesResponded, bytesWritten int64 = 0, 0 + var fetchError error + // Note: the buffer size is the same as is used in io.Copy() + buffer := make([]byte, 32*1024) + for { + // read from remote request's response body + bytesRead, readErr := resp.Body.Read(buffer) + if bytesRead > 0 { + // write to client request's response body + bytesTemp, respErr := w.Write(buffer[:bytesRead]) + if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) { + // TODO: BORKEN + logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr) + fetchError = errResponse + break + } + bytesResponded += int64(bytesTemp) + if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) { + // if larger than cfg.MaxFileSize then stop writing to disk and discard cached file + if bytesWritten+int64(len(buffer)) > int64(cfg.MaxFileSize) { + // TODO: WAAAAHNING and clean up temp files + fetchError = errFileIsTooLarge + } else { + // write to disk + bytesTemp, writeErr := writer.Write(buffer) + if writeErr != nil && writeErr != io.EOF { + // TODO: WAAAAHNING and clean up temp files + fetchError = errWrite + } else { + bytesWritten += int64(bytesTemp) + } + } + } + } + if readErr != nil { + if readErr != io.EOF { + fetchError = errRead + break + } + } + } + + writer.Flush() + + if fetchError != nil { + logFields := log.Fields{ + "MediaID": r.MediaMetadata.MediaID, + "Origin": r.MediaMetadata.Origin, + } + if fetchError == errFileIsTooLarge { + logFields["MaxFileSize"] = cfg.MaxFileSize + } + logger.WithFields(logFields).Warnln(fetchError) + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point + if bytesResponded < 1 { + jsonErrorResponse(w, util.JSONResponse{ + Code: 502, + JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)), + }, logger) + } else { + // We attempt to bluntly close the connection because that is the + // best thing we can do after we've sent a 200 OK + logger.Println("Attempting to close the connection.") + hijacker, ok := w.(http.Hijacker) + if ok { + connection, _, hijackErr := hijacker.Hijack() + if hijackErr == nil { + logger.Println("Closing") + connection.Close() + } else { + logger.Printf("Error trying to hijack: %v", hijackErr) + } + } + } + return + } + + // Note: After this point we have responded to the client's request and are just dealing with local caching. + // As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return. + + // if written to disk, add to db + err = db.StoreMediaMetadata(r.MediaMetadata) + if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return + } + + // TODO: unlock request in hash set + + // TODO: generate thumbnails + + err = moveFile( + types.Path(path.Join(string(tmpDir), "content")), + types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)), + ) + if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return + } + } else { + // TODO: If we do not have a record and the origin is local, or if we have another error from the database, the file is not found jsonErrorResponse(w, util.JSONResponse{ Code: 404, - JSON: jsonerror.NotFound(fmt.Sprintf("File %q does not exist", r.MediaID)), + JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), }, logger) - return } - - // - read file and respond - logger.WithFields(log.Fields{ - "MediaID": r.MediaID, - "ServerName": r.ServerName, - "Filename": filename, - "Content-Type": contentType, - "Content-Disposition": contentDisposition, - }).Infof("Downloading file") - - logger.WithField("code", 200).Infof("Responding (%d bytes)", fileSize) - - respWriter := httpResponseWriter{resp: w} - if err = downloadServer.getImage(respWriter, r.ServerName, r.MediaID); err != nil { - if respWriter.haveWritten() { - closeConnection(w) - return - } - - errStatus := 500 - switch err { - case errNotFound: - errStatus = 404 - case errProxy: - errStatus = 502 - } - http.Error(w, err.Error(), errStatus) - return - } - - return -} - -// DownloadServer serves and caches remote media. -type DownloadServer struct { - Client http.Client - Repository storage.Repository - LocalServerName string -} - -func (handler *DownloadServer) getImage(w responseWriter, host, name string) error { - var file io.ReadCloser - var descr *storage.Description - var err error - if host == handler.LocalServerName { - file, descr, err = handler.Repository.ReaderFromLocalRepo(name) - } else { - file, descr, err = handler.Repository.ReaderFromRemoteCache(host, name) - } - - if err == nil { - log.Println("Found in Cache") - w.setContentType(descr.Type) - - size := strconv.FormatInt(descr.Length, 10) - w.setContentLength(size) - w.setContentSecurityPolicy() - if _, err = io.Copy(w, file); err != nil { - log.Printf("Failed to copy from cache %v\n", err) - return err - } - w.Close() - return nil - } else if !storage.IsNotExists(err) { - log.Printf("Error looking in cache: %v\n", err) - return err - } - - if host == handler.LocalServerName { - // Its fatal if we can't find local files in our cache. - return errNotFound - } - - respBody, desc, err := handler.fetchRemoteMedia(host, name) - if err != nil { - return err - } - - defer respBody.Close() - - w.setContentType(desc.Type) - if desc.Length > 0 { - w.setContentLength(strconv.FormatInt(desc.Length, 10)) - } - - writer, err := handler.Repository.WriterToRemoteCache(host, name, *desc) - if err != nil { - log.Printf("Failed to get cache writer %q\n", err) - return err - } - - defer writer.Close() - - reader := io.TeeReader(respBody, w) - if _, err := io.Copy(writer, reader); err != nil { - log.Printf("Failed to copy %q\n", err) - return err - } - - writer.Finished() - - log.Println("Finished conn") - - return nil -} - -func (handler *DownloadServer) fetchRemoteMedia(host, name string) (io.ReadCloser, *storage.Description, error) { - urls := getMatrixUrls(host) - - log.Printf("Connecting to remote %q\n", urls[0]) - - remoteReq, err := http.NewRequest("GET", urls[0]+"/_matrix/media/v1/download/"+host+"/"+name, nil) - if err != nil { - log.Printf("Failed to connect to remote: %q\n", err) - return nil, nil, err - } - - remoteReq.Header.Set("Host", host) - - resp, err := handler.Client.Do(remoteReq) - if err != nil { - log.Printf("Failed to connect to remote: %q\n", err) - return nil, nil, errProxy - } - - if resp.StatusCode != 200 { - resp.Body.Close() - log.Printf("Server responded with %d\n", resp.StatusCode) - if resp.StatusCode == 404 { - return nil, nil, errNotFound - } - return nil, nil, errProxy - } - - desc := storage.Description{ - Type: resp.Header.Get("Content-Type"), - Length: -1, - } - - length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - if err == nil { - desc.Length = length - } - - return resp.Body, &desc, nil -} - -// Given a http.ResponseWriter, attempt to force close the connection. -// -// This is useful if you get a fatal error after sending the initial 200 OK -// response. -func closeConnection(w http.ResponseWriter) { - log.Println("Attempting to close connection") - - // We attempt to bluntly close the connection because that is the - // best thing we can do after we've sent a 200 OK - hijack, ok := w.(http.Hijacker) - if ok { - conn, _, err := hijack.Hijack() - if err != nil { - fmt.Printf("Err trying to hijack: %v", err) - return - } - log.Println("Closing") - conn.Close() - return - } - log.Println("Not hijacker") } // Given a matrix server name, attempt to discover URLs to contact the server // on. -func getMatrixUrls(host string) []string { - _, srvs, err := net.LookupSRV("matrix", "tcp", host) +func getMatrixUrls(serverName types.ServerName) []string { + _, srvs, err := net.LookupSRV("matrix", "tcp", string(serverName)) if err != nil { - return []string{"https://" + host + ":8448"} + return []string{"https://" + string(serverName) + ":8448"} } results := make([]string, 0, len(srvs)) @@ -294,65 +405,3 @@ func getMatrixUrls(host string) []string { return results } - -// Given a path of the form '//' extract the host and name. -func getMediaIDFromPath(path string) (host, name string, err error) { - parts := strings.Split(path, "/") - if len(parts) != 3 { - err = fmt.Errorf("Invalid path %q", path) - return - } - - host, name = parts[1], parts[2] - - if host == "" || name == "" { - err = fmt.Errorf("Invalid path %q", path) - return - } - - return -} - -type responseWriter interface { - io.WriteCloser - setContentLength(string) - setContentSecurityPolicy() - setContentType(string) - haveWritten() bool -} - -type httpResponseWriter struct { - resp http.ResponseWriter - written bool -} - -func (writer httpResponseWriter) haveWritten() bool { - return writer.written -} - -func (writer httpResponseWriter) Write(p []byte) (n int, err error) { - writer.written = true - return writer.resp.Write(p) -} - -func (writer httpResponseWriter) Close() error { return nil } - -func (writer httpResponseWriter) setContentType(contentType string) { - writer.resp.Header().Set("Content-Type", contentType) -} - -func (writer httpResponseWriter) setContentLength(length string) { - writer.resp.Header().Set("Content-Length", length) -} - -func (writer httpResponseWriter) setContentSecurityPolicy() { - contentSecurityPolicy := "default-src 'none';" + - " script-src 'none';" + - " plugin-types application/pdf;" + - " style-src 'unsafe-inline';" + - " object-src 'self';" - writer.resp.Header().Set("Content-Security-Policy", contentSecurityPolicy) -} - -var errProxy = fmt.Errorf("Failed to contact remote") -var errNotFound = fmt.Errorf("Image not found") diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go index f68168840..7b6b4876d 100644 --- a/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/upload.go @@ -15,9 +15,14 @@ package writers import ( + "crypto/sha256" + "database/sql" + "encoding/base64" "fmt" "io" "net/http" + "os" + "path" "strings" log "github.com/Sirupsen/logrus" @@ -25,58 +30,54 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/util" ) -// UploadRequest metadata included in or derivable from an upload request +// uploadRequest metadata included in or derivable from an upload request // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-upload -// NOTE: ContentType is an HTTP request header and Filename is passed as a query parameter -type UploadRequest struct { - ContentDisposition string - ContentLength int64 - ContentType string - Filename string - Base64FileHash string - Method string - UserID string +// NOTE: The members come from HTTP request metadata such as headers, query parameters or can be derived from such +type uploadRequest struct { + MediaMetadata *types.MediaMetadata } -// Validate validates the UploadRequest fields -func (r UploadRequest) Validate() *util.JSONResponse { +// Validate validates the uploadRequest fields +func (r uploadRequest) Validate(maxFileSize types.ContentLength) *util.JSONResponse { // TODO: Any validation to be done on ContentDisposition? - if r.ContentLength < 1 { + + if r.MediaMetadata.ContentLength < 1 { return &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("HTTP Content-Length request header must be greater than zero."), + JSON: jsonerror.Unknown("HTTP Content-Length request header must be greater than zero."), + } + } + if maxFileSize > 0 && r.MediaMetadata.ContentLength > maxFileSize { + return &util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSize)), } } // TODO: Check if the Content-Type is a valid type? - if r.ContentType == "" { + if r.MediaMetadata.ContentType == "" { return &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("HTTP Content-Type request header must be set."), + JSON: jsonerror.Unknown("HTTP Content-Type request header must be set."), } } // TODO: Validate filename - what are the valid characters? - if r.Method != "POST" { - return &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("HTTP request method must be POST."), - } - } - if r.UserID != "" { + if r.MediaMetadata.UserID != "" { // TODO: We should put user ID parsing code into gomatrixserverlib and use that instead // (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 ) // It should be a struct (with pointers into a single string to avoid copying) and // we should update all refs to use UserID types rather than strings. // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92 - if len(r.UserID) == 0 || r.UserID[0] != '@' { + if len(r.MediaMetadata.UserID) == 0 || r.MediaMetadata.UserID[0] != '@' { return &util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON("user id must start with '@'"), + JSON: jsonerror.Unknown("user id must start with '@'"), } } - parts := strings.SplitN(r.UserID[1:], ":", 2) + parts := strings.SplitN(string(r.MediaMetadata.UserID[1:]), ":", 2) if len(parts) != 2 { return &util.JSONResponse{ Code: 400, @@ -93,9 +94,21 @@ type uploadResponse struct { } // Upload implements /upload -func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) util.JSONResponse { +// +// This endpoint involves uploading potentially significant amounts of data to the homeserver. +// This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. +// Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. +// TODO: Requests time out if they have not received any data within the configured timeout period. +func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database) util.JSONResponse { logger := util.GetLogger(req.Context()) + if req.Method != "POST" { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown("HTTP request method must be POST."), + } + } + // FIXME: This will require querying some other component/db but currently // just accepts a user id for auth userID, resErr := auth.VerifyAccessToken(req) @@ -103,74 +116,129 @@ func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database, repo * return *resErr } - r := &UploadRequest{ - ContentDisposition: req.Header.Get("Content-Disposition"), - ContentLength: req.ContentLength, - ContentType: req.Header.Get("Content-Type"), - Filename: req.FormValue("filename"), - Method: req.Method, - UserID: userID, + r := &uploadRequest{ + MediaMetadata: &types.MediaMetadata{ + Origin: cfg.ServerName, + ContentDisposition: types.ContentDisposition(req.Header.Get("Content-Disposition")), + ContentLength: types.ContentLength(req.ContentLength), + ContentType: types.ContentType(req.Header.Get("Content-Type")), + UploadName: types.Filename(req.FormValue("filename")), + UserID: types.MatrixUserID(userID), + }, } - if resErr = r.Validate(); resErr != nil { + // FIXME: if no Content-Disposition then set + + if resErr = r.Validate(cfg.MaxFileSize); resErr != nil { return *resErr } logger.WithFields(log.Fields{ - "ContentType": r.ContentType, - "Filename": r.Filename, - "UserID": r.UserID, + "Origin": r.MediaMetadata.Origin, + "UploadName": r.MediaMetadata.UploadName, + "Content-Length": r.MediaMetadata.ContentLength, + "Content-Type": r.MediaMetadata.ContentType, + "Content-Disposition": r.MediaMetadata.ContentDisposition, }).Info("Uploading file") - // TODO: Store file to disk - // - make path to file - // - progressive writing (could support Content-Length 0 and cut off - // after some max upload size is exceeded) - // - generate id (ideally a hash but a random string to start with) - writer, err := repo.WriterToLocalRepository(storage.Description{ - Type: r.ContentType, - }) + tmpDir, err := createTempDir(cfg.BasePath) if err != nil { - logger.Infof("Failed to get cache writer %q\n", err) + logger.Infof("Failed to create temp dir %q\n", err) return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), } } + file, writer, err := createFileWriter(tmpDir, "content") + defer file.Close() - defer writer.Close() + // The limited reader restricts how many bytes are read from the body to the specified maximum bytes + // Note: the golang HTTP server closes the request body + limitedBody := io.LimitReader(req.Body, int64(cfg.MaxFileSize)) + hasher := sha256.New() + reader := io.TeeReader(limitedBody, hasher) - if _, err = io.Copy(writer, req.Body); err != nil { + bytesWritten, err := io.Copy(writer, reader) + if err != nil { logger.Infof("Failed to copy %q\n", err) + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), } } - r.Base64FileHash, err = writer.Finished() - if err != nil { - return util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), - } + writer.Flush() + + if bytesWritten != int64(r.MediaMetadata.ContentLength) { + logger.Warnf("Bytes uploaded (%v) != claimed Content-Length (%v)", bytesWritten, r.MediaMetadata.ContentLength) + } + + hash := hasher.Sum(nil) + r.MediaMetadata.MediaID = types.MediaID(base64.URLEncoding.EncodeToString(hash[:])) + + logger.WithFields(log.Fields{ + "MediaID": r.MediaMetadata.MediaID, + "Origin": r.MediaMetadata.Origin, + "UploadName": r.MediaMetadata.UploadName, + "Content-Length": r.MediaMetadata.ContentLength, + "Content-Type": r.MediaMetadata.ContentType, + "Content-Disposition": r.MediaMetadata.ContentDisposition, + }).Info("File uploaded") + + // check if we already have a record of the media in our database and if so, we can remove the temporary directory + err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata) + if err == nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return util.JSONResponse{ + Code: 200, + JSON: uploadResponse{ + ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID), + }, + } + } else if err != nil && err != sql.ErrNoRows { + logger.Warnf("Failed to query database for %v: %q", r.MediaMetadata.MediaID, err) } - // TODO: check if file with hash already exists // TODO: generate thumbnails - err = db.CreateMedia(r.Base64FileHash, cfg.ServerName, r.ContentType, r.ContentDisposition, r.ContentLength, r.Filename, r.UserID) + err = db.StoreMediaMetadata(r.MediaMetadata) if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } return util.JSONResponse{ Code: 400, - JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), + } + } + + err = moveFile( + types.Path(path.Join(string(tmpDir), "content")), + types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)), + ) + if err != nil { + tmpDirErr := os.RemoveAll(string(tmpDir)) + if tmpDirErr != nil { + logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr) + } + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)), } } return util.JSONResponse{ Code: 200, JSON: uploadResponse{ - ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.Base64FileHash), + ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID), }, } } diff --git a/src/github.com/matrix-org/dendrite/mediaapi/writers/utils.go b/src/github.com/matrix-org/dendrite/mediaapi/writers/utils.go new file mode 100644 index 000000000..f38717389 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/mediaapi/writers/utils.go @@ -0,0 +1,80 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package writers + +import ( + "bufio" + "io/ioutil" + "os" + "path" + + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/dendrite/mediaapi/types" +) + +// createTempDir creates a tmp/ directory within baseDirectory and returns its path +func createTempDir(baseDirectory types.Path) (types.Path, error) { + baseTmpDir := path.Join(string(baseDirectory), "tmp") + err := os.MkdirAll(baseTmpDir, 0770) + if err != nil { + log.Printf("Failed to create base temp dir: %v\n", err) + return "", err + } + tmpDir, err := ioutil.TempDir(baseTmpDir, "") + if err != nil { + log.Printf("Failed to create temp dir: %v\n", err) + return "", err + } + return types.Path(tmpDir), nil +} + +// createFileWriter creates a buffered file writer with a new file at directory/filename +// Returns the file handle as it needs to be closed when writing is complete +func createFileWriter(directory types.Path, filename types.Filename) (*os.File, *bufio.Writer, error) { + filePath := path.Join(string(directory), string(filename)) + file, err := os.Create(filePath) + if err != nil { + log.Printf("Failed to create file: %v\n", err) + return nil, nil, err + } + + return file, bufio.NewWriter(file), nil +} + +func getPathFromMediaMetadata(m *types.MediaMetadata, basePath types.Path) string { + return path.Join( + string(basePath), + string(m.Origin), + string(m.MediaID[:3]), + string(m.MediaID[3:]), + ) +} + +// moveFile attempts to move the file src to dst +func moveFile(src types.Path, dst types.Path) error { + dstDir := path.Dir(string(dst)) + + err := os.MkdirAll(dstDir, 0770) + if err != nil { + log.Printf("Failed to make directory: %v\n", dstDir) + return err + } + err = os.Rename(string(src), string(dst)) + if err != nil { + log.Printf("Failed to move directory: %v to %v\n", src, dst) + return err + } + return nil +}