WIP: Refactoring

This commit is contained in:
Robert Swain 2017-05-09 19:49:39 +02:00
parent a24b3e7810
commit 7cf34af30b
12 changed files with 718 additions and 829 deletions

View file

@ -28,7 +28,7 @@ import (
var ( var (
bindAddr = os.Getenv("BIND_ADDRESS") bindAddr = os.Getenv("BIND_ADDRESS")
database = os.Getenv("DATABASE") dataSource = os.Getenv("DATABASE")
logDir = os.Getenv("LOG_DIR") logDir = os.Getenv("LOG_DIR")
) )
@ -42,7 +42,8 @@ func main() {
cfg := config.MediaAPI{ cfg := config.MediaAPI{
ServerName: "localhost", ServerName: "localhost",
BasePath: "/Users/robertsw/dendrite", BasePath: "/Users/robertsw/dendrite",
DataSource: database, MaxFileSize: 61440,
DataSource: dataSource,
} }
db, err := storage.Open(cfg.DataSource) db, err := storage.Open(cfg.DataSource)
@ -50,13 +51,8 @@ func main() {
log.Panicln("Failed to open database:", err) log.Panicln("Failed to open database:", err)
} }
repo := &storage.Repository{
StorePrefix: cfg.BasePath,
MaxBytes: 61440,
}
log.Info("Starting mediaapi") log.Info("Starting mediaapi")
routing.Setup(http.DefaultServeMux, http.DefaultClient, cfg, db, repo) routing.Setup(http.DefaultServeMux, http.DefaultClient, cfg, db)
log.Fatal(http.ListenAndServe(bindAddr, nil)) log.Fatal(http.ListenAndServe(bindAddr, nil))
} }

View file

@ -14,12 +14,17 @@
package config package config
import "github.com/matrix-org/dendrite/mediaapi/types"
// MediaAPI contains the config information necessary to spin up a mediaapi process. // MediaAPI contains the config information necessary to spin up a mediaapi process.
type MediaAPI struct { type MediaAPI struct {
// The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'.
ServerName string `yaml:"server_name"` ServerName types.ServerName `yaml:"server_name"`
// The base path to where media files will be stored. // The base path to where media files will be stored.
BasePath string `yaml:"base_path"` BasePath types.Path `yaml:"base_path"`
// The maximum file size in bytes that is allowed to be stored on this server.
// Note that remote files larger than this can still be proxied to a client, they will just not be cached.
MaxFileSize types.ContentLength `yaml:"base_path"`
// The postgres connection config for connecting to the database e.g a postgres:// URI // The postgres connection config for connecting to the database e.g a postgres:// URI
DataSource string `yaml:"database"` DataSource string `yaml:"database"`
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/mediaapi/writers" "github.com/matrix-org/dendrite/mediaapi/writers"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -27,45 +28,27 @@ import (
const pathPrefixR0 = "/_matrix/media/v1" const pathPrefixR0 = "/_matrix/media/v1"
type downloadRequestHandler struct { // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
Config config.MediaAPI // to clients which need to make outbound HTTP requests.
Database *storage.Database func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database) {
DownloadServer writers.DownloadServer apiMux := mux.NewRouter()
} r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
r0mux.Handle("/upload", make("upload", util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
return writers.Upload(req, cfg, db)
})))
func (handler downloadRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { r0mux.Handle("/download/{serverName}/{mediaId}",
prometheus.InstrumentHandler("download", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
util.SetupRequestLogging(req) util.SetupRequestLogging(req)
// Set common headers returned regardless of the outcome of the request // Set common headers returned regardless of the outcome of the request
util.SetCORSHeaders(w) util.SetCORSHeaders(w)
// TODO: fix comment
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
vars := mux.Vars(req) vars := mux.Vars(req)
writers.Download(w, req, vars["serverName"], vars["mediaId"], handler.Config, handler.Database, handler.DownloadServer) writers.Download(w, req, types.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db)
} })),
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
// to clients which need to make outbound HTTP requests.
func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) {
apiMux := mux.NewRouter()
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
r0mux.Handle("/upload", make("upload", util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse {
return writers.Upload(req, cfg, db, repo)
})))
downloadServer := writers.DownloadServer{
Repository: *repo,
LocalServerName: cfg.ServerName,
}
handler := downloadRequestHandler{
Config: cfg,
Database: db,
DownloadServer: downloadServer,
}
r0mux.Handle("/download/{serverName}/{mediaId}",
prometheus.InstrumentHandler("download", handler),
) )
servMux.Handle("/metrics", prometheus.Handler()) servMux.Handle("/metrics", prometheus.Handler())

View file

@ -1,92 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"os"
log "github.com/Sirupsen/logrus"
)
// LimitedFileWriter writes only a limited number of bytes to a file.
//
// If the callee attempts to write more bytes the file is deleted and further
// writes are silently discarded.
//
// This isn't thread safe.
type LimitedFileWriter struct {
filePath string
file *os.File
writtenBytes uint64
maxBytes uint64
}
// NewLimitedFileWriter creates a new LimitedFileWriter at the given location.
//
// If a file already exists at the location it is immediately truncated.
//
// A maxBytes of 0 or negative is treated as no limit.
func NewLimitedFileWriter(filePath string, maxBytes uint64) (*LimitedFileWriter, error) {
file, err := os.Create(filePath)
if err != nil {
return nil, err
}
writer := LimitedFileWriter{
filePath: filePath,
file: file,
maxBytes: maxBytes,
}
return &writer, nil
}
// Close closes the underlying file descriptor, if its open.
//
// Any error comes from File.Close
func (writer *LimitedFileWriter) Close() error {
if writer.file != nil {
file := writer.file
writer.file = nil
return file.Close()
}
return nil
}
func (writer *LimitedFileWriter) Write(p []byte) (n int, err error) {
if writer.maxBytes > 0 && uint64(len(p))+writer.writtenBytes > writer.maxBytes {
if writer.file != nil {
writer.Close()
err = os.Remove(writer.filePath)
if err != nil {
log.Printf("Failed to delete file %v\n", err)
}
}
return 0, fmt.Errorf("Reached limit")
}
if writer.file != nil {
n, err = writer.file.Write(p)
writer.writtenBytes += uint64(n)
if err != nil {
log.Printf("Failed to write to file %v\n", err)
}
}
return
}

