unix socket support

This commit is contained in:
Boris Rybalkin 2023-02-15 09:26:35 +00:00
parent c8ca23acdb
commit 2cfdf1fe6a
6 changed files with 166 additions and 33 deletions

View file

@ -16,6 +16,7 @@ package main
import ( import (
"flag" "flag"
"io/fs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -30,16 +31,33 @@ import (
) )
var ( var (
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") unixSocket = flag.String("unix-socket", "", "The HTTP listening unix socket for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") unixSocketPermission = flag.Int("unix-socket-permission", 0755, "The HTTP listening unix socket permission for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
) )
func main() { func main() {
cfg := setup.ParseFlags(true) cfg := setup.ParseFlags(true)
httpAddr := config.HTTPAddress("http://" + *httpBindAddr) httpAddr := config.ServerAddress{}
httpsAddr := config.HTTPAddress("https://" + *httpsBindAddr) httpsAddr := config.ServerAddress{}
if *unixSocket == "" {
http, err := config.HttpAddress("http://" + *httpBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address")
}
httpAddr = http
https, err := config.HttpAddress("https://" + *httpsBindAddr)
if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address")
}
httpsAddr = https
} else {
httpAddr = config.UnixSocketAddress(*unixSocket, fs.FileMode(*unixSocketPermission))
}
options := []basepkg.BaseDendriteOptions{} options := []basepkg.BaseDendriteOptions{}
base := basepkg.NewBaseDendrite(cfg, options...) base := basepkg.NewBaseDendrite(cfg, options...)
@ -92,7 +110,7 @@ func main() {
base.SetupAndServeHTTP(httpAddr, nil, nil) base.SetupAndServeHTTP(httpAddr, nil, nil)
}() }()
// Handle HTTPS if certificate and key are provided // Handle HTTPS if certificate and key are provided
if *certFile != "" && *keyFile != "" { if *unixSocket == "" && *certFile != "" && *keyFile != "" {
go func() { go func() {
base.SetupAndServeHTTP(httpsAddr, certFile, keyFile) base.SetupAndServeHTTP(httpsAddr, certFile, keyFile)
}() }()

View file

@ -85,8 +85,6 @@ type BaseDendrite struct {
startupLock sync.Mutex startupLock sync.Mutex
} }
const NoListener = ""
const HTTPServerTimeout = time.Minute * 5 const HTTPServerTimeout = time.Minute * 5
type BaseDendriteOptions int type BaseDendriteOptions int
@ -345,18 +343,17 @@ func (b *BaseDendrite) ConfigureAdminEndpoints() {
// SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs // SetupAndServeHTTP sets up the HTTP server to serve client & federation APIs
// and adds a prometheus handler under /_dendrite/metrics. // and adds a prometheus handler under /_dendrite/metrics.
func (b *BaseDendrite) SetupAndServeHTTP( func (b *BaseDendrite) SetupAndServeHTTP(
externalHTTPAddr config.HTTPAddress, externalHTTPAddr config.ServerAddress,
certFile, keyFile *string, certFile, keyFile *string,
) { ) {
// Manually unlocked right before actually serving requests, // Manually unlocked right before actually serving requests,
// as we don't return from this method (defer doesn't work). // as we don't return from this method (defer doesn't work).
b.startupLock.Lock() b.startupLock.Lock()
externalAddr, _ := externalHTTPAddr.Address()
externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath()
externalServ := &http.Server{ externalServ := &http.Server{
Addr: string(externalAddr), Addr: externalHTTPAddr.Address,
WriteTimeout: HTTPServerTimeout, WriteTimeout: HTTPServerTimeout,
Handler: externalRouter, Handler: externalRouter,
BaseContext: func(_ net.Listener) context.Context { BaseContext: func(_ net.Listener) context.Context {
@ -419,7 +416,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
b.startupLock.Unlock() b.startupLock.Unlock()
if externalAddr != NoListener { if externalHTTPAddr.Enabled() {
go func() { go func() {
var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once var externalShutdown atomic.Bool // RegisterOnShutdown can be called more than once
logrus.Infof("Starting external listener on %s", externalServ.Addr) logrus.Infof("Starting external listener on %s", externalServ.Addr)
@ -437,9 +434,27 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
} }
} else { } else {
if err := externalServ.ListenAndServe(); err != nil { if externalHTTPAddr.IsUnixSocket() {
if err != http.ErrServerClosed { _ = os.Remove(externalHTTPAddr.Address)
logrus.WithError(err).Fatal("failed to serve HTTP") listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
if err != nil {
logrus.WithError(err).Fatal("failed to serve unix socket HTTP")
}
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
if err != nil {
logrus.WithError(err).Fatal("failed to set unix socket permissions")
}
if err := externalServ.Serve(listener); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP")
}
}
} else {
if err := externalServ.ListenAndServe(); err != nil {
if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP")
}
} }
} }
} }

View file

@ -4,11 +4,15 @@ import (
"bytes" "bytes"
"embed" "embed"
"html/template" "html/template"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path"
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -18,7 +22,7 @@ import (
//go:embed static/*.gotmpl //go:embed static/*.gotmpl
var staticContent embed.FS var staticContent embed.FS
func TestLandingPage(t *testing.T) { func TestLandingPage_Tcp(t *testing.T) {
// generate the expected result // generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{} expectedRes := &bytes.Buffer{}
@ -35,7 +39,9 @@ func TestLandingPage(t *testing.T) {
s.Close() s.Close()
// start base with the listener and wait for it to be started // start base with the listener and wait for it to be started
go b.SetupAndServeHTTP(config.HTTPAddress(s.URL), nil, nil) address, err := config.HttpAddress(s.URL)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
// When hitting /, we should be redirected to /_matrix/static, which should contain the landing page // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page
@ -55,3 +61,43 @@ func TestLandingPage(t *testing.T) {
// Using .String() for user friendly output // Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
} }
func TestLandingPage_UnixSocket(t *testing.T) {
// generate the expected result
tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl"))
expectedRes := &bytes.Buffer{}
err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{
"Version": internal.VersionString(),
})
assert.NoError(t, err)
b, _, _ := testrig.Base(nil)
defer b.Close()
tempDir := t.TempDir()
socket := path.Join(tempDir, "socket")
// start base with the listener and wait for it to be started
address := config.UnixSocketAddress(socket, 0755)
assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 100)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socket)
},
},
}
resp, err := client.Get("http://unix/")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// read the response
buf := &bytes.Buffer{}
_, err = buf.ReadFrom(resp.Body)
assert.NoError(t, err)
// Using .String() for user friendly output
assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch")
}

View file

@ -19,7 +19,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -131,20 +130,6 @@ func (d DataSource) IsPostgres() bool {
// A Topic in kafka. // A Topic in kafka.
type Topic string type Topic string
// An Address to listen on.
type Address string
// An HTTPAddress to listen on, starting with either http:// or https://.
type HTTPAddress string
func (h HTTPAddress) Address() (Address, error) {
url, err := url.Parse(string(h))
if err != nil {
return "", err
}
return Address(url.Host), nil
}
// FileSizeBytes is a file size in bytes // FileSizeBytes is a file size in bytes
type FileSizeBytes int64 type FileSizeBytes int64

View file

@ -0,0 +1,45 @@
package config
import (
"io/fs"
"net/url"
)
const (
NetworkTcp = "tcp"
NetworkUnix = "unix"
)
type ServerAddress struct {
Address string
Scheme string
UnixSocketPermission fs.FileMode
}
func (s ServerAddress) Enabled() bool {
return s.Address != ""
}
func (s ServerAddress) IsUnixSocket() bool {
return s.Scheme == NetworkUnix
}
func (s ServerAddress) Network() string {
if s.Scheme == NetworkUnix {
return NetworkUnix
} else {
return NetworkTcp
}
}
func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
}
func HttpAddress(urlAddress string) (ServerAddress, error) {
parsedUrl, err := url.Parse(urlAddress)
if err != nil {
return ServerAddress{}, err
}
return ServerAddress{parsedUrl.Host, parsedUrl.Scheme, 0}, nil
}

View file

@ -0,0 +1,24 @@
package config
import (
"github.com/stretchr/testify/assert"
"io/fs"
"testing"
)
func TestHttpAddress_ParseGood(t *testing.T) {
address, err := HttpAddress("http://localhost:123")
assert.NoError(t, err)
assert.Equal(t, "localhost:123", address.Address)
assert.Equal(t, "tcp", address.Network())
}
func TestHttpAddress_ParseBad(t *testing.T) {
_, err := HttpAddress(":")
assert.Error(t, err)
}
func TestUnixSocketAddress_Network(t *testing.T) {
address := UnixSocketAddress("/tmp", fs.FileMode(0755))
assert.Equal(t, "unix", address.Network())
}