This commit is contained in:
Robert Swain 2017-05-26 07:32:09 +00:00 committed by GitHub
commit 1d3d9689df
13 changed files with 1654 additions and 0 deletions

View file

@ -0,0 +1,89 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"net/http"
"os"
"path/filepath"
"strconv"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/Sirupsen/logrus"
)
var (
bindAddr = os.Getenv("BIND_ADDRESS")
dataSource = os.Getenv("DATABASE")
logDir = os.Getenv("LOG_DIR")
serverName = os.Getenv("SERVER_NAME")
basePath = os.Getenv("BASE_PATH")
// Note: if the MAX_FILE_SIZE_BYTES is set to 0, it will be unlimited
maxFileSizeBytesString = os.Getenv("MAX_FILE_SIZE_BYTES")
)
func main() {
common.SetupLogging(logDir)
if bindAddr == "" {
log.Panic("No BIND_ADDRESS environment variable found.")
}
if basePath == "" {
log.Panic("No BASE_PATH environment variable found.")
}
absBasePath, err := filepath.Abs(basePath)
if err != nil {
log.WithError(err).WithField("BASE_PATH", basePath).Panic("BASE_PATH is invalid (must be able to make absolute)")
}
if serverName == "" {
serverName = "localhost"
}
maxFileSizeBytes, err := strconv.ParseInt(maxFileSizeBytesString, 10, 64)
if err != nil {
maxFileSizeBytes = 10 * 1024 * 1024
log.WithError(err).WithField("MAX_FILE_SIZE_BYTES", maxFileSizeBytesString).Warnf("Failed to parse MAX_FILE_SIZE_BYTES. Defaulting to %v bytes.", maxFileSizeBytes)
}
cfg := &config.MediaAPI{
ServerName: gomatrixserverlib.ServerName(serverName),
AbsBasePath: types.Path(absBasePath),
MaxFileSizeBytes: types.FileSizeBytes(maxFileSizeBytes),
DataSource: dataSource,
}
db, err := storage.Open(cfg.DataSource)
if err != nil {
log.WithError(err).Panic("Failed to open database")
}
log.WithFields(log.Fields{
"BASE_PATH": absBasePath,
"BIND_ADDRESS": bindAddr,
"DATABASE": dataSource,
"LOG_DIR": logDir,
"MAX_FILE_SIZE_BYTES": maxFileSizeBytes,
"SERVER_NAME": serverName,
}).Info("Starting mediaapi")
routing.Setup(http.DefaultServeMux, http.DefaultClient, cfg, db)
log.Fatal(http.ListenAndServe(bindAddr, nil))
}

View file

@ -0,0 +1,68 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"io/ioutil"
"log"
"net/http"
"sync"
)
const origin = "matrix.org"
const mediaID = "rZOBfBHnuOoyqBKUIHAaSbcM"
const requestCount = 100
func main() {
httpURL := "http://localhost:7777/api/_matrix/media/v1/download/" + origin + "/" + mediaID
jsonResponses := make(chan string)
var wg sync.WaitGroup
wg.Add(requestCount)
for i := 0; i < requestCount; i++ {
go func() {
defer wg.Done()
res, err := http.Get(httpURL)
if err != nil {
log.Fatal(err)
} else {
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
log.Fatal(err)
} else {
if res.StatusCode != 200 {
jsonResponses <- string(body)
}
}
}
}()
}
errorCount := 0
go func() {
for response := range jsonResponses {
errorCount++
fmt.Println(response)
}
}()
wg.Wait()
fmt.Printf("%v/%v requests were successful\n", requestCount-errorCount, requestCount)
}

View file

@ -0,0 +1,5 @@
# Media API
This server is responsible for serving `/media` requests as per:
http://matrix.org/docs/spec/client_server/r0.2.0.html#id43

View file

@ -0,0 +1,34 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
// MediaAPI contains the config information necessary to spin up a mediaapi process.
type MediaAPI struct {
// The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'.
ServerName gomatrixserverlib.ServerName `yaml:"server_name"`
// The absolute base path to where media files will be stored.
AbsBasePath types.Path `yaml:"abs_base_path"`
// The maximum file size in bytes that is allowed to be stored on this server.
// Note that remote files larger than this can still be proxied to a client, they will just not be cached.
// Note: if MaxFileSizeBytes is set to 0, the size is unlimited.
MaxFileSizeBytes types.FileSizeBytes `yaml:"max_file_size_bytes"`
// The postgres connection config for connecting to the database e.g a postgres:// URI
DataSource string `yaml:"database"`
}

View file

