mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-26 00:01:55 -06:00
unix socket support (#2974)
### Pull Request Checklist <!-- Please read https://matrix-org.github.io/dendrite/development/contributing before submitting your pull request --> * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Boris Rybalkin <ribalkin@gmail.com>` I need this for Syncloud project (https://github.com/syncloud/platform) where I run multiple apps behind an nginx on the same RPi like device so unix socket is very convenient to not have port conflicts between apps. Also someone opened this Issue: https://github.com/matrix-org/dendrite/issues/2924 --------- Co-authored-by: kegsay <kegan@matrix.org> Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com>
This commit is contained in:
parent
6c20f8f742
commit
6b1c9eafa9
|
@ -16,6 +16,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
@ -30,6 +31,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
unixSocket = flag.String("unix-socket", "",
|
||||||
|
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
|
||||||
|
)
|
||||||
|
unixSocketPermission = flag.Int("unix-socket-permission", 0755,
|
||||||
|
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server",
|
||||||
|
)
|
||||||
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
|
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
|
||||||
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
|
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
|
||||||
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
|
||||||
|
@ -38,8 +45,23 @@ var (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cfg := setup.ParseFlags(true)
|
cfg := setup.ParseFlags(true)
|
||||||
httpAddr := config.HTTPAddress("http://" + *httpBindAddr)
|
httpAddr := config.ServerAddress{}
|
||||||
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr)
|
httpsAddr := config.ServerAddress{}
|
||||||
|
if *unixSocket == "" {
|
||||||
|
http, err := config.HTTPAddress("http://" + *httpBindAddr)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to parse http address")
|
||||||
|
}
|
||||||
|
httpAddr = http
|
||||||
|
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to parse https address")
|
||||||
|
}
|
||||||
|
httpsAddr = https
|
||||||
|
} else {
|
||||||
|
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
|
||||||
|
}
|
||||||
|
|
||||||
options := []basepkg.BaseDendriteOptions{}
|
options := []basepkg.BaseDendriteOptions{}
|
||||||
|
|
||||||
base := basepkg.NewBaseDendrite(cfg, options...)
|
base := basepkg.NewBaseDendrite(cfg, options...)
|
||||||
|
@ -92,7 +114,7 @@ func main() {
|
||||||
base.SetupAndServeHTTP(httpAddr, nil, nil)
|
base.SetupAndServeHTTP(httpAddr, nil, nil)
|
||||||
}()
|
}()
|
||||||
// Handle HTTPS if certificate and key are provided
|
// Handle HTTPS if certificate and key are provided
|
||||||
if *certFile != "" && *keyFile != "" {
|
if *unixSocket == "" && *certFile != "" && *keyFile != "" {
|
||||||
go func() {
|
go func() {
|
||||||
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
|
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -20,9 +20,11 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
@ -85,8 +87,6 @@ type BaseDendrite struct {
|
||||||
startupLock sync.Mutex
|
startupLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
const NoListener = ""
|
|
||||||
|
|
||||||
const HTTPServerTimeout = time.Minute * 5
|
const HTTPServerTimeout = time.Minute * 5
|
||||||
|
|
||||||
type BaseDendriteOptions int
|
type BaseDendriteOptions int
|
||||||
|
@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
|
||||||
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
|
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
|
||||||
// and adds a prometheus handler under /_dendrite/metrics.
|
// and adds a prometheus handler under /_dendrite/metrics.
|
||||||
func (b *BaseDendrite) SetupAndServeHTTP(
|
func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
externalHTTPAddr config.HTTPAddress,
|
externalHTTPAddr config.ServerAddress,
|
||||||
certFile, keyFile *string,
|
certFile, keyFile *string,
|
||||||
) {
|
) {
|
||||||
// Manually unlocked right before actually serving requests,
|
// Manually unlocked right before actually serving requests,
|
||||||
// as we don't return from this method (defer doesn't work).
|
// as we don't return from this method (defer doesn't work).
|
||||||
b.startupLock.Lock()
|
b.startupLock.Lock()
|
||||||
externalAddr, _ := externalHTTPAddr.Address()
|
|
||||||
|
|
||||||
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
|
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
|
||||||
|
|
||||||
externalServ := &http.Server{
|
externalServ := &http.Server{
|
||||||
Addr: string(externalAddr),
|
Addr: externalHTTPAddr.Address,
|
||||||
WriteTimeout: HTTPServerTimeout,
|
WriteTimeout: HTTPServerTimeout,
|
||||||
Handler: externalRouter,
|
Handler: externalRouter,
|
||||||
BaseContext: func(_ net.Listener) context.Context {
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
|
|
||||||
b.startupLock.Unlock()
|
b.startupLock.Unlock()
|
||||||
|
|
||||||
if externalAddr != NoListener {
|
if externalHTTPAddr.Enabled() {
|
||||||
go func() {
|
go func() {
|
||||||
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
|
||||||
logrus.Infof("Starting external listener on %s", externalServ.Addr)
|
logrus.Infof("Starting external listener on %s", externalServ.Addr)
|
||||||
|
@ -437,9 +436,30 @@ func (b *BaseDendrite) SetupAndServeHTTP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := externalServ.ListenAndServe(); err != nil {
|
if externalHTTPAddr.IsUnixSocket() {
|
||||||
if err != http.ErrServerClosed {
|
err := os.Remove(externalHTTPAddr.Address)
|
||||||
logrus.WithError(err).Fatal("failed to serve HTTP")
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
logrus.WithError(err).Fatal("failed to remove existing unix socket")
|
||||||
|
}
|
||||||
|
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve unix socket")
|
||||||
|
}
|
||||||
|
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatal("failed to set unix socket permissions")
|
||||||
|
}
|
||||||
|
if err := externalServ.Serve(listener); err != nil {
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve unix socket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if err := externalServ.ListenAndServe(); err != nil {
|
||||||
|
if err != http.ErrServerClosed {
|
||||||
|
logrus.WithError(err).Fatal("failed to serve HTTP")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,13 @@ package base_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -18,7 +21,7 @@ import (
|
||||||
//go:embed static/*.gotmpl
|
//go:embed static/*.gotmpl
|
||||||
var staticContent embed.FS
|
var staticContent embed.FS
|
||||||
|
|
||||||
func TestLandingPage(t *testing.T) {
|
func TestLandingPage_Tcp(t *testing.T) {
|
||||||
// generate the expected result
|
// generate the expected result
|
||||||
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
||||||
expectedRes := &bytes.Buffer{}
|
expectedRes := &bytes.Buffer{}
|
||||||
|
@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) {
|
||||||
s.Close()
|
s.Close()
|
||||||
|
|
||||||
// start base with the listener and wait for it to be started
|
// start base with the listener and wait for it to be started
|
||||||
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil)
|
address, err := config.HTTPAddress(s.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
go b.SetupAndServeHTTP(address, nil, nil)
|
||||||
time.Sleep(time.Millisecond * 10)
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
|
||||||
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
|
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
|
||||||
|
@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) {
|
||||||
// Using .String() for user friendly output
|
// Using .String() for user friendly output
|
||||||
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLandingPage_UnixSocket(t *testing.T) {
|
||||||
|
// generate the expected result
|
||||||
|
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
|
||||||
|
expectedRes := &bytes.Buffer{}
|
||||||
|
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
|
||||||
|
"Version": internal.VersionString(),
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
b, _, _ := testrig.Base(nil)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
socket := path.Join(tempDir, "socket")
|
||||||
|
// start base with the listener and wait for it to be started
|
||||||
|
address := config.UnixSocketAddress(socket, 0755)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
go b.SetupAndServeHTTP(address, nil, nil)
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return net.Dial("unix", socket)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := client.Get("http://unix/")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// read the response
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
_, err = buf.ReadFrom(resp.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Using .String() for user friendly output
|
||||||
|
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
|
||||||
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
|
||||||
// A Topic in kafka.
|
// A Topic in kafka.
|
||||||
type Topic string
|
type Topic string
|
||||||
|
|
||||||
// An Address to listen on.
|
|
||||||
type Address string
|
|
||||||
|
|
||||||
// An HTTPAddress to listen on, starting with either http:// or https://.
|
|
||||||
type HTTPAddress string
|
|
||||||
|
|
||||||
func (h HTTPAddress) Address() (Address, error) {
|
|
||||||
url, err := url.Parse(string(h))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return Address(url.Host), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileSizeBytes is a file size in bytes
|
// FileSizeBytes is a file size in bytes
|
||||||
type FileSizeBytes int64
|
type FileSizeBytes int64
|
||||||
|
|
||||||
|
|
45
setup/config/config_address.go
Normal file
45
setup/config/config_address.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetworkTCP = "tcp"
|
||||||
|
NetworkUnix = "unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerAddress struct {
|
||||||
|
Address string
|
||||||
|
Scheme string
|
||||||
|
UnixSocketPermission fs.FileMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) Enabled() bool {
|
||||||
|
return s.Address != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) IsUnixSocket() bool {
|
||||||
|
return s.Scheme == NetworkUnix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s ServerAddress) Network() string {
|
||||||
|
if s.Scheme == NetworkUnix {
|
||||||
|
return NetworkUnix
|
||||||
|
} else {
|
||||||
|
return NetworkTCP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
|
||||||
|
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HTTPAddress(urlAddress string) (ServerAddress, error) {
|
||||||
|
parsedUrl, err := url.Parse(urlAddress)
|
||||||
|
if err != nil {
|
||||||
|
return ServerAddress{}, err
|
||||||
|
}
|
||||||
|
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
|
||||||
|
}
|
25
setup/config/config_address_test.go
Normal file
25
setup/config/config_address_test.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHttpAddress_ParseGood(t *testing.T) {
|
||||||
|
address, err := HTTPAddress("http://localhost:123")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "localhost:123", address.Address)
|
||||||
|
assert.Equal(t, "tcp", address.Network())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpAddress_ParseBad(t *testing.T) {
|
||||||
|
_, err := HTTPAddress(":")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnixSocketAddress_Network(t *testing.T) {
|
||||||
|
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
|
||||||
|
assert.Equal(t, "unix", address.Network())
|
||||||
|
}
|
Loading…
Reference in a new issue