unix socket support

This commit is contained in:
Boris Rybalkin 2023-02-15 09:26:35 +00:00
parent c8ca23acdb
commit 2cfdf1fe6a
6 changed files with 166 additions and 33 deletions

View file

@ -16,6 +16,7 @@ package main
import (
"flag"
"io/fs"
"github.com/sirupsen/logrus"
@ -30,16 +31,33 @@ import (
)
var (
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
unixSocket = flag.String("unix-socket", "", "The HTTP listening unix socket for the server")
unixSocketPermission = flag.Int("unix-socket-permission", 0755, "The HTTP listening unix socket permission for the server")
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
)
func main() {
cfg := setup.ParseFlags(true)
httpAddr := config.HTTPAddress("http://" + *httpBindAddr)
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr)
httpAddr := config.ServerAddress{}
httpsAddr := config.ServerAddress{}
if *unixSocket == "" {
http, err := config.HttpAddress("http://" + *httpBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address")
}
httpAddr = http
https, err := config.HttpAddress("https://" + *httpsBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address")
}
httpsAddr = https
} else {
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
}
options := []basepkg.BaseDendriteOptions{}
base := basepkg.NewBaseDendrite(cfg, options...)
@ -92,7 +110,7 @@ func main() {
base.SetupAndServeHTTP(httpAddr, nil, nil)
}()
// Handle HTTPS if certificate and key are provided
if *certFile != "" && *keyFile != "" {
if *unixSocket == "" && *certFile != "" && *keyFile != "" {
go func() {
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
}()

View file

@ -85,8 +85,6 @@ type BaseDendrite struct {
startupLock sync.Mutex
}
const NoListener = ""
const HTTPServerTimeout = time.Minute * 5
type BaseDendriteOptions int
@ -345,18 +343,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
// and adds a prometheus handler under /_dendrite/metrics.
func (b *BaseDendrite) SetupAndServeHTTP(
externalHTTPAddr config.HTTPAddress,
externalHTTPAddr config.ServerAddress,
certFile, keyFile *string,
) {
// Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work).
b.startupLock.Lock()
externalAddr, _ := externalHTTPAddr.Address()
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
externalServ := &http.Server{
Addr: string(externalAddr),
Addr: externalHTTPAddr.Address,
WriteTimeout: HTTPServerTimeout,
Handler: externalRouter,
BaseContext: func(_ net.Listener) context.Context {
@ -419,7 +416,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
b.startupLock.Unlock()
if externalAddr != NoListener {
if externalHTTPAddr.Enabled() {
go func() {
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
logrus.Infof("Starting external listener on %s", externalServ.Addr)
@ -437,9 +434,27 @@ func (b *BaseDendrite) SetupAndServeHTTP(
}
}
} else {
if err := externalServ.ListenAndServe(); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP")
if externalHTTPAddr.IsUnixSocket() {
_ = os.Remove(externalHTTPAddr.Address)
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
if err != nil {
logrus.WithError(err).Fatal("failed to serve unix socket HTTP")
}
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
if err != nil {
logrus.WithError(err).Fatal("failed to set unix socket permissions")
}
if err := externalServ.Serve(listener); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP")
}
}
} else {
if err := externalServ.ListenAndServe(); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP")
}
}
}
}

View file

@ -4,11 +4,15 @@ import (
"bytes"
"embed"
"html/template"
"net"
"net/http"
"net/http/httptest"
"path"
"testing"
"time"
"golang.org/x/net/context"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test/testrig"
@ -18,7 +22,7 @@ import (
//go:embed static/*.gotmpl
var staticContent embed.FS
func TestLandingPage(t *testing.T) {
func TestLandingPage_Tcp(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
@ -35,7 +39,9 @@ func TestLandingPage(t *testing.T) {
s.Close()
// start base with the listener and wait for it to be started
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil)
address, err := config.HttpAddress(s.URL)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10)
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
@ -55,3 +61,43 @@ func TestLandingPage(t *testing.T) {
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}
func TestLandingPage_UnixSocket(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
"Version": internal.VersionString(),
})
assert.NoError(t, err)
b, _, _ := testrig.Base(nil)
defer b.Close()
tempDir := t.TempDir()
socket := path.Join(tempDir, "socket")
// start base with the listener and wait for it to be started
address := config.UnixSocketAddress(socket, 0755)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 100)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socket)
},
},
}
resp, err := client.Get("http://unix/")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// read the response
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(resp.Body)
assert.NoError(t, err)
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}

View file

@ -19,7 +19,6 @@ import (
"encoding/pem"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"regexp"
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
// A Topic in kafka.
type Topic string
// An Address to listen on.
type Address string
// An HTTPAddress to listen on, starting with either http:// or https://.
type HTTPAddress string
func (h HTTPAddress) Address() (Address, error) {
url, err := url.Parse(string(h))
if err != nil {
return "", err
}
return Address(url.Host), nil
}
// FileSizeBytes is a file size in bytes
type FileSizeBytes int64

View file

@ -0,0 +1,45 @@
package config
import (
"io/fs"
"net/url"
)
const (
NetworkTcp = "tcp"
NetworkUnix = "unix"
)
type ServerAddress struct {
Address string
Scheme string
UnixSocketPermission fs.FileMode
}
func (s ServerAddress) Enabled() bool {
return s.Address != ""
}
func (s ServerAddress) IsUnixSocket() bool {
return s.Scheme == NetworkUnix
}
func (s ServerAddress) Network() string {
if s.Scheme == NetworkUnix {
return NetworkUnix
} else {
return NetworkTcp
}
}
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
}
func HttpAddress(urlAddress string) (ServerAddress, error) {
parsedUrl, err := url.Parse(urlAddress)
if err != nil {
return ServerAddress{}, err
}
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
}

View file

@ -0,0 +1,24 @@
package config
import (
"github.com/stretchr/testify/assert"
"io/fs"
"testing"
)
func TestHttpAddress_ParseGood(t *testing.T) {
address, err := HttpAddress("http://localhost:123")
assert.NoError(t, err)
assert.Equal(t, "localhost:123", address.Address)
assert.Equal(t, "tcp", address.Network())
}
func TestHttpAddress_ParseBad(t *testing.T) {
_, err := HttpAddress(":")
assert.Error(t, err)
}
func TestUnixSocketAddress_Network(t *testing.T) {
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
assert.Equal(t, "unix", address.Network())
}