View file

@ -1,89 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
"time"
)
const mediaSchema = `
-- The events table holds metadata for each media upload to the local server,
-- the actual file is stored separately.
CREATE TABLE IF NOT EXISTS media_repository (
-- The id used to refer to the media.
-- This is a base64-encoded sha256 hash of the file data
media_id TEXT PRIMARY KEY,
-- The origin of the media as requested by the client.
media_origin TEXT NOT NULL,
-- The MIME-type of the media file.
content_type TEXT NOT NULL,
-- The HTTP Content-Disposition header for the media file.
content_disposition TEXT NOT NULL,
-- Size of the media file in bytes.
file_size BIGINT NOT NULL,
-- When the content was uploaded in ms.
created_ts BIGINT NOT NULL,
-- The name with which the media was uploaded.
upload_name TEXT NOT NULL,
-- The user who uploaded the file.
user_id TEXT NOT NULL,
UNIQUE(media_id, media_origin)
);
`
const insertMediaSQL = `
INSERT INTO media_repository (media_id, media_origin, content_type, content_disposition, file_size, created_ts, upload_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
const selectMediaSQL = `
SELECT content_type, content_disposition, file_size, upload_name FROM media_repository WHERE media_id = $1 AND media_origin = $2
`
type mediaStatements struct {
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
if err != nil {
return
}
return statementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
}.prepare(db)
}
func (s *mediaStatements) insertMedia(mediaID string, mediaOrigin string, contentType string,
contentDisposition string, fileSize int64, uploadName string, userID string) error {
_, err := s.insertMediaStmt.Exec(
mediaID, mediaOrigin, contentType, contentDisposition, fileSize,
int64(time.Now().UnixNano()/1000000), uploadName, userID,
)
return err
}
func (s *mediaStatements) selectMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) {
var contentType string
var contentDisposition string
var fileSize int64
var uploadName string
err := s.selectMediaStmt.QueryRow(mediaID, mediaOrigin).Scan(&contentType, &contentDisposition, &fileSize, &uploadName)
return string(contentType), string(contentDisposition), int64(fileSize), string(uploadName), err
}

View file

@ -0,0 +1,107 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/mediaapi/types"
)
const mediaSchema = `
-- The media_repository table holds metadata for each media file stored and accessible to the local server,
-- the actual file is stored separately.
CREATE TABLE IF NOT EXISTS media_repository (
-- The id used to refer to the media.
-- For uploads to this server this is a base64-encoded sha256 hash of the file data
-- For media from remote servers, this can be any unique identifier string
media_id TEXT NOT NULL,
-- The origin of the media as requested by the client. Should be a homeserver domain.
media_origin TEXT NOT NULL,
-- The MIME-type of the media file as specified when uploading.
content_type TEXT NOT NULL,
-- The HTTP Content-Disposition header for the media file as specified when uploading.
content_disposition TEXT NOT NULL,
-- Size of the media file in bytes.
content_length BIGINT NOT NULL,
-- When the content was uploaded in UNIX epoch ms.
creation_ts BIGINT NOT NULL,
-- The file name with which the media was uploaded.
upload_name TEXT NOT NULL,
-- The user who uploaded the file. Should be a Matrix user ID.
user_id TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS media_repository_index ON media_repository (media_id, media_origin);
`
const insertMediaSQL = `
INSERT INTO media_repository (media_id, media_origin, content_type, content_disposition, content_length, creation_ts, upload_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
const selectMediaSQL = `
SELECT content_type, content_disposition, content_length, creation_ts, upload_name, user_id FROM media_repository WHERE media_id = $1 AND media_origin = $2
`
type mediaStatements struct {
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
if err != nil {
return
}
return statementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
}.prepare(db)
}
func (s *mediaStatements) insertMedia(mediaMetadata *types.MediaMetadata) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.Exec(
mediaMetadata.MediaID,
mediaMetadata.Origin,
mediaMetadata.ContentType,
mediaMetadata.ContentDisposition,
mediaMetadata.ContentLength,
mediaMetadata.CreationTimestamp,
mediaMetadata.UploadName,
mediaMetadata.UserID,
)
return err
}
func (s *mediaStatements) selectMedia(mediaID types.MediaID, mediaOrigin types.ServerName) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRow(
mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
&mediaMetadata.ContentDisposition,
&mediaMetadata.ContentLength,
&mediaMetadata.CreationTimestamp,
&mediaMetadata.UploadName,
&mediaMetadata.UserID,
)
return &mediaMetadata, err
}

View file