@ -0,0 +1,284 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fileutils
import (
"bufio"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/mediaapi/types"
)
// RemoveDir removes a directory and logs a warning in case of errors
func RemoveDir(dir types.Path, logger *log.Entry) {
dirErr := os.RemoveAll(string(dir))
if dirErr != nil {
logger.WithError(dirErr).WithField("dir", dir).Warn("Failed to remove directory")
}
}
// createTempDir creates a tmp/<random string> directory within baseDirectory and returns its path
func createTempDir(baseDirectory types.Path) (types.Path, error) {
baseTmpDir := path.Join(string(baseDirectory), "tmp")
if err := os.MkdirAll(baseTmpDir, 0770); err != nil {
return "", fmt.Errorf("Failed to create base temp dir: %v", err)
}
tmpDir, err := ioutil.TempDir(baseTmpDir, "")
if err != nil {
return "", fmt.Errorf("Failed to create temp dir: %v", err)
}
return types.Path(tmpDir), nil
}
// createFileWriter creates a buffered file writer with a new file at directory/filename
// Returns the file handle as it needs to be closed when writing is complete
func createFileWriter(directory types.Path, filename types.Filename) (*bufio.Writer, *os.File, error) {
filePath := path.Join(string(directory), string(filename))
file, err := os.Create(filePath)
if err != nil {
return nil, nil, fmt.Errorf("Failed to create file: %v", err)
}
return bufio.NewWriter(file), file, nil
}
func createTempFileWriter(absBasePath types.Path) (*bufio.Writer, *os.File, types.Path, error) {
tmpDir, err := createTempDir(absBasePath)
if err != nil {
return nil, nil, "", fmt.Errorf("Failed to create temp dir: %q", err)
}
writer, tmpFile, err := createFileWriter(tmpDir, "content")
if err != nil {
return nil, nil, "", fmt.Errorf("Failed to create file writer: %q", err)
}
return writer, tmpFile, tmpDir, nil
}
var (
// ErrFileIsTooLarge indicates that the uploaded file is larger than the configured maximum file size
ErrFileIsTooLarge = fmt.Errorf("file is too large")
errRead = fmt.Errorf("failed to read response from remote server")
errResponse = fmt.Errorf("failed to write file data to response body")
errHash = fmt.Errorf("failed to hash file data")
errWrite = fmt.Errorf("failed to write file to disk")
)
// writeToResponse takes bytesToWrite bytes from buffer and writes them to respWriter
// Returns bytes written and an error. In case of error, the number of bytes written will be 0.
func writeToResponse(respWriter http.ResponseWriter, buffer []byte, bytesToWrite int) (int64, error) {
bytesWritten, respErr := respWriter.Write(buffer[:bytesToWrite])
if bytesWritten != bytesToWrite || (respErr != nil && respErr != io.EOF) {
return 0, errResponse
}
return int64(bytesWritten), nil
}
// writeToDiskAndHasher takes bytesToWrite bytes from buffer and writes them to tmpFileWriter and hasher.
// Returns bytes written and an error. In case of error, including if writing would exceed maxFileSizeBytes,
// the number of bytes written will be 0.
func writeToDiskAndHasher(tmpFileWriter *bufio.Writer, hasher hash.Hash, bytesWritten int64, maxFileSizeBytes types.FileSizeBytes, buffer []byte, bytesToWrite int) (int64, error) {
// if larger than maxFileSizeBytes then stop writing to disk and discard cached file
if bytesWritten+int64(bytesToWrite) > int64(maxFileSizeBytes) {
return 0, ErrFileIsTooLarge
}
// write to hasher and to disk
bytesTemp, writeErr := tmpFileWriter.Write(buffer[:bytesToWrite])
bytesHashed, hashErr := hasher.Write(buffer[:bytesToWrite])
if writeErr != nil && writeErr != io.EOF || bytesTemp != bytesToWrite || bytesTemp != bytesHashed {
return 0, errWrite
} else if hashErr != nil && hashErr != io.EOF {
return 0, errHash
}
return int64(bytesTemp), nil
}
// WriteTempFile writes to a new temporary file
// * creates a temporary file
// * writes data from reqReader to disk and simultaneously hash it
// * the amount of data written to disk and hashed is limited by maxFileSizeBytes
// * if a respWriter is supplied, the data is also simultaneously written to that
// * data written to the respWriter is _not_ limited to maxFileSizeBytes such that
// the homeserver can proxy files larger than it is willing to cache
// Returns all of the hash sum, bytes written to disk, bytes proxied, and temporary directory path, or an error.
func WriteTempFile(reqReader io.Reader, maxFileSizeBytes types.FileSizeBytes, absBasePath types.Path, respWriter http.ResponseWriter) (types.Base64Hash, types.FileSizeBytes, types.FileSizeBytes, types.Path, error) {
// create the temporary file writer
tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath)
if err != nil {
return "", -1, -1, "", err
}
defer tmpFile.Close()
// The file data is hashed and the hash is returned. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
hasher := sha256.New()
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
var bytesResponded, bytesWritten int64 = 0, 0
var bytesTemp int64
var copyError error
// Note: the buffer size is the same as is used in io.Copy()
buffer := make([]byte, 32*1024)
for {
// read from remote request's response body
bytesRead, readErr := reqReader.Read(buffer)
if bytesRead > 0 {
// Note: This code allows proxying files larger than maxFileSizeBytes!
// write to client request's response body
if respWriter != nil {
bytesTemp, copyError = writeToResponse(respWriter, buffer, bytesRead)
bytesResponded += bytesTemp
if copyError != nil {
break
}
}
if copyError == nil {
// Note: if we get here then copyError != ErrFileIsTooLarge && copyError != errWrite
// as if copyError == errResponse || copyError == errWrite then we would have broken
// out of the loop and there are no other cases
bytesTemp, copyError = writeToDiskAndHasher(tmpFileWriter, hasher, bytesWritten, maxFileSizeBytes, buffer, (bytesRead))
bytesWritten += bytesTemp
// If we do not have a respWriter then we are only writing to the hasher and tmpFileWriter. In that case, if we get an error, we need to break.
if respWriter == nil && copyError != nil {
break
}
}
}
if readErr != nil {
if readErr != io.EOF {
copyError = errRead
}
break
}
}
if copyError != nil {
return "", -1, -1, "", copyError
}
tmpFileWriter.Flush()
hash := hasher.Sum(nil)
return types.Base64Hash(base64.URLEncoding.EncodeToString(hash[:])), types.FileSizeBytes(bytesResponded), types.FileSizeBytes(bytesWritten), tmpDir, nil
}
// GetPathFromBase64Hash evaluates the path to a media file from its Base64Hash
// If the Base64Hash is long enough, we split it into pieces, creating up to 2 subdirectories
// for more manageable browsing and use the remainder as the file name.
// For example, if Base64Hash is 'qwerty', the path will be 'q/w/erty'.
func GetPathFromBase64Hash(base64Hash types.Base64Hash, absBasePath types.Path) (string, error) {
var subPath, fileName string
hashLen := len(base64Hash)
switch {
case hashLen < 1:
return "", fmt.Errorf("Invalid filePath (Base64Hash too short): %q", base64Hash)
case hashLen > 255:
return "", fmt.Errorf("Invalid filePath (Base64Hash too long - max 255 characters): %q", base64Hash)
case hashLen < 2:
subPath = ""
fileName = string(base64Hash)
case hashLen < 3:
subPath = string(base64Hash[0:1])
fileName = string(base64Hash[1:])
default:
subPath = path.Join(
string(base64Hash[0:1]),
string(base64Hash[1:2]),
)
fileName = string(base64Hash[2:])
}
filePath, err := filepath.Abs(path.Join(
string(absBasePath),
subPath,
fileName,
))
if err != nil {
return "", fmt.Errorf("Unable to construct filePath: %q", err)
}
// check if the absolute absBasePath is a prefix of the absolute filePath
// if so, no directory escape has occurred and the filePath is valid
// Note: absBasePath is already absolute
if strings.HasPrefix(filePath, string(absBasePath)) == false {
return "", fmt.Errorf("Invalid filePath (not within absBasePath %v): %v", absBasePath, filePath)
}
return filePath, nil
}
// moveFile attempts to move the file src to dst
func moveFile(src types.Path, dst types.Path) error {
dstDir := path.Dir(string(dst))
err := os.MkdirAll(dstDir, 0770)
if err != nil {
return fmt.Errorf("Failed to make directory: %q", err)
}
err = os.Rename(string(src), string(dst))
if err != nil {
return fmt.Errorf("Failed to move directory: %q", err)
}
return nil
}
// MoveFileWithHashCheck checks for hash collisions when moving a temporary file to its destination based on metadata
// Check if destination file exists. As the destination is based on a hash of the file data,
// if it exists and the file size does not match then there is a hash collision for two different files. If
// it exists and the file size matches, it is believable that it is the same file and we can just
// discard the temporary file.
func MoveFileWithHashCheck(tmpDir types.Path, mediaMetadata *types.MediaMetadata, absBasePath types.Path, logger *log.Entry) (string, bool, error) {
duplicate := false
finalPath, err := GetPathFromBase64Hash(mediaMetadata.Base64Hash, absBasePath)
if err != nil {
RemoveDir(tmpDir, logger)
return "", duplicate, fmt.Errorf("failed to get file path from metadata: %q", err)
}
var stat os.FileInfo
if stat, err = os.Stat(finalPath); os.IsExist(err) {
duplicate = true
if stat.Size() == int64(mediaMetadata.FileSizeBytes) {
RemoveDir(tmpDir, logger)
return finalPath, duplicate, nil
}
// Remove the tmpDir as we anyway cannot cache the file on disk due to the hash collision
RemoveDir(tmpDir, logger)
return "", duplicate, fmt.Errorf("downloaded file with hash collision but different file size (%v)", finalPath)
}
err = moveFile(
types.Path(path.Join(string(tmpDir), "content")),
types.Path(finalPath),
)
if err != nil {
RemoveDir(tmpDir, logger)
return "", duplicate, fmt.Errorf("failed to move file to final destination (%v): %q", finalPath, err)
}
return finalPath, duplicate, nil
}

