// Copyright 2017 Vector Creations Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package fileutils import ( "bufio" "context" "crypto/sha256" "encoding/base64" "fmt" "io" "os" "path/filepath" "strings" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" ) // GetPathFromBase64Hash evaluates the path to a media file from its Base64Hash // 3 subdirectories are created for more manageable browsing and use the remainder as the file name. // For example, if Base64Hash is 'qwerty', the path will be 'q/w/erty/file'. func GetPathFromBase64Hash(base64Hash types.Base64Hash, absBasePath config.Path) (string, error) { if len(base64Hash) < 3 { return "", fmt.Errorf("invalid filePath (Base64Hash too short - min 3 characters): %q", base64Hash) } if len(base64Hash) > 255 { return "", fmt.Errorf("invalid filePath (Base64Hash too long - max 255 characters): %q", base64Hash) } filePath, err := filepath.Abs(filepath.Join( string(absBasePath), string(base64Hash[0:1]), string(base64Hash[1:2]), string(base64Hash[2:]), "file", )) if err != nil { return "", fmt.Errorf("unable to construct filePath: %w", err) } // check if the absolute absBasePath is a prefix of the absolute filePath // if so, no directory escape has occurred and the filePath is valid // Note: absBasePath is already absolute if !strings.HasPrefix(filePath, string(absBasePath)) { return "", fmt.Errorf("invalid filePath (not within absBasePath %v): %v", absBasePath, filePath) } return filePath, nil } // MoveFileWithHashCheck checks for hash collisions when moving a temporary file to its final path based on metadata // The final path is based on the hash of the file. // If the final path exists and the file size matches, the file does not need to be moved. // In error cases where the file is not a duplicate, the caller may decide to remove the final path. // Returns the final path of the file, whether it is a duplicate and an error. func MoveFileWithHashCheck(tmpDir types.Path, mediaMetadata *types.MediaMetadata, absBasePath config.Path, logger *log.Entry) (types.Path, bool, error) { // Note: in all error and success cases, we need to remove the temporary directory defer RemoveDir(tmpDir, logger) duplicate := false finalPath, err := GetPathFromBase64Hash(mediaMetadata.Base64Hash, absBasePath) if err != nil { return "", duplicate, fmt.Errorf("failed to get file path from metadata: %w", err) } var stat os.FileInfo // Note: The double-negative is intentional as os.IsExist(err) != !os.IsNotExist(err). // The functions are error checkers to be used in different cases. if stat, err = os.Stat(finalPath); !os.IsNotExist(err) { duplicate = true if stat.Size() == int64(mediaMetadata.FileSizeBytes) { return types.Path(finalPath), duplicate, nil } return "", duplicate, fmt.Errorf("downloaded file with hash collision but different file size (%v)", finalPath) } err = moveFile( types.Path(filepath.Join(string(tmpDir), "content")), types.Path(finalPath), ) if err != nil { return "", duplicate, fmt.Errorf("failed to move file to final destination (%v): %w", finalPath, err) } return types.Path(finalPath), duplicate, nil } // RemoveDir removes a directory and logs a warning in case of errors func RemoveDir(dir types.Path, logger *log.Entry) { dirErr := os.RemoveAll(string(dir)) if dirErr != nil { logger.WithError(dirErr).WithField("dir", dir).Warn("failed to remove directory") } } // WriteTempFile writes to a new temporary file. // The file is deleted if there was an error while writing. func WriteTempFile( ctx context.Context, reqReader io.Reader, absBasePath config.Path, ) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) { size = -1 logger := util.GetLogger(ctx) tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath) if err != nil { return } defer func() { err2 := tmpFile.Close() if err == nil { err = err2 } }() // Hash the file data. The hash will be returned. The hash is useful as a // method of deduplicating files to save storage, as well as a way to conduct // integrity checks on the file data in the repository. hasher := sha256.New() teeReader := io.TeeReader(reqReader, hasher) bytesWritten, err := io.Copy(tmpFileWriter, teeReader) if err != nil && err != io.EOF { RemoveDir(tmpDir, logger) return } err = tmpFileWriter.Flush() if err != nil { RemoveDir(tmpDir, logger) return } hash = types.Base64Hash(base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)[:])) size = types.FileSizeBytes(bytesWritten) path = tmpDir return } // moveFile attempts to move the file src to dst func moveFile(src types.Path, dst types.Path) error { dstDir := filepath.Dir(string(dst)) err := os.MkdirAll(dstDir, 0770) if err != nil { return fmt.Errorf("failed to make directory: %w", err) } err = os.Rename(string(src), string(dst)) if err != nil { return fmt.Errorf("failed to move directory: %w", err) } return nil } func MoveFile(src types.Path, dst types.Path) error { return moveFile(src, dst) } func createTempFileWriter(absBasePath config.Path) (*bufio.Writer, *os.File, types.Path, error) { tmpDir, err := createTempDir(absBasePath) if err != nil { return nil, nil, "", fmt.Errorf("failed to create temp dir: %w", err) } writer, tmpFile, err := createFileWriter(tmpDir) if err != nil { return nil, nil, "", fmt.Errorf("failed to create file writer: %w", err) } return writer, tmpFile, tmpDir, nil } // createTempDir creates a tmp/ directory within baseDirectory and returns its path func createTempDir(baseDirectory config.Path) (types.Path, error) { baseTmpDir := filepath.Join(string(baseDirectory), "tmp") if err := os.MkdirAll(baseTmpDir, 0770); err != nil { return "", fmt.Errorf("failed to create base temp dir: %w", err) } tmpDir, err := os.MkdirTemp(baseTmpDir, "") if err != nil { return "", fmt.Errorf("failed to create temp dir: %w", err) } return types.Path(tmpDir), nil } // createFileWriter creates a buffered file writer with a new file // The caller should flush the writer before closing the file. // Returns the file handle as it needs to be closed when writing is complete func createFileWriter(directory types.Path) (*bufio.Writer, *os.File, error) { filePath := filepath.Join(string(directory), "content") file, err := os.Create(filePath) if err != nil { return nil, nil, fmt.Errorf("failed to create file: %w", err) } return bufio.NewWriter(file), file, nil }