@ -1,283 +0,0 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"crypto/sha256"
"encoding/base64"
"hash"
"io"
"io/ioutil"
"os"
"path"
log "github.com/Sirupsen/logrus"
)
// Description contains various attributes for an image.
type Description struct {
Type string
Length int64
}
type repositoryPaths struct {
contentPath string
typePath string
}
// Repository stores locally uploaded media, and caches remote media that has
// been requested.
type Repository struct {
StorePrefix string
MaxBytes uint64
}
// ReaderFromRemoteCache returns a io.ReadCloser with the cached remote content,
// if it exists. Use IsNotExist to check if the error was due to it not existing
// in the cache
func (repo Repository) ReaderFromRemoteCache(host, name string) (io.ReadCloser, *Description, error) {
mediaDir := repo.getDirForRemoteMedia(host, name)
repoPaths := getPathsForMedia(mediaDir)
return repo.readerFromRepository(repoPaths)
}
// ReaderFromLocalRepo returns a io.ReadCloser with the locally uploaded content,
// if it exists. Use IsNotExist to check if the error was due to it not existing
// in the cache
func (repo Repository) ReaderFromLocalRepo(name string) (io.ReadCloser, *Description, error) {
mediaDir := repo.getDirForLocalMedia(name)
repoPaths := getPathsForMedia(mediaDir)
return repo.readerFromRepository(repoPaths)
}
func (repo Repository) readerFromRepository(repoPaths repositoryPaths) (io.ReadCloser, *Description, error) {
contentTypeBytes, err := ioutil.ReadFile(repoPaths.typePath)
if err != nil {
return nil, nil, err
}
contentType := string(contentTypeBytes)
file, err := os.Open(repoPaths.contentPath)
if err != nil {
return nil, nil, err
}
stat, err := file.Stat()
if err != nil {
return nil, nil, err
}
descr := Description{
Type: contentType,
Length: stat.Size(),
}
return file, &descr, nil
}
// WriterToLocalRepository returns a RepositoryWriter for writing newly uploaded
// content into the repository.
//
// The returned RepositoryWriter will fail if more than MaxBytes tries to be
// written.
func (repo Repository) WriterToLocalRepository(descr Description) (RepositoryWriter, error) {
return newLocalRepositoryWriter(repo, descr)
}
// WriterToRemoteCache returns a RepositoryWriter for caching newly downloaded
// remote content.
//
// The returned RepositoryWriter will silently stop writing if more than MaxBytes
// tries to be written and does *not* return an error.
func (repo Repository) WriterToRemoteCache(host, name string, descr Description) (RepositoryWriter, error) {
return newRemoteRepositoryWriter(repo, host, name, descr)
}
func (repo *Repository) makeTempDir() (string, error) {
tmpDir := path.Join(repo.StorePrefix, "tmp")
os.MkdirAll(tmpDir, 0770)
return ioutil.TempDir(tmpDir, "")
}
func (repo *Repository) getDirForLocalMedia(name string) string {
return path.Join(repo.StorePrefix, "local", name[:3], name[3:])
}
func (repo *Repository) getDirForRemoteMedia(host, sanitizedName string) string {
return path.Join(repo.StorePrefix, "remote", host, sanitizedName[:3], sanitizedName[3:])
}
// Get the actual paths for the data and metadata associated with remote media.
func getPathsForMedia(dir string) repositoryPaths {
contentPath := path.Join(dir, "content")
typePath := path.Join(dir, "type")
return repositoryPaths{
contentPath: contentPath,
typePath: typePath,
}
}
// IsNotExists check if error was due to content not existing in cache.
func IsNotExists(err error) bool { return os.IsNotExist(err) }
// RepositoryWriter is used to either store into the repository newly uploaded
// media or to cache recently fetched remote media.
type RepositoryWriter interface {
io.WriteCloser
// Finished should be called when successfully finished writing; otherwise
// the written content will not be committed to the repository.
Finished() (string, error)
}
type remoteRepositoryWriter struct {
tmpDir string
finalDir string
name string
file io.WriteCloser
erred bool
}
func newRemoteRepositoryWriter(repo Repository, host, name string, descr Description) (*remoteRepositoryWriter, error) {
tmpFile, tmpDir, err := getTempWriter(repo, descr)
if err != nil {
log.Printf("Failed to create writer: %v\n", err)
return nil, err
}
return &remoteRepositoryWriter{
tmpDir: tmpDir,
finalDir: repo.getDirForRemoteMedia(host, name),
name: name,
file: tmpFile,
erred: false,
}, nil
}
func (writer remoteRepositoryWriter) Write(p []byte) (int, error) {
// Its OK to fail when writing to the remote repo. We just hide the error
// from the layers above
if !writer.erred {
if _, err := writer.file.Write(p); err != nil {
writer.erred = true
}
}
return len(p), nil
}
func (writer remoteRepositoryWriter) Close() error {
os.RemoveAll(writer.tmpDir)
writer.file.Close()
return nil
}
func (writer remoteRepositoryWriter) Finished() (string, error) {
var err error
if !writer.erred {
os.MkdirAll(path.Dir(writer.finalDir), 0770)
err = os.Rename(writer.tmpDir, writer.finalDir)
if err != nil {
return "", err
}
}
err = writer.Close()
return writer.name, err
}
type localRepositoryWriter struct {
repo Repository
tmpDir string
hasher hash.Hash
file io.WriteCloser
finished bool
}
func newLocalRepositoryWriter(repo Repository, descr Description) (*localRepositoryWriter, error) {
tmpFile, tmpDir, err := getTempWriter(repo, descr)
if err != nil {
return nil, err
}
return &localRepositoryWriter{
repo: repo,
tmpDir: tmpDir,
hasher: sha256.New(),
file: tmpFile,
finished: false,
}, nil
}
func (writer localRepositoryWriter) Write(p []byte) (int, error) {
writer.hasher.Write(p) // Never errors.
n, err := writer.file.Write(p)
if err != nil {
writer.Close()
}
return n, err
}
func (writer localRepositoryWriter) Close() error {
var err error
if !writer.finished {
err = os.RemoveAll(writer.tmpDir)
if err != nil {
return err
}
}
err = writer.file.Close()
return err
}
func (writer localRepositoryWriter) Finished() (string, error) {
hash := writer.hasher.Sum(nil)
name := base64.URLEncoding.EncodeToString(hash[:])
finalDir := writer.repo.getDirForLocalMedia(name)
os.MkdirAll(path.Dir(finalDir), 0770)
err := os.Rename(writer.tmpDir, finalDir)
if err != nil {
log.Println("Failed to move temp directory:", writer.tmpDir, finalDir, err)
return "", err
}
writer.finished = true
writer.Close()
return name, nil
}
func getTempWriter(repo Repository, descr Description) (io.WriteCloser, string, error) {
tmpDir, err := repo.makeTempDir()
if err != nil {
log.Printf("Failed to create temp dir: %v\n", err)
return nil, "", err
}
repoPaths := getPathsForMedia(tmpDir)
if err = ioutil.WriteFile(repoPaths.typePath, []byte(descr.Type), 0660); err != nil {
log.Printf("Failed to create typeFile: %q\n", err)
return nil, "", err
}
tmpFile, err := NewLimitedFileWriter(repoPaths.contentPath, repo.MaxBytes)
if err != nil {
log.Printf("Failed to create limited file: %v\n", err)
return nil, "", err
}
return tmpFile, tmpDir, nil
}