View file

@ -0,0 +1,67 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"sync"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/mediaapi/writers"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
)
const pathPrefixR0 = "/_matrix/media/v1"
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
// to clients which need to make outbound HTTP requests.
func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg *config.MediaAPI, db *storage.Database) {
apiMux := mux.NewRouter()
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
r0mux.Handle("/upload", makeAPI("upload", func(req *http.Request) util.JSONResponse {
return writers.Upload(req, cfg, db)
}))
activeRemoteRequests := &types.ActiveRemoteRequests{
Set: map[string]*sync.Cond{},
}
r0mux.Handle("/download/{serverName}/{mediaId}",
prometheus.InstrumentHandler("download", http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
// Set common headers returned regardless of the outcome of the request
util.SetCORSHeaders(w)
// Content-Type will be overridden in case of returning file data, else we respond with JSON-formatted errors
w.Header().Set("Content-Type", "application/json")
vars := mux.Vars(req)
writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests)
})),
)
servMux.Handle("/metrics", prometheus.Handler())
servMux.Handle("/api/", http.StripPrefix("/api", apiMux))
}
// make a util.JSONRequestHandler function into an http.Handler.
func makeAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
h := util.NewJSONRequestHandler(f)
return prometheus.InstrumentHandler(metricsName, util.MakeJSONAPI(h))
}

View file

@ -0,0 +1,112 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
const mediaSchema = `
-- The media_repository table holds metadata for each media file stored and accessible to the local server,
-- the actual file is stored separately.
CREATE TABLE IF NOT EXISTS media_repository (
-- The id used to refer to the media.
-- For uploads to this server this is a base64-encoded sha256 hash of the file data
-- For media from remote servers, this can be any unique identifier string
media_id TEXT NOT NULL,
-- The origin of the media as requested by the client. Should be a homeserver domain.
media_origin TEXT NOT NULL,
-- The MIME-type of the media file as specified when uploading.
content_type TEXT NOT NULL,
-- The HTTP Content-Disposition header for the media file as specified when uploading.
content_disposition TEXT NOT NULL,
-- Size of the media file in bytes.
file_size_bytes BIGINT NOT NULL,
-- When the content was uploaded in UNIX epoch ms.
creation_ts BIGINT NOT NULL,
-- The file name with which the media was uploaded.
upload_name TEXT NOT NULL,
-- A golang base64 URLEncoding string representation of a SHA-256 hash sum of the file data.
base64hash TEXT NOT NULL,
-- The user who uploaded the file. Should be a Matrix user ID.
user_id TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS media_repository_index ON media_repository (media_id, media_origin);
`
const insertMediaSQL = `
INSERT INTO media_repository (media_id, media_origin, content_type, content_disposition, file_size_bytes, creation_ts, upload_name, base64hash, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
const selectMediaSQL = `
SELECT content_type, content_disposition, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM media_repository WHERE media_id = $1 AND media_origin = $2
`
type mediaStatements struct {
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
if err != nil {
return
}
return statementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
}.prepare(db)
}
func (s *mediaStatements) insertMedia(mediaMetadata *types.MediaMetadata) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.Exec(
mediaMetadata.MediaID,
mediaMetadata.Origin,
mediaMetadata.ContentType,
mediaMetadata.ContentDisposition,
mediaMetadata.FileSizeBytes,
mediaMetadata.CreationTimestamp,
mediaMetadata.UploadName,
mediaMetadata.Base64Hash,
mediaMetadata.UserID,
)
return err
}
func (s *mediaStatements) selectMedia(mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRow(
mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
&mediaMetadata.ContentDisposition,
&mediaMetadata.FileSizeBytes,
&mediaMetadata.CreationTimestamp,
&mediaMetadata.UploadName,
&mediaMetadata.Base64Hash,
&mediaMetadata.UserID,
)
return &mediaMetadata, err
}

View file

@ -0,0 +1,37 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made common!
package storage
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View file

@ -0,0 +1,33 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
)
type statements struct {
mediaStatements
}
func (s *statements) prepare(db *sql.DB) error {
var err error
if err = s.mediaStatements.prepare(db); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,56 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
// A Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
}
// Open a postgres database.
func Open(dataSourceName string) (*Database, error) {
var d Database
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(mediaMetadata *types.MediaMetadata) error {
return d.statements.insertMedia(mediaMetadata)
}
// GetMediaMetadata returns metadata about media stored on this server. The media could
// have been uploaded to this server or fetched from another server and cached here.
// Returns sql.ErrNoRows if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
return d.statements.selectMedia(mediaID, mediaOrigin)
}

View file

@ -0,0 +1,72 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
// ContentDisposition is an HTTP Content-Disposition header string
type ContentDisposition string
// FileSizeBytes is a file size in bytes
type FileSizeBytes int64
// ContentType is an HTTP Content-Type header string representing the MIME type of a request body
type ContentType string
// Filename is a string representing the name of a file
type Filename string
// Base64Hash is a base64 URLEncoding string representation of a SHA-256 hash sum
type Base64Hash string
// Path is an absolute or relative UNIX filesystem path
type Path string
// MediaID is a string representing the unique identifier for a file (could be a hash but does not have to be)
type MediaID string
// RequestMethod is an HTTP request method i.e. GET, POST, etc
type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file
type MediaMetadata struct {
MediaID MediaID
Origin gomatrixserverlib.ServerName
ContentType ContentType
ContentDisposition ContentDisposition
FileSizeBytes FileSizeBytes
CreationTimestamp UnixMs
UploadName Filename
Base64Hash Base64Hash
UserID MatrixUserID
}
// ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers
// It is used for ensuring multiple requests for the same file do not clobber each other.
type ActiveRemoteRequests struct {
sync.Mutex
// The string key is an mxc:// URL
Set map[string]*sync.Cond
}

View file

@ -0,0 +1,536 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package writers
import (
"database/sql"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path"
"strconv"
"strings"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// downloadRequest metadata included in or derivable from an download request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
type downloadRequest struct {
MediaMetadata *types.MediaMetadata
Logger *log.Entry
}
// Validate validates the downloadRequest fields
func (r *downloadRequest) Validate() *util.JSONResponse {
// FIXME: the following errors aren't bad JSON, rather just a bad request path
// maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...?
if r.MediaMetadata.MediaID == "" {
return &util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("mediaId must be a non-empty string"),
}
}
if r.MediaMetadata.Origin == "" {
return &util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("serverName must be a non-empty string"),
}
}
return nil
}
func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) {
// Marshal JSON response into raw bytes to send as the HTTP body
resBytes, err := json.Marshal(res.JSON)
if err != nil {
r.Logger.WithError(err).Error("Failed to marshal JSONResponse")
// this should never fail to be marshalled so drop err to the floor
res = util.MessageResponse(500, "Internal Server Error")
resBytes, _ = json.Marshal(res.JSON)
}
// Set status code and write the body
w.WriteHeader(res.Code)
r.Logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
w.Write(resBytes)
}
// Download implements /download
// Files from this server (i.e. origin == cfg.ServerName) are served directly
// Files from remote servers (i.e. origin != cfg.ServerName) are cached locally.
// If they are present in the cache, they are served directly.
// If they are not present in the cache, they are obtained from the remote server and
// simultaneously served back to the client and written into the cache.
func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
r := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
Origin: origin,
},
Logger: util.GetLogger(req.Context()),
}
// request validation
if req.Method != "GET" {
r.jsonErrorResponse(w, util.JSONResponse{
Code: 405,
JSON: jsonerror.Unknown("request method must be GET"),
})
return
}
if resErr := r.Validate(); resErr != nil {
r.jsonErrorResponse(w, *resErr)
return
}
// check if we have a record of the media in our database
mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err == nil {
// If we have a record, we can respond from the local file
r.MediaMetadata = mediaMetadata
r.respondFromLocalFile(w, cfg.AbsBasePath)
return
} else if err == sql.ErrNoRows && r.MediaMetadata.Origin != cfg.ServerName {
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file
mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
// The following code using activeRemoteRequests is avoiding duplication of fetches from the remote server in the case
// of multiple simultaneous incoming requests for the same remote file - it will be downloaded once, cached and served
// to all clients.
activeRemoteRequests.Lock()
mediaMetadata, err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err == nil {
// If we have a record, we can respond from the local file
r.MediaMetadata = mediaMetadata
r.respondFromLocalFile(w, cfg.AbsBasePath)
activeRemoteRequests.Unlock()
return
}
if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Info("Waiting for another goroutine to fetch the remote file.")
activeRemoteRequestCondition.Wait()
activeRemoteRequests.Unlock()
activeRemoteRequests.Lock()
mediaMetadata, err = db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err == nil {
// If we have a record, we can respond from the local file
r.MediaMetadata = mediaMetadata
r.respondFromLocalFile(w, cfg.AbsBasePath)
} else {
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Warn("Other goroutine failed to fetch remote file.")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
})
}
activeRemoteRequests.Unlock()
return
}
activeRemoteRequests.Set[mxcURL] = &sync.Cond{L: activeRemoteRequests}
activeRemoteRequests.Unlock()
r.respondFromRemoteFile(w, cfg.AbsBasePath, cfg.MaxFileSizeBytes, db, activeRemoteRequests)
} else if err == sql.ErrNoRows && r.MediaMetadata.Origin == cfg.ServerName {
// If we do not have a record and the origin is local, the file is not found
r.Logger.WithError(err).Warn("Failed to look up file in database")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
})
} else {
// Another error from the database
r.Logger.WithError(err).WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Error("Error querying the database.")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown("Internal server error"),
})
}
}
func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePath types.Path) {
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
"Base64Hash": r.MediaMetadata.Base64Hash,
"FileSizeBytes": r.MediaMetadata.FileSizeBytes,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Infof("Downloading file")
filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath)
if err != nil {
// FIXME: Remove erroneous file from database?
r.Logger.WithError(err).Warn("Failed to get file path from metadata")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
})
return
}
file, err := os.Open(filePath)
if err != nil {
// FIXME: Remove erroneous file from database?
r.Logger.WithError(err).Warn("Failed to open file")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
})
return
}
stat, err := file.Stat()
if err != nil {
// FIXME: Remove erroneous file from database?
r.Logger.WithError(err).Warn("Failed to stat file")
r.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
})
return
}
if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() {
r.Logger.WithFields(log.Fields{
"fileSizeDatabase": r.MediaMetadata.FileSizeBytes,
"fileSizeDisk": stat.Size(),
}).Warn("File size in database and on-disk differ.")
// FIXME: Remove erroneous file from database?
}
w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType))
w.Header().Set("Content-Length", strconv.FormatInt(stat.Size(), 10))
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if bytesResponded, err := io.Copy(w, file); err != nil {
r.Logger.WithError(err).Warn("Failed to copy from cache")
if bytesResponded == 0 {
r.jsonErrorResponse(w, util.JSONResponse{
Code: 500,
JSON: jsonerror.NotFound(fmt.Sprintf("Failed to respond with file with media ID %q", r.MediaMetadata.MediaID)),
})
}
// If we have written any data then we have already responded with 200 OK and all we can do is close the connection
return
}
}
func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) {
urls := getMatrixURLs(r.MediaMetadata.Origin)
r.Logger.WithField("URL", urls[0]).Info("Connecting to remote")
remoteReqAddr := urls[0] + "/_matrix/media/v1/download/" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
remoteReq, err := http.NewRequest("GET", remoteReqAddr, nil)
if err != nil {
return nil, &util.JSONResponse{
Code: 500,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}
}
remoteReq.Header.Set("Host", string(r.MediaMetadata.Origin))
client := http.Client{}
resp, err := client.Do(remoteReq)
if err != nil {
return nil, &util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}
}
if resp.StatusCode != 200 {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
"StatusCode": resp.StatusCode,
}).Info("Received error response")
if resp.StatusCode == 404 {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
"StatusCode": resp.StatusCode,
}).Warn("Remote server says file does not exist")
return nil, &util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}
}
return nil, &util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}
}
return resp, nil
}
func (r *downloadRequest) closeConnection(w http.ResponseWriter) {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Info("Attempting to close the connection.")
hijacker, ok := w.(http.Hijacker)
if ok {
connection, _, hijackErr := hijacker.Hijack()
if hijackErr == nil {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Info("Closing")
connection.Close()
} else {
r.Logger.WithError(hijackErr).WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Warn("Error trying to hijack and close connection")
}
}
}
func completeRemoteRequest(activeRemoteRequests *types.ActiveRemoteRequests, mxcURL string) {
if activeRemoteRequestCondition, ok := activeRemoteRequests.Set[mxcURL]; ok {
activeRemoteRequestCondition.Broadcast()
}
delete(activeRemoteRequests.Set, mxcURL)
activeRemoteRequests.Unlock()
}
func (r *downloadRequest) commitFileAndMetadata(tmpDir types.Path, absBasePath types.Path, activeRemoteRequests *types.ActiveRemoteRequests, db *storage.Database, mxcURL string) bool {
updateActiveRemoteRequests := true
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"Base64Hash": r.MediaMetadata.Base64Hash,
"UploadName": r.MediaMetadata.UploadName,
"FileSizeBytes": r.MediaMetadata.FileSizeBytes,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Info("Storing file metadata to media repository database")
// The database is the source of truth so we need to have moved the file first
finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger)
if err != nil {
r.Logger.WithError(err).Error("Failed to move file.")
return updateActiveRemoteRequests
}
if duplicate {
r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate")
// Continue on to store the metadata in the database
}
// Writing the metadata to the media repository database and removing the mxcURL from activeRemoteRequests needs to be atomic.
// If it were not atomic, a new request for the same file could come in in routine A and check the database before the INSERT.
// Routine B which was fetching could then have its INSERT complete and remove the mxcURL from the activeRemoteRequests.
// If routine A then checked the activeRemoteRequests it would think it needed to fetch the file when it's already in the database.
// The locking below mitigates this situation.
updateActiveRemoteRequests = false
activeRemoteRequests.Lock()
// FIXME: unlock after timeout of db request
// if written to disk, add to db
err = db.StoreMediaMetadata(r.MediaMetadata)
if err != nil {
// If the file is a duplicate (has the same hash as an existing file) then
// there is valid metadata in the database for that file. As such we only
// remove the file if it is not a duplicate.
if duplicate == false {
finalDir := path.Dir(finalPath)
fileutils.RemoveDir(types.Path(finalDir), r.Logger)
}
completeRemoteRequest(activeRemoteRequests, mxcURL)
return updateActiveRemoteRequests
}
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Info("Signalling other goroutines waiting for us to fetch the file.")
completeRemoteRequest(activeRemoteRequests, mxcURL)
return updateActiveRemoteRequests
}
func (r *downloadRequest) respondFromRemoteFile(w http.ResponseWriter, absBasePath types.Path, maxFileSizeBytes types.FileSizeBytes, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Infof("Fetching remote file")
mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
// If we hit an error and we return early, we need to lock, broadcast on the condition, delete the condition and unlock.
// If we return normally we have slightly different locking around the storage of metadata to the database and deletion of the condition.
// As such, this deferred cleanup of the sync.Cond is conditional.
// This approach seems safer than potentially missing this cleanup in error cases.
updateActiveRemoteRequests := true
defer func() {
if updateActiveRemoteRequests {
activeRemoteRequests.Lock()
// Note that completeRemoteRequest unlocks activeRemoteRequests
completeRemoteRequest(activeRemoteRequests, mxcURL)
}
}()
// create request for remote file
resp, errorResponse := r.createRemoteRequest()
if errorResponse != nil {
r.jsonErrorResponse(w, *errorResponse)
return
}
defer resp.Body.Close()
// get metadata from request and set metadata on response
contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil {
r.Logger.WithError(err).Warn("Failed to parse content length")
}
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
r.MediaMetadata.ContentDisposition = types.ContentDisposition(resp.Header.Get("Content-Disposition"))
// FIXME: parse from Content-Disposition header if possible, else fall back
//r.MediaMetadata.UploadName = types.Filename()
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Infof("Connected to remote")
w.Header().Set("Content-Type", string(r.MediaMetadata.ContentType))
w.Header().Set("Content-Length", strconv.FormatInt(int64(r.MediaMetadata.FileSizeBytes), 10))
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
// read the remote request's response body
// simultaneously write it to the incoming request's response body and the temporary file
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}).Info("Proxying and caching remote file")
// The file data is hashed but is NOT used as the MediaID, unlike in Upload. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
// bytesResponded is the total number of bytes written to the response to the client request
// bytesWritten is the total number of bytes written to disk
hash, bytesResponded, bytesWritten, tmpDir, copyError := fileutils.WriteTempFile(resp.Body, maxFileSizeBytes, absBasePath, w)
if copyError != nil {
logFields := log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
}
if copyError == fileutils.ErrFileIsTooLarge {
logFields["MaxFileSizeBytes"] = maxFileSizeBytes
}
r.Logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file")
fileutils.RemoveDir(tmpDir, r.Logger)
// Note: if we have responded with any data in the body at all then we have already sent 200 OK and we can only abort at this point
if bytesResponded < 1 {
r.jsonErrorResponse(w, util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File could not be downloaded from remote server")),
})
} else {
// We attempt to bluntly close the connection because that is the
// best thing we can do after we've sent a 200 OK
r.closeConnection(w)
}
return
}
// The file has been fetched. It is moved to its final destination and its metadata is inserted into the database.
// Note: After this point we have responded to the client's request and are just dealing with local caching.
// As we have responded with 200 OK, any errors are ineffectual to the client request and so we just log and return.
// FIXME: Does continuing to do work here that is ineffectual to the client have any bad side effects? Could we fire off the remainder in a separate goroutine to mitigate that?
// It's possible the bytesWritten to the temporary file is different to the reported Content-Length from the remote
// request's response. bytesWritten is therefore used as it is what would be sent to clients when reading from the local
// file.
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(bytesWritten)
r.MediaMetadata.Base64Hash = hash
r.MediaMetadata.UserID = types.MatrixUserID("@:" + string(r.MediaMetadata.Origin))
updateActiveRemoteRequests = r.commitFileAndMetadata(tmpDir, absBasePath, activeRemoteRequests, db, mxcURL)
// TODO: generate thumbnails
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
"Base64Hash": r.MediaMetadata.Base64Hash,
"FileSizeBytes": r.MediaMetadata.FileSizeBytes,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Infof("Remote file cached")
}
// Given a matrix server name, attempt to discover URLs to contact the server
// on.
func getMatrixURLs(serverName gomatrixserverlib.ServerName) []string {
_, srvs, err := net.LookupSRV("matrix", "tcp", string(serverName))
if err != nil {
return []string{"https://" + string(serverName) + ":8448"}
}
results := make([]string, 0, len(srvs))
for _, srv := range srvs {
if srv == nil {
continue
}
url := []string{"https://", strings.Trim(srv.Target, "."), ":", strconv.Itoa(int(srv.Port))}
results = append(results, strings.Join(url, ""))
}
// TODO: Order based on priority and weight.
return results
}

View file

@ -0,0 +1,261 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package writers
import (
"database/sql"
"fmt"
"net/http"
"net/url"
"path"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util"
)
// uploadRequest metadata included in or derivable from an upload request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-upload
// NOTE: The members come from HTTP request metadata such as headers, query parameters or can be derived from such
type uploadRequest struct {
MediaMetadata *types.MediaMetadata
Logger *log.Entry
}
// Validate validates the uploadRequest fields
func (r *uploadRequest) Validate(maxFileSizeBytes types.FileSizeBytes) *util.JSONResponse {
// TODO: Any validation to be done on ContentDisposition?
if r.MediaMetadata.FileSizeBytes < 1 {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("HTTP Content-Length request header must be greater than zero."),
}
}
if maxFileSizeBytes > 0 && r.MediaMetadata.FileSizeBytes > maxFileSizeBytes {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("HTTP Content-Length is greater than the maximum allowed upload size (%v).", maxFileSizeBytes)),
}
}
// TODO: Check if the Content-Type is a valid type?
if r.MediaMetadata.ContentType == "" {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("HTTP Content-Type request header must be set."),
}
}
// TODO: Validate filename - what are the valid characters?
if r.MediaMetadata.UserID != "" {
// TODO: We should put user ID parsing code into gomatrixserverlib and use that instead
// (see https://github.com/matrix-org/gomatrixserverlib/blob/3394e7c7003312043208aa73727d2256eea3d1f6/eventcontent.go#L347 )
// It should be a struct (with pointers into a single string to avoid copying) and
// we should update all refs to use UserID types rather than strings.
// https://github.com/matrix-org/synapse/blob/v0.19.2/synapse/types.py#L92
if len(r.MediaMetadata.UserID) == 0 || r.MediaMetadata.UserID[0] != '@' {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("user id must start with '@'"),
}
}
parts := strings.SplitN(string(r.MediaMetadata.UserID[1:]), ":", 2)
if len(parts) != 2 {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("user id must be in the form @localpart:domain"),
}
}
}
return nil
}
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-upload
type uploadResponse struct {
ContentURI string `json:"content_uri"`
}
// parseAndValidateRequest parses the incoming upload request to validate and extract
// all the metadata about the media being uploaded. Returns either an uploadRequest or
// an error formatted as a util.JSONResponse
func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRequest, *util.JSONResponse) {
if req.Method != "POST" {
return nil, &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown("HTTP request method must be POST."),
}
}
// FIXME: This will require querying some other component/db but currently
// just accepts a user id for auth
userID, resErr := auth.VerifyAccessToken(req)
if resErr != nil {
return nil, resErr
}
r := &uploadRequest{
MediaMetadata: &types.MediaMetadata{
Origin: cfg.ServerName,
ContentDisposition: types.ContentDisposition(req.Header.Get("Content-Disposition")),
FileSizeBytes: types.FileSizeBytes(req.ContentLength),
ContentType: types.ContentType(req.Header.Get("Content-Type")),
UploadName: types.Filename(req.FormValue("filename")),
UserID: types.MatrixUserID(userID),
},
Logger: util.GetLogger(req.Context()),
}
if resErr = r.Validate(cfg.MaxFileSizeBytes); resErr != nil {
return nil, resErr
}
// FIXME: do we want to always override ContentDisposition here or only if
// there is no Content-Disposition header set?
if len(r.MediaMetadata.UploadName) > 0 {
r.MediaMetadata.ContentDisposition = types.ContentDisposition(
"inline; filename*=utf-8''" + url.PathEscape(string(r.MediaMetadata.UploadName)),
)
}
return r, nil
}
// storeFileAndMetadata first moves a temporary file named content from tmpDir to its
// final path (see getPathFromMediaMetadata for details.) Once the file is moved, the
// metadata about the file is written into the media repository database. This order
// of operations is important as it avoids metadata entering the database before the file
// is ready and if we fail to move the file, it never gets added to the database.
// In case of any error, appropriate files and directories are cleaned up a
// util.JSONResponse error is returned.
func (r *uploadRequest) storeFileAndMetadata(tmpDir types.Path, absBasePath types.Path, db *storage.Database) *util.JSONResponse {
finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger)
if err != nil {
r.Logger.WithError(err).Error("Failed to move file.")
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
}
if duplicate {
r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate")
}
if err = db.StoreMediaMetadata(r.MediaMetadata); err != nil {
r.Logger.WithError(err).Warn("Failed to store metadata")
// If the file is a duplicate (has the same hash as an existing file) then
// there is valid metadata in the database for that file. As such we only
// remove the file if it is not a duplicate.
if duplicate == false {
fileutils.RemoveDir(types.Path(path.Dir(finalPath)), r.Logger)
}
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
}
return nil
}
// Upload implements /upload
//
// This endpoint involves uploading potentially significant amounts of data to the homeserver.
// This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large.
// Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory.
// TODO: We should time out requests if they have not received any data within a configured timeout period.
func Upload(req *http.Request, cfg *config.MediaAPI, db *storage.Database) util.JSONResponse {
r, resErr := parseAndValidateRequest(req, cfg)
if resErr != nil {
return *resErr
}
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"UploadName": r.MediaMetadata.UploadName,
"FileSizeBytes": r.MediaMetadata.FileSizeBytes,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Info("Uploading file")
// The file data is hashed and the hash is used as the MediaID. The hash is useful as a
// method of deduplicating files to save storage, as well as a way to conduct
// integrity checks on the file data in the repository.
// bytesWritten is the total number of bytes written to disk
hash, _, bytesWritten, tmpDir, copyError := fileutils.WriteTempFile(req.Body, cfg.MaxFileSizeBytes, cfg.AbsBasePath, nil)
if copyError != nil {
logFields := log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}
if copyError == fileutils.ErrFileIsTooLarge {
logFields["MaxFileSizeBytes"] = cfg.MaxFileSizeBytes
}
r.Logger.WithError(copyError).WithFields(logFields).Warn("Error while transferring file")
fileutils.RemoveDir(tmpDir, r.Logger)
return util.JSONResponse{
Code: 400,
JSON: jsonerror.Unknown(fmt.Sprintf("Failed to upload")),
}
}
r.MediaMetadata.FileSizeBytes = bytesWritten
r.MediaMetadata.Base64Hash = hash
r.MediaMetadata.MediaID = types.MediaID(hash)
r.Logger.WithFields(log.Fields{
"MediaID": r.MediaMetadata.MediaID,
"Origin": r.MediaMetadata.Origin,
"Base64Hash": r.MediaMetadata.Base64Hash,
"UploadName": r.MediaMetadata.UploadName,
"FileSizeBytes": r.MediaMetadata.FileSizeBytes,
"Content-Type": r.MediaMetadata.ContentType,
"Content-Disposition": r.MediaMetadata.ContentDisposition,
}).Info("File uploaded")
// check if we already have a record of the media in our database and if so, we can remove the temporary directory
mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err == nil {
r.MediaMetadata = mediaMetadata
fileutils.RemoveDir(tmpDir, r.Logger)
return util.JSONResponse{
Code: 200,
JSON: uploadResponse{
ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID),
},
}
} else if err != sql.ErrNoRows {
r.Logger.WithError(err).WithField("MediaID", r.MediaMetadata.MediaID).Warn("Failed to query database")
}
// TODO: generate thumbnails
resErr = r.storeFileAndMetadata(tmpDir, cfg.AbsBasePath, db)
if resErr != nil {
return *resErr
}
return util.JSONResponse{
Code: 200,
JSON: uploadResponse{
ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.ServerName, r.MediaMetadata.MediaID),
},
}
}