mediaapi: Hack in /download from gotest code

This commit is contained in:
Robert Swain 2017-04-27 17:40:57 +02:00
parent c1e5974872
commit 81706408bd
4 changed files with 440 additions and 0 deletions

View file

@ -15,8 +15,10 @@
package routing package routing
import ( import (
"context"
"net/http" "net/http"
log "github.com/Sirupsen/logrus"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/mediaapi/config" "github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
@ -27,6 +29,48 @@ import (
const pathPrefixR0 = "/_matrix/media/v1" const pathPrefixR0 = "/_matrix/media/v1"
type contextKeys string
const ctxValueLogger = contextKeys("logger")
const ctxValueRequestID = contextKeys("requestid")
type Fudge struct {
Config config.MediaAPI
Database *storage.Database
DownloadServer writers.DownloadServer
}
func (fudge Fudge) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// NOTE: The code below is from util.Protect and respond but this is the only
// API that needs a different form of it to be able to pass the
// http.ResponseWriter to the handler
reqID := util.RandomString(12)
// Set a Logger and request ID on the context
ctx := context.WithValue(req.Context(), ctxValueLogger, log.WithFields(log.Fields{
"req.method": req.Method,
"req.path": req.URL.Path,
"req.id": reqID,
}))
ctx = context.WithValue(ctx, ctxValueRequestID, reqID)
req = req.WithContext(ctx)
logger := util.GetLogger(req.Context())
logger.Print("Incoming request")
if req.Method == "OPTIONS" {
util.SetCORSHeaders(w)
w.WriteHeader(200)
return
}
// Set common headers returned regardless of the outcome of the request
util.SetCORSHeaders(w)
w.Header().Set("Content-Type", "application/json")
vars := mux.Vars(req)
writers.Download(w, req, vars["serverName"], vars["mediaId"], fudge.Config, fudge.Database, fudge.DownloadServer)
}
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client // Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
// to clients which need to make outbound HTTP requests. // to clients which need to make outbound HTTP requests.
func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) { func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI, db *storage.Database, repo *storage.Repository) {
@ -36,6 +80,21 @@ func Setup(servMux *http.ServeMux, httpClient *http.Client, cfg config.MediaAPI,
return writers.Upload(req, cfg, db, repo) return writers.Upload(req, cfg, db, repo)
}))) })))
downloadServer := writers.DownloadServer{
Repository: *repo,
LocalServerName: cfg.ServerName,
}
fudge := Fudge{
Config: cfg,
Database: db,
DownloadServer: downloadServer,
}
r0mux.Handle("/download/{serverName}/{mediaId}",
prometheus.InstrumentHandler("download", fudge),
)
servMux.Handle("/metrics", prometheus.Handler()) servMux.Handle("/metrics", prometheus.Handler())
servMux.Handle("/api/", http.StripPrefix("/api", apiMux)) servMux.Handle("/api/", http.StripPrefix("/api", apiMux))
} }

View file

@ -49,8 +49,13 @@ INSERT INTO media_repository (media_id, media_origin, content_type, content_disp
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
` `
const selectMediaSQL = `
SELECT content_type, content_disposition, file_size, upload_name FROM media_repository WHERE media_id = $1 AND media_origin = $2
`
type mediaStatements struct { type mediaStatements struct {
insertMediaStmt *sql.Stmt insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
} }
func (s *mediaStatements) prepare(db *sql.DB) (err error) { func (s *mediaStatements) prepare(db *sql.DB) (err error) {
@ -61,6 +66,7 @@ func (s *mediaStatements) prepare(db *sql.DB) (err error) {
return statementList{ return statementList{
{&s.insertMediaStmt, insertMediaSQL}, {&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
}.prepare(db) }.prepare(db)
} }
@ -72,3 +78,12 @@ func (s *mediaStatements) insertMedia(mediaID string, mediaOrigin string, conten
) )
return err return err
} }
func (s *mediaStatements) selectMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) {
var contentType string
var contentDisposition string
var fileSize int64
var uploadName string
err := s.selectMediaStmt.QueryRow(mediaID, mediaOrigin).Scan(&contentType, &contentDisposition, &fileSize, &uploadName)
return string(contentType), string(contentDisposition), int64(fileSize), string(uploadName), err
}

View file

@ -44,3 +44,8 @@ func Open(dataSourceName string) (*Database, error) {
func (d *Database) CreateMedia(mediaID string, mediaOrigin string, contentType string, contentDisposition string, fileSize int64, uploadName string, userID string) error { func (d *Database) CreateMedia(mediaID string, mediaOrigin string, contentType string, contentDisposition string, fileSize int64, uploadName string, userID string) error {
return d.statements.insertMedia(mediaID, mediaOrigin, contentType, contentDisposition, fileSize, uploadName, userID) return d.statements.insertMedia(mediaID, mediaOrigin, contentType, contentDisposition, fileSize, uploadName, userID)
} }
// GetMedia possibly selects the metadata about previously uploaded media from the database.
func (d *Database) GetMedia(mediaID string, mediaOrigin string) (string, string, int64, string, error) {
return d.statements.selectMedia(mediaID, mediaOrigin)
}

View file

@ -0,0 +1,361 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package writers
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/mediaapi/config"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/util"
)
// DownloadRequest metadata included in or derivable from an upload request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#post-matrix-media-r0-download
type DownloadRequest struct {
MediaID string
ServerName string
}
// Validate validates the DownloadRequest fields
func (r DownloadRequest) Validate() *util.JSONResponse {
// FIXME: the following errors aren't bad JSON, rather just a bad request path
// maybe give the URL pattern in the routing, these are not even possible as the handler would not be hit...?
if r.MediaID == "" {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("mediaId must be a non-empty string"),
}
}
if r.ServerName == "" {
return &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("serverName must be a non-empty string"),
}
}
return nil
}
func jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse, logger *log.Entry) {
// Marshal JSON response into raw bytes to send as the HTTP body
resBytes, err := json.Marshal(res.JSON)
if err != nil {
logger.WithError(err).Error("Failed to marshal JSONResponse")
// this should never fail to be marshalled so drop err to the floor
res = util.MessageResponse(500, "Internal Server Error")
resBytes, _ = json.Marshal(res.JSON)
}
// Set status code and write the body
w.WriteHeader(res.Code)
logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
w.Write(resBytes)
}
// Download implements /upload
func Download(w http.ResponseWriter, req *http.Request, serverName string, mediaID string, cfg config.MediaAPI, db *storage.Database, downloadServer DownloadServer) {
logger := util.GetLogger(req.Context())
r := &DownloadRequest{
MediaID: mediaID,
ServerName: serverName,
}
if resErr := r.Validate(); resErr != nil {
jsonErrorResponse(w, *resErr, logger)
return
}
// TODO:
// - query db to look up content type and disposition and whether we have the file
logger.Warnln(r.MediaID, r.ServerName, cfg.ServerName)
contentType, contentDisposition, fileSize, filename, err := db.GetMedia(r.MediaID, r.ServerName)
if err != nil {
if strings.Compare(r.ServerName, cfg.ServerName) != 0 {
// TODO: get remote file from remote server
jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("NOT YET IMPLEMENTED")),
}, logger)
return
}
jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File %q does not exist", r.MediaID)),
}, logger)
return
}
// - read file and respond
logger.WithFields(log.Fields{
"MediaID": r.MediaID,
"ServerName": r.ServerName,
"Filename": filename,
"Content-Type": contentType,
"Content-Disposition": contentDisposition,
}).Infof("Downloading file")
logger.WithField("code", 200).Infof("Responding (%d bytes)", fileSize)
respWriter := httpResponseWriter{resp: w}
if err = downloadServer.getImage(respWriter, r.ServerName, r.MediaID); err != nil {
if respWriter.haveWritten() {
closeConnection(w)
return
}
errStatus := 500
switch err {
case errNotFound:
errStatus = 404
case errProxy:
errStatus = 502
}
http.Error(w, err.Error(), errStatus)
return
}
return
}
// DownloadServer serves and caches remote media.
type DownloadServer struct {
Client http.Client
Repository storage.Repository
LocalServerName string
}
func (handler *DownloadServer) getImage(w responseWriter, host, name string) error {
var file io.ReadCloser
var descr *storage.Description
var err error
if host == handler.LocalServerName {
file, descr, err = handler.Repository.ReaderFromLocalRepo(name)
} else {
file, descr, err = handler.Repository.ReaderFromRemoteCache(host, name)
}
if err == nil {
log.Println("Found in Cache")
w.setContentType(descr.Type)
size := strconv.FormatInt(descr.Length, 10)
w.setContentLength(size)
w.setContentSecurityPolicy()
if _, err = io.Copy(w, file); err != nil {
log.Printf("Failed to copy from cache %v\n", err)
return err
}
w.Close()
return nil
} else if !storage.IsNotExists(err) {
log.Printf("Error looking in cache: %v\n", err)
return err
}
if host == handler.LocalServerName {
// Its fatal if we can't find local files in our cache.
return errNotFound
}
respBody, desc, err := handler.fetchRemoteMedia(host, name)
if err != nil {
return err
}
defer respBody.Close()
w.setContentType(desc.Type)
if desc.Length > 0 {
w.setContentLength(strconv.FormatInt(desc.Length, 10))
}
writer, err := handler.Repository.WriterToRemoteCache(host, name, *desc)
if err != nil {
log.Printf("Failed to get cache writer %q\n", err)
return err
}
defer writer.Close()
reader := io.TeeReader(respBody, w)
if _, err := io.Copy(writer, reader); err != nil {
log.Printf("Failed to copy %q\n", err)
return err
}
writer.Finished()
log.Println("Finished conn")
return nil
}
func (handler *DownloadServer) fetchRemoteMedia(host, name string) (io.ReadCloser, *storage.Description, error) {
urls := getMatrixUrls(host)
log.Printf("Connecting to remote %q\n", urls[0])
remoteReq, err := http.NewRequest("GET", urls[0]+"/_matrix/media/v1/download/"+host+"/"+name, nil)
if err != nil {
log.Printf("Failed to connect to remote: %q\n", err)
return nil, nil, err
}
remoteReq.Header.Set("Host", host)
resp, err := handler.Client.Do(remoteReq)
if err != nil {
log.Printf("Failed to connect to remote: %q\n", err)
return nil, nil, errProxy
}
if resp.StatusCode != 200 {
resp.Body.Close()
log.Printf("Server responded with %d\n", resp.StatusCode)
if resp.StatusCode == 404 {
return nil, nil, errNotFound
}
return nil, nil, errProxy
}
desc := storage.Description{
Type: resp.Header.Get("Content-Type"),
Length: -1,
}
length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err == nil {
desc.Length = length
}
return resp.Body, &desc, nil
}
// Given a http.ResponseWriter, attempt to force close the connection.
//
// This is useful if you get a fatal error after sending the initial 200 OK
// response.
func closeConnection(w http.ResponseWriter) {
log.Println("Attempting to close connection")
// We attempt to bluntly close the connection because that is the
// best thing we can do after we've sent a 200 OK
hijack, ok := w.(http.Hijacker)
if ok {
conn, _, err := hijack.Hijack()
if err != nil {
fmt.Printf("Err trying to hijack: %v", err)
return
}
log.Println("Closing")
conn.Close()
return
}
log.Println("Not hijacker")
}
// Given a matrix server name, attempt to discover URLs to contact the server
// on.
func getMatrixUrls(host string) []string {
_, srvs, err := net.LookupSRV("matrix", "tcp", host)
if err != nil {
return []string{"https://" + host + ":8448"}
}
results := make([]string, 0, len(srvs))
for _, srv := range srvs {
if srv == nil {
continue
}
url := []string{"https://", strings.Trim(srv.Target, "."), ":", strconv.Itoa(int(srv.Port))}
results = append(results, strings.Join(url, ""))
}
// TODO: Order based on priority and weight.
return results
}
// Given a path of the form '/<host>/<name>' extract the host and name.
func getMediaIDFromPath(path string) (host, name string, err error) {
parts := strings.Split(path, "/")
if len(parts) != 3 {
err = fmt.Errorf("Invalid path %q", path)
return
}
host, name = parts[1], parts[2]
if host == "" || name == "" {
err = fmt.Errorf("Invalid path %q", path)
return
}
return
}
type responseWriter interface {
io.WriteCloser
setContentLength(string)
setContentSecurityPolicy()
setContentType(string)
haveWritten() bool
}
type httpResponseWriter struct {
resp http.ResponseWriter
written bool
}
func (writer httpResponseWriter) haveWritten() bool {
return writer.written
}
func (writer httpResponseWriter) Write(p []byte) (n int, err error) {
writer.written = true
return writer.resp.Write(p)
}
func (writer httpResponseWriter) Close() error { return nil }
func (writer httpResponseWriter) setContentType(contentType string) {
writer.resp.Header().Set("Content-Type", contentType)
}
func (writer httpResponseWriter) setContentLength(length string) {
writer.resp.Header().Set("Content-Length", length)
}
func (writer httpResponseWriter) setContentSecurityPolicy() {
contentSecurityPolicy := "default-src 'none';" +
" script-src 'none';" +
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
writer.resp.Header().Set("Content-Security-Policy", contentSecurityPolicy)
}
var errProxy = fmt.Errorf("Failed to contact remote")
var errNotFound = fmt.Errorf("Image not found")