View file

@ -19,6 +19,7 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/mediaapi/types"
) )
// A Database is used to store room events and stream offsets. // A Database is used to store room events and stream offsets.
@ -40,12 +41,19 @@ func Open(dataSourceName string) (*Database, error) {
return &d, nil return &d, nil
} }
// CreateMedia inserts the metadata about the uploaded media into the database. // StoreMediaMetadata inserts the metadata about the uploaded media into the database.
func (d *Database) CreateMedia(mediaID string, mediaOrigin string, contentType string, contentDisposition string, fileSize int64, uploadName string, userID string) error { func (d *Database) StoreMediaMetadata(mediaMetadata *types.MediaMetadata) error {
return d.statements.insertMedia(mediaID, mediaOrigin, contentType, contentDisposition, fileSize, uploadName, userID) return d.statements.insertMedia(mediaMetadata)
} }
// GetMedia possibly selects the metadata about previously uploaded media from the database. // GetMediaMetadata possibly selects the metadata about previously uploaded media from the database.
func (d *Database) GetMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) { func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin types.ServerName, mediaMetadata *types.MediaMetadata) error {
return d.statements.selectMedia(mediaID, mediaOrigin) metadata, err := d.statements.selectMedia(mediaID, mediaOrigin)
mediaMetadata.ContentType = metadata.ContentType
mediaMetadata.ContentDisposition = metadata.ContentDisposition
mediaMetadata.ContentLength = metadata.ContentLength
mediaMetadata.CreationTimestamp = metadata.CreationTimestamp
mediaMetadata.UploadName = metadata.UploadName
mediaMetadata.UserID = metadata.UserID
return err
} }

View file

@ -0,0 +1,57 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
// ContentDisposition is an HTTP Content-Disposition header string
type ContentDisposition string
// ContentLength is an HTTP Content-Length header which is a number of bytes to be expected in a request body
type ContentLength int64
// ContentType is an HTTP Content-Type header string representing the MIME type of a request body
type ContentType string
// Filename is a string representing the name of a file
type Filename string
// Path is an absolute or relative UNIX filesystem path
type Path string
// MediaID is a string representing the unique identifier for a file (could be a hash but does not have to be)
type MediaID string
// ServerName is the host of a matrix homeserver, e.g. matrix.org
type ServerName string
// RequestMethod is an HTTP request method i.e. GET, POST, etc
type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file
type MediaMetadata struct {
MediaID MediaID
Origin ServerName
ContentType ContentType
ContentDisposition ContentDisposition
ContentLength ContentLength
CreationTimestamp UnixMs
UploadName Filename
UserID MatrixUserID
}

View file

@ -15,11 +15,14 @@
package writers package writers
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"path"
"strconv" "strconv"
"strings" "strings"
@ -27,30 +30,30 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// DownloadRequest metadata included in or derivable from an upload request // downloadRequest metadata included in or derivable from an download request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download
type DownloadRequest struct { type downloadRequest struct {
MediaID string MediaMetadata *types.MediaMetadata
ServerName string
} }
// Validate validates the DownloadRequest fields // Validate validates the downloadRequest fields
func (r DownloadRequest) Validate() *util.JSONResponse { func (r downloadRequest) Validate() *util.JSONResponse {
// FIXME: the following errors aren't bad JSON, rather just a bad request path // FIXME: the following errors aren't bad JSON, rather just a bad request path
// maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...? // maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...?
if r.MediaID == "" { if r.MediaMetadata.MediaID == "" {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 404,
JSON: jsonerror.BadJSON("mediaId must be a non-empty string"), JSON: jsonerror.NotFound("mediaId must be a non-empty string"),
} }
} }
if r.ServerName == "" { if r.MediaMetadata.Origin == "" {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 404,
JSON: jsonerror.BadJSON("serverName must be a non-empty string"), JSON: jsonerror.NotFound("serverName must be a non-empty string"),
} }
} }
return nil return nil
@ -72,10 +75,21 @@ func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log
w.Write(resBytes) w.Write(resBytes)
} }
var errFileIsTooLarge = fmt.Errorf("file is too large")
var errRead = fmt.Errorf("failed to read response from remote server")
var errResponse = fmt.Errorf("failed to write file data to response body")
var errWrite = fmt.Errorf("failed to write file to disk")
// Download implements /download // Download implements /download
func Download(w http.ResponseWriter, req *http.Request, serverName string, mediaID string, cfg config.MediaAPI, db *storage.Database, downloadServer DownloadServer) { // Files from this server (i.e. origin == cfg.ServerName) are served directly
// Files from remote servers (i.e. origin != cfg.ServerName) are cached locally.
// If they are present in the cache, they are served directly.
// If they are not present in the cache, they are obtained from the remote server and
// simultaneously served back to the client and written into the cache.
func Download(w http.ResponseWriter, req *http.Request, origin types.ServerName, mediaID types.MediaID, cfg config.MediaAPI, db *storage.Database) {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
// request validation
if req.Method != "GET" { if req.Method != "GET" {
jsonErrorResponse(w, util.JSONResponse{ jsonErrorResponse(w, util.JSONResponse{
Code: 405, Code: 405,
@ -84,9 +98,11 @@ func Download(w http.ResponseWriter, req *http.Request, serverName string, media
return return
} }
r := &DownloadRequest{ r := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID, MediaID: mediaID,
ServerName: serverName, Origin: origin,
},
} }
if resErr := r.Validate(); resErr != nil { if resErr := r.Validate(); resErr != nil {
@ -94,190 +110,285 @@ func Download(w http.ResponseWriter, req *http.Request, serverName string, media
return return
} }
contentType, contentDisposition, fileSize, filename, err := db.GetMedia(r.MediaID, r.ServerName) // check if we have a record of the media in our database
if err != nil && strings.Compare(r.ServerName, cfg.ServerName) == 0 { err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata)
if err == nil {
// If we have a record, we can respond from the local file
logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
"Content-Length": r.MediaMetadata.ContentLength,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Infof("Downloading file")
filePath := getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)
file, err := os.Open(filePath)
if err != nil {
// FIXME: Remove erroneous file from database?
jsonErrorResponse(w, util.JSONResponse{ jsonErrorResponse(w, util.JSONResponse{
Code: 404, Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File %q does not exist", r.MediaID)), JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger) }, logger)
return return
} }
// - read file and respond stat, err := file.Stat()
if err != nil {
// FIXME: Remove erroneous file from database?
jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger)
return
}
if r.MediaMetadata.ContentLength > 0 && int64(r.MediaMetadata.ContentLength) != stat.Size() {
logger.Warnf("File size in database (%v) and on disk (%v) differ.", r.MediaMetadata.ContentLength, stat.Size())
// FIXME: Remove erroneous file from database?
}
w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType))
w.Header().Set("Content-Length", strconv.FormatInt(stat.Size(), 10))
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if bytesResponded, err := io.Copy(w, file); err != nil {
logger.Warnf("Failed to copy from cache %v\n", err)
if bytesResponded == 0 {
jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)),
}, logger)
}
// If we have written any data then we have already responded with 200 OK and all we can do is close the connection
return
}
} else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName {
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"MediaID": r.MediaID, "MediaID": r.MediaMetadata.MediaID,
"ServerName": r.ServerName, "Origin": r.MediaMetadata.Origin,
"Filename": filename, "UploadName": r.MediaMetadata.UploadName,
"Content-Type": contentType, "Content-Length": r.MediaMetadata.ContentLength,
"Content-Disposition": contentDisposition, "Content-Type": r.MediaMetadata.ContentType,
}).Infof("Downloading file") "Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Infof("Fetching remote file")
logger.WithField("code", 200).Infof("Responding (%d bytes)", fileSize) // TODO: lock request in hash set
respWriter := httpResponseWriter{resp: w} // FIXME: Only request once (would race if multiple requests for the same remote file)
if err = downloadServer.getImage(respWriter, r.ServerName, r.MediaID); err != nil { // Use a hash set based on the origin and media ID (the request URL should be fine...) and synchronise adding / removing members
if respWriter.haveWritten() { urls := getMatrixUrls(r.MediaMetadata.Origin)
closeConnection(w)
logger.Printf("Connecting to remote %q\n", urls[0])
remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil)
if err != nil {
jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
return return
} }
errStatus := 500 remoteReq.Header.Set("Host", string(r.MediaMetadata.Origin))
switch err {
case errNotFound: client := http.Client{}
errStatus = 404 resp, err := client.Do(remoteReq)
case errProxy: if err != nil {
errStatus = 502 jsonErrorResponse(w, util.JSONResponse{
} Code: 502,
http.Error(w, err.Error(), errStatus) JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
return return
} }
defer resp.Body.Close()
return
}
// DownloadServer serves and caches remote media.
type DownloadServer struct {
Client http.Client
Repository storage.Repository
LocalServerName string
}
func (handler *DownloadServer) getImage(w responseWriter, host, name string) error {
var file io.ReadCloser
var descr *storage.Description
var err error
if host == handler.LocalServerName {
file, descr, err = handler.Repository.ReaderFromLocalRepo(name)
} else {
file, descr, err = handler.Repository.ReaderFromRemoteCache(host, name)
}
if err == nil {
log.Println("Found in Cache")
w.setContentType(descr.Type)
size := strconv.FormatInt(descr.Length, 10)
w.setContentLength(size)
w.setContentSecurityPolicy()
if _, err = io.Copy(w, file); err != nil {
log.Printf("Failed to copy from cache %v\n", err)
return err
}
w.Close()
return nil
} else if !storage.IsNotExists(err) {
log.Printf("Error looking in cache: %v\n", err)
return err
}
if host == handler.LocalServerName {
// Its fatal if we can't find local files in our cache.
return errNotFound
}
respBody, desc, err := handler.fetchRemoteMedia(host, name)
if err != nil {
return err
}
defer respBody.Close()
w.setContentType(desc.Type)
if desc.Length > 0 {
w.setContentLength(strconv.FormatInt(desc.Length, 10))
}
writer, err := handler.Repository.WriterToRemoteCache(host, name, *desc)
if err != nil {
log.Printf("Failed to get cache writer %q\n", err)
return err
}
defer writer.Close()
reader := io.TeeReader(respBody, w)
if _, err := io.Copy(writer, reader); err != nil {
log.Printf("Failed to copy %q\n", err)
return err
}
writer.Finished()
log.Println("Finished conn")
return nil
}
func (handler *DownloadServer) fetchRemoteMedia(host, name string) (io.ReadCloser, *storage.Description, error) {
urls := getMatrixUrls(host)
log.Printf("Connecting to remote %q\n", urls[0])
remoteReq, err := http.NewRequest("GET", urls[0]+"/_matrix/media/v1/download/"+host+"/"+name, nil)
if err != nil {
log.Printf("Failed to connect to remote: %q\n", err)
return nil, nil, err
}
remoteReq.Header.Set("Host", host)
resp, err := handler.Client.Do(remoteReq)
if err != nil {
log.Printf("Failed to connect to remote: %q\n", err)
return nil, nil, errProxy
}
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
resp.Body.Close() logger.Printf("Server responded with %d\n", resp.StatusCode)
log.Printf("Server responded with %d\n", resp.StatusCode)
if resp.StatusCode == 404 { if resp.StatusCode == 404 {
return nil, nil, errNotFound jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger)
return
} }
return nil, nil, errProxy jsonErrorResponse(w, util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
return
} }
desc := storage.Description{ contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
Type: resp.Header.Get("Content-Type"), if err != nil {
Length: -1, logger.Warn("Failed to parse content length")
}
r.MediaMetadata.ContentLength = types.ContentLength(contentLength)
w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType))
w.Header().Set("Content-Length", strconv.FormatInt(int64(r.MediaMetadata.ContentLength), 10))
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
tmpDir, err := createTempDir(cfg.BasePath)
if err != nil {
logger.Infof("Failed to create temp dir %q\n", err)
jsonErrorResponse(w, util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
}, logger)
return
}
tmpFile, writer, err := createFileWriter(tmpDir, types.Filename(r.MediaMetadata.MediaID[3:]))
if err != nil {
logger.Infof("Failed to create file writer %q\n", err)
jsonErrorResponse(w, util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
}, logger)
return
}
defer tmpFile.Close()
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
var bytesResponded, bytesWritten int64 = 0, 0
var fetchError error
// Note: the buffer size is the same as is used in io.Copy()
buffer := make([]byte, 32*1024)
for {
// read from remote request's response body
bytesRead, readErr := resp.Body.Read(buffer)
if bytesRead > 0 {
// write to client request's response body
bytesTemp, respErr := w.Write(buffer[:bytesRead])
if bytesTemp != bytesRead || (respErr != nil && respErr != io.EOF) {
// TODO: BORKEN
logger.Errorf("bytesTemp %v != bytesRead %v : %v", bytesTemp, bytesRead, respErr)
fetchError = errResponse
break
}
bytesResponded += int64(bytesTemp)
if fetchError == nil || (fetchError != errFileIsTooLarge && fetchError != errWrite) {
// if larger than cfg.MaxFileSize then stop writing to disk and discard cached file
if bytesWritten+int64(len(buffer)) > int64(cfg.MaxFileSize) {
// TODO: WAAAAHNING and clean up temp files
fetchError = errFileIsTooLarge
} else {
// write to disk
bytesTemp, writeErr := writer.Write(buffer)
if writeErr != nil && writeErr != io.EOF {
// TODO: WAAAAHNING and clean up temp files
fetchError = errWrite
} else {
bytesWritten += int64(bytesTemp)
}
}
}
}
if readErr != nil {
if readErr != io.EOF {
fetchError = errRead
break
}
}
} }
length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) writer.Flush()
if err == nil {
desc.Length = length if fetchError != nil {
logFields := log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
} }
if fetchError == errFileIsTooLarge {
return resp.Body, &desc, nil logFields["MaxFileSize"] = cfg.MaxFileSize
} }
logger.WithFields(logFields).Warnln(fetchError)
// Given a http.ResponseWriter, attempt to force close the connection. tmpDirErr := os.RemoveAll(string(tmpDir))
// if tmpDirErr != nil {
// This is useful if you get a fatal error after sending the initial 200 OK logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
// response. }
func closeConnection(w http.ResponseWriter) { // Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point
log.Println("Attempting to close connection") if bytesResponded < 1 {
jsonErrorResponse(w, util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}, logger)
} else {
// We attempt to bluntly close the connection because that is the // We attempt to bluntly close the connection because that is the
// best thing we can do after we've sent a 200 OK // best thing we can do after we've sent a 200 OK
hijack, ok := w.(http.Hijacker) logger.Println("Attempting to close the connection.")
hijacker, ok := w.(http.Hijacker)
if ok { if ok {
conn, _, err := hijack.Hijack() connection, _, hijackErr := hijacker.Hijack()
if hijackErr == nil {
logger.Println("Closing")
connection.Close()
} else {
logger.Printf("Error trying to hijack: %v", hijackErr)
}
}
}
return
}
// Note: After this point we have responded to the client's request and are just dealing with local caching.
// As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return.
// if written to disk, add to db
err = db.StoreMediaMetadata(r.MediaMetadata)
if err != nil { if err != nil {
fmt.Printf("Err trying to hijack: %v", err) tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return return
} }
log.Println("Closing")
conn.Close() // TODO: unlock request in hash set
// TODO: generate thumbnails
err = moveFile(
types.Path(path.Join(string(tmpDir), "content")),
types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)),
)
if err != nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return return
} }
log.Println("Not hijacker") } else {
// TODO: If we do not have a record and the origin is local, or if we have another error from the database, the file is not found
jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}, logger)
}
} }
// Given a matrix server name, attempt to discover URLs to contact the server // Given a matrix server name, attempt to discover URLs to contact the server
// on. // on.
func getMatrixUrls(host string) []string { func getMatrixUrls(serverName types.ServerName) []string {
_, srvs, err := net.LookupSRV("matrix", "tcp", host) _, srvs, err := net.LookupSRV("matrix", "tcp", string(serverName))
if err != nil { if err != nil {
return []string{"https://" + host + ":8448"} return []string{"https://" + string(serverName) + ":8448"}
} }
results := make([]string, 0, len(srvs)) results := make([]string, 0, len(srvs))
@ -294,65 +405,3 @@ func getMatrixUrls(host string) []string {
return results return results
} }
// Given a path of the form '/<host>/<name>' extract the host and name.
func getMediaIDFromPath(path string) (host, name string, err error) {
parts := strings.Split(path, "/")
if len(parts) != 3 {
err = fmt.Errorf("Invalid path %q", path)
return
}
host, name = parts[1], parts[2]
if host == "" || name == "" {
err = fmt.Errorf("Invalid path %q", path)
return
}
return
}
type responseWriter interface {
io.WriteCloser
setContentLength(string)
setContentSecurityPolicy()
setContentType(string)
haveWritten() bool
}
type httpResponseWriter struct {
resp http.ResponseWriter
written bool
}
func (writer httpResponseWriter) haveWritten() bool {
return writer.written
}
func (writer httpResponseWriter) Write(p []byte) (n int, err error) {
writer.written = true
return writer.resp.Write(p)
}
func (writer httpResponseWriter) Close() error { return nil }
func (writer httpResponseWriter) setContentType(contentType string) {
writer.resp.Header().Set("Content-Type", contentType)
}
func (writer httpResponseWriter) setContentLength(length string) {
writer.resp.Header().Set("Content-Length", length)
}
func (writer httpResponseWriter) setContentSecurityPolicy() {
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
writer.resp.Header().Set("Content-Security-Policy", contentSecurityPolicy)
}
var errProxy = fmt.Errorf("Failed to contact remote")
var errNotFound = fmt.Errorf("Image not found")

View file

@ -15,9 +15,14 @@
package writers package writers
import ( import (
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"path"
"strings" "strings"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -25,58 +30,54 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
// UploadRequest metadata included in or derivable from an upload request // uploadRequest metadata included in or derivable from an upload request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-upload // https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-upload
// NOTE: ContentType is an HTTP request header and Filename is passed as a query parameter // NOTE: The members come from HTTP request metadata such as headers, query parameters or can be derived from such
type UploadRequest struct { type uploadRequest struct {
ContentDisposition string MediaMetadata *types.MediaMetadata
ContentLength int64
ContentType string
Filename string
Base64FileHash string
Method string
UserID string
} }
// Validate validates the UploadRequest fields // Validate validates the uploadRequest fields
func (r UploadRequest) Validate() *util.JSONResponse { func (r uploadRequest) Validate(maxFileSize types.ContentLength) *util.JSONResponse {
// TODO: Any validation to be done on ContentDisposition? // TODO: Any validation to be done on ContentDisposition?
if r.ContentLength < 1 {
if r.MediaMetadata.ContentLength < 1 {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON("HTTP Content-Length request header must be greater than zero."), JSON: jsonerror.Unknown("HTTP Content-Length request header must be greater than zero."),
}
}
if maxFileSize > 0 && r.MediaMetadata.ContentLength > maxFileSize {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSize)),
} }
} }
// TODO: Check if the Content-Type is a valid type? // TODO: Check if the Content-Type is a valid type?
if r.ContentType == "" { if r.MediaMetadata.ContentType == "" {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON("HTTP Content-Type request header must be set."), JSON: jsonerror.Unknown("HTTP Content-Type request header must be set."),
} }
} }
// TODO: Validate filename - what are the valid characters? // TODO: Validate filename - what are the valid characters?
if r.Method != "POST" { if r.MediaMetadata.UserID != "" {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("HTTP request method must be POST."),
}
}
if r.UserID != "" {
// TODO: We should put user ID parsing code into gomatrixserverlib and use that instead // TODO: We should put user ID parsing code into gomatrixserverlib and use that instead
// (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 ) // (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 )
// It should be a struct (with pointers into a single string to avoid copying) and // It should be a struct (with pointers into a single string to avoid copying) and
// we should update all refs to use UserID types rather than strings. // we should update all refs to use UserID types rather than strings.
// https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92 // https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92
if len(r.UserID) == 0 || r.UserID[0] != '@' { if len(r.MediaMetadata.UserID) == 0 || r.MediaMetadata.UserID[0] != '@' {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON("user id must start with '@'"), JSON: jsonerror.Unknown("user id must start with '@'"),
} }
} }
parts := strings.SplitN(r.UserID[1:], ":", 2) parts := strings.SplitN(string(r.MediaMetadata.UserID[1:]), ":", 2)
if len(parts) != 2 { if len(parts) != 2 {
return &util.JSONResponse{ return &util.JSONResponse{
Code: 400, Code: 400,
@ -93,9 +94,21 @@ type uploadResponse struct {
} }
// Upload implements /upload // Upload implements /upload
func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) util.JSONResponse { //
// This endpoint involves uploading potentially significant amounts of data to the homeserver.
// This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large.
// Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory.
// TODO: Requests time out if they have not received any data within the configured timeout period.
func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database) util.JSONResponse {
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
if req.Method != "POST" {
return util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("HTTP request method must be POST."),
}
}
// FIXME: This will require querying some other component/db but currently // FIXME: This will require querying some other component/db but currently
// just accepts a user id for auth // just accepts a user id for auth
userID, resErr := auth.VerifyAccessToken(req) userID, resErr := auth.VerifyAccessToken(req)
@ -103,74 +116,129 @@ func Upload(req *http.Request, cfg config.MediaAPI, db *storage.Database, repo *
return *resErr return *resErr
} }
r := &UploadRequest{ r := &uploadRequest{
ContentDisposition: req.Header.Get("Content-Disposition"), MediaMetadata: &types.MediaMetadata{
ContentLength: req.ContentLength, Origin: cfg.ServerName,
ContentType: req.Header.Get("Content-Type"), ContentDisposition: types.ContentDisposition(req.Header.Get("Content-Disposition")),
Filename: req.FormValue("filename"), ContentLength: types.ContentLength(req.ContentLength),
Method: req.Method, ContentType: types.ContentType(req.Header.Get("Content-Type")),
UserID: userID, UploadName: types.Filename(req.FormValue("filename")),
UserID: types.MatrixUserID(userID),
},
} }
if resErr = r.Validate(); resErr != nil { // FIXME: if no Content-Disposition then set
if resErr = r.Validate(cfg.MaxFileSize); resErr != nil {
return *resErr return *resErr
} }
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"ContentType": r.ContentType, "Origin": r.MediaMetadata.Origin,
"Filename": r.Filename, "UploadName": r.MediaMetadata.UploadName,
"UserID": r.UserID, "Content-Length": r.MediaMetadata.ContentLength,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Info("Uploading file") }).Info("Uploading file")
// TODO: Store file to disk tmpDir, err := createTempDir(cfg.BasePath)
// - make path to file
// - progressive writing (could support Content-Length 0 and cut off
// after some max upload size is exceeded)
// - generate id (ideally a hash but a random string to start with)
writer, err := repo.WriterToLocalRepository(storage.Description{
Type: r.ContentType,
})
if err != nil { if err != nil {
logger.Infof("Failed to get cache writer %q\n", err) logger.Infof("Failed to create temp dir %q\n", err)
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
} }
} }
file, writer, err := createFileWriter(tmpDir, "content")
defer file.Close()
defer writer.Close() // The limited reader restricts how many bytes are read from the body to the specified maximum bytes
// Note: the golang HTTP server closes the request body
limitedBody := io.LimitReader(req.Body, int64(cfg.MaxFileSize))
hasher := sha256.New()
reader := io.TeeReader(limitedBody, hasher)
if _, err = io.Copy(writer, req.Body); err != nil { bytesWritten, err := io.Copy(writer, reader)
if err != nil {
logger.Infof("Failed to copy %q\n", err) logger.Infof("Failed to copy %q\n", err)
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
} }
} }
r.Base64FileHash, err = writer.Finished() writer.Flush()
if err != nil {
if bytesWritten != int64(r.MediaMetadata.ContentLength) {
logger.Warnf("Bytes uploaded (%v) != claimed Content-Length (%v)", bytesWritten, r.MediaMetadata.ContentLength)
}
hash := hasher.Sum(nil)
r.MediaMetadata.MediaID = types.MediaID(base64.URLEncoding.EncodeToString(hash[:]))
logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
"Content-Length": r.MediaMetadata.ContentLength,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Info("File uploaded")
// check if we already have a record of the media in our database and if so, we can remove the temporary directory
err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, r.MediaMetadata)
if err == nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 200,
JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), JSON: uploadResponse{
ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID),
},
} }
} else if err != nil && err != sql.ErrNoRows {
logger.Warnf("Failed to query database for %v: %q", r.MediaMetadata.MediaID, err)
} }
// TODO: check if file with hash already exists
// TODO: generate thumbnails // TODO: generate thumbnails
err = db.CreateMedia(r.Base64FileHash, cfg.ServerName, r.ContentType, r.ContentDisposition, r.ContentLength, r.Filename, r.UserID) err = db.StoreMediaMetadata(r.MediaMetadata)
if err != nil { if err != nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
JSON: jsonerror.BadJSON(fmt.Sprintf("Failed to upload: %q", err)), JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
}
}
err = moveFile(
types.Path(path.Join(string(tmpDir), "content")),
types.Path(getPathFromMediaMetadata(r.MediaMetadata, cfg.BasePath)),
)
if err != nil {
tmpDirErr := os.RemoveAll(string(tmpDir))
if tmpDirErr != nil {
logger.Warnf("Failed to remove tmpDir (%v): %q\n", tmpDir, tmpDirErr)
}
return util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload: %q", err)),
} }
} }
return util.JSONResponse{ return util.JSONResponse{
Code: 200, Code: 200,
JSON: uploadResponse{ JSON: uploadResponse{
ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.Base64FileHash), ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID),
}, },
} }
} }

View file

@ -0,0 +1,80 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package writers
import (
"bufio"
"io/ioutil"
"os"
"path"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/mediaapi/types"
)
// createTempDir creates a tmp/<random string> directory within baseDirectory and returns its path
func createTempDir(baseDirectory types.Path) (types.Path, error) {
baseTmpDir := path.Join(string(baseDirectory), "tmp")
err := os.MkdirAll(baseTmpDir, 0770)
if err != nil {
log.Printf("Failed to create base temp dir: %v\n", err)
return "", err
}
tmpDir, err := ioutil.TempDir(baseTmpDir, "")
if err != nil {
log.Printf("Failed to create temp dir: %v\n", err)
return "", err
}
return types.Path(tmpDir), nil
}
// createFileWriter creates a buffered file writer with a new file at directory/filename
// Returns the file handle as it needs to be closed when writing is complete
func createFileWriter(directory types.Path, filename types.Filename) (*os.File, *bufio.Writer, error) {
filePath := path.Join(string(directory), string(filename))
file, err := os.Create(filePath)
if err != nil {
log.Printf("Failed to create file: %v\n", err)
return nil, nil, err
}
return file, bufio.NewWriter(file), nil
}
func getPathFromMediaMetadata(m *types.MediaMetadata, basePath types.Path) string {
return path.Join(
string(basePath),
string(m.Origin),
string(m.MediaID[:3]),
string(m.MediaID[3:]),
)
}
// moveFile attempts to move the file src to dst
func moveFile(src types.Path, dst types.Path) error {
dstDir := path.Dir(string(dst))
err := os.MkdirAll(dstDir, 0770)
if err != nil {
log.Printf("Failed to make directory: %v\n", dstDir)
return err
}
err = os.Rename(string(src), string(dst))
if err != nil {
log.Printf("Failed to move directory: %v to %v\n", src, dst)
return err
}
return nil
}