unix socket support (#2974)

### Pull Request Checklist

<!-- Please read
https://matrix-org.github.io/dendrite/development/contributing before
submitting your pull request -->

* [x] I have added Go unit tests or [Complement integration
tests](https://github.com/matrix-org/complement) for this PR _or_ I have
justified why this PR doesn't need tests
* [x] Pull request includes a [sign off below using a legally
identifiable
name](https://matrix-org.github.io/dendrite/development/contributing#sign-off)
_or_ I have already signed off privately

Signed-off-by: `Boris Rybalkin <ribalkin@gmail.com>`

I need this for Syncloud project (https://github.com/syncloud/platform)
where I run multiple apps behind an nginx on the same RPi like device so
unix socket is very convenient to not have port conflicts between apps.
Also someone opened this Issue:
https://github.com/matrix-org/dendrite/issues/2924

---------

Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com>
This commit is contained in:
Boris Rybalkin 2023-03-01 21:57:30 +00:00 committed by GitHub
parent 6c20f8f742
commit 6b1c9eafa9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 171 additions and 29 deletions

View file

@ -16,6 +16,7 @@ package main
import ( import (
"flag" "flag"
"io/fs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -30,6 +31,12 @@ import (
) )
var ( var (
unixSocket = flag.String("unix-socket", "",
"EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
)
unixSocketPermission = flag.Int("unix-socket-permission", 0755,
"EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server",
)
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
@ -38,8 +45,23 @@ var (
func main() { func main() {
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
httpAddr := config.HTTPAddress("http://" + *httpBindAddr) httpAddr := config.ServerAddress{}
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) httpsAddr := config.ServerAddress{}
if *unixSocket == "" {
http, err := config.HTTPAddress("http://" + *httpBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address")
}
httpAddr = http
https, err := config.HTTPAddress("https://" + *httpsBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address")
}
httpsAddr = https
} else {
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
}
options := []basepkg.BaseDendriteOptions{} options := []basepkg.BaseDendriteOptions{}
base := basepkg.NewBaseDendrite(cfg, options...) base := basepkg.NewBaseDendrite(cfg, options...)
@ -92,7 +114,7 @@ func main() {
base.SetupAndServeHTTP(httpAddr, nil, nil) base.SetupAndServeHTTP(httpAddr, nil, nil)
}() }()
// Handle HTTPS if certificate and key are provided // Handle HTTPS if certificate and key are provided
if *certFile != "" && *keyFile != "" { if *unixSocket == "" && *certFile != "" && *keyFile != "" {
go func() { go func() {
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
}() }()

View file

@ -20,9 +20,11 @@ import (
"database/sql" "database/sql"
"embed" "embed"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"html/template" "html/template"
"io" "io"
"io/fs"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
@ -85,8 +87,6 @@ type BaseDendrite struct {
startupLock sync.Mutex startupLock sync.Mutex
} }
const NoListener = ""
const HTTPServerTimeout = time.Minute * 5 const HTTPServerTimeout = time.Minute * 5
type BaseDendriteOptions int type BaseDendriteOptions int
@ -345,18 +345,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
// and adds a prometheus handler under /_dendrite/metrics. // and adds a prometheus handler under /_dendrite/metrics.
func (b *BaseDendrite) SetupAndServeHTTP( func (b *BaseDendrite) SetupAndServeHTTP(
externalHTTPAddr config.HTTPAddress, externalHTTPAddr config.ServerAddress,
certFile, keyFile *string, certFile, keyFile *string,
) { ) {
// Manually unlocked right before actually serving requests, // Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work). // as we don't return from this method (defer doesn't work).
b.startupLock.Lock() b.startupLock.Lock()
externalAddr, _ := externalHTTPAddr.Address()
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
externalServ := &http.Server{ externalServ := &http.Server{
Addr: string(externalAddr), Addr: externalHTTPAddr.Address,
WriteTimeout: HTTPServerTimeout, WriteTimeout: HTTPServerTimeout,
Handler: externalRouter, Handler: externalRouter,
BaseContext: func(_ net.Listener) context.Context { BaseContext: func(_ net.Listener) context.Context {
@ -419,7 +418,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
b.startupLock.Unlock() b.startupLock.Unlock()
if externalAddr != NoListener { if externalHTTPAddr.Enabled() {
go func() { go func() {
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
logrus.Infof("Starting external listener on %s", externalServ.Addr) logrus.Infof("Starting external listener on %s", externalServ.Addr)
@ -436,6 +435,26 @@ func (b *BaseDendrite) SetupAndServeHTTP(
logrus.WithError(err).Fatal("failed to serve HTTPS") logrus.WithError(err).Fatal("failed to serve HTTPS")
} }
} }
} else {
if externalHTTPAddr.IsUnixSocket() {
err := os.Remove(externalHTTPAddr.Address)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
logrus.WithError(err).Fatal("failed to remove existing unix socket")
}
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
if err != nil {
logrus.WithError(err).Fatal("failed to serve unix socket")
}
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
if err != nil {
logrus.WithError(err).Fatal("failed to set unix socket permissions")
}
if err := externalServ.Serve(listener); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve unix socket")
}
}
} else { } else {
if err := externalServ.ListenAndServe(); err != nil { if err := externalServ.ListenAndServe(); err != nil {
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
@ -443,6 +462,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
} }
} }
}
logrus.Infof("Stopped external listener on %s", externalServ.Addr) logrus.Infof("Stopped external listener on %s", externalServ.Addr)
}() }()
} }

View file

@ -2,10 +2,13 @@ package base_test
import ( import (
"bytes" "bytes"
"context"
"embed" "embed"
"html/template" "html/template"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path"
"testing" "testing"
"time" "time"
@ -18,7 +21,7 @@ import (
//go:embed static/*.gotmpl //go:embed static/*.gotmpl
var staticContent embed.FS var staticContent embed.FS
func TestLandingPage(t *testing.T) { func TestLandingPage_Tcp(t *testing.T) {
// generate the expected result // generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{} expectedRes := &bytes.Buffer{}
@ -35,7 +38,9 @@ func TestLandingPage(t *testing.T) {
s.Close() s.Close()
// start base with the listener and wait for it to be started // start base with the listener and wait for it to be started
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) address, err := config.HTTPAddress(s.URL)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
@ -55,3 +60,43 @@ func TestLandingPage(t *testing.T) {
// Using .String() for user friendly output // Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
} }
func TestLandingPage_UnixSocket(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
"Version": internal.VersionString(),
})
assert.NoError(t, err)
b, _, _ := testrig.Base(nil)
defer b.Close()
tempDir := t.TempDir()
socket := path.Join(tempDir, "socket")
// start base with the listener and wait for it to be started
address := config.UnixSocketAddress(socket, 0755)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 100)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socket)
},
},
}
resp, err := client.Get("http://unix/")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// read the response
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(resp.Body)
assert.NoError(t, err)
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}

View file

@ -19,7 +19,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
// A Topic in kafka. // A Topic in kafka.
type Topic string type Topic string
// An Address to listen on.
type Address string
// An HTTPAddress to listen on, starting with either http:// or https://.
type HTTPAddress string
func (h HTTPAddress) Address() (Address, error) {
url, err := url.Parse(string(h))
if err != nil {
return "", err
}
return Address(url.Host), nil
}
// FileSizeBytes is a file size in bytes // FileSizeBytes is a file size in bytes
type FileSizeBytes int64 type FileSizeBytes int64

View file

@ -0,0 +1,45 @@
package config
import (
"io/fs"
"net/url"
)
const (
NetworkTCP = "tcp"
NetworkUnix = "unix"
)
type ServerAddress struct {
Address string
Scheme string
UnixSocketPermission fs.FileMode
}
func (s ServerAddress) Enabled() bool {
return s.Address != ""
}
func (s ServerAddress) IsUnixSocket() bool {
return s.Scheme == NetworkUnix
}
func (s ServerAddress) Network() string {
if s.Scheme == NetworkUnix {
return NetworkUnix
} else {
return NetworkTCP
}
}
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
}
func HTTPAddress(urlAddress string) (ServerAddress, error) {
parsedUrl, err := url.Parse(urlAddress)
if err != nil {
return ServerAddress{}, err
}
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
}

View file

@ -0,0 +1,25 @@
package config
import (
"io/fs"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHttpAddress_ParseGood(t *testing.T) {
address, err := HTTPAddress("http://localhost:123")
assert.NoError(t, err)
assert.Equal(t, "localhost:123", address.Address)
assert.Equal(t, "tcp", address.Network())
}
func TestHttpAddress_ParseBad(t *testing.T) {
_, err := HTTPAddress(":")
assert.Error(t, err)
}
func TestUnixSocketAddress_Network(t *testing.T) {
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
assert.Equal(t, "unix", address.Network())
}