fix review comments

This commit is contained in:
Boris Rybalkin 2023-02-23 13:24:14 +00:00
parent 6f0e9ecbab
commit 3744994ea8
5 changed files with 25 additions and 19 deletions

View file

@ -31,12 +31,16 @@ import (
) )
var ( var (
unixSocket = flag.String("unix-socket", "", "The HTTP listening unix socket for the server") unixSocket = flag.String("unix-socket", "",
unixSocketPermission = flag.Int("unix-socket-permission", 0755, "The HTTP listening unix socket permission for the server") "EXPERIMENTAL(unstable): The HTTP listening unix socket for the server (disables http[s]-bind-address feature)",
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server") )
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server") unixSocketPermission = flag.Int("unix-socket-permission", 0755,
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS") "EXPERIMENTAL(unstable): The HTTP listening unix socket permission for the server",
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS") )
httpBindAddr = flag.String("http-bind-address", ":8008", "The HTTP listening port for the server")
httpsBindAddr = flag.String("https-bind-address", ":8448", "The HTTPS listening port for the server")
certFile = flag.String("tls-cert", "", "The PEM formatted X509 certificate to use for TLS")
keyFile = flag.String("tls-key", "", "The PEM private key to use for TLS")
) )
func main() { func main() {
@ -44,12 +48,12 @@ func main() {
httpAddr := config.ServerAddress{} httpAddr := config.ServerAddress{}
httpsAddr := config.ServerAddress{} httpsAddr := config.ServerAddress{}
if *unixSocket == "" { if *unixSocket == "" {
http, err := config.HttpAddress("http://" + *httpBindAddr) http, err := config.HTTPAddress("http://" + *httpBindAddr)
if err != nil { if err != nil {
logrus.WithError(err).Fatalf("Failed to parse http address") logrus.WithError(err).Fatalf("Failed to parse http address")
} }
httpAddr = http httpAddr = http
https, err := config.HttpAddress("https://" + *httpsBindAddr) https, err := config.HTTPAddress("https://" + *httpsBindAddr)
if err != nil { if err != nil {
logrus.WithError(err).Fatalf("Failed to parse https address") logrus.WithError(err).Fatalf("Failed to parse https address")
} }

View file

@ -435,10 +435,13 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
} else { } else {
if externalHTTPAddr.IsUnixSocket() { if externalHTTPAddr.IsUnixSocket() {
_ = os.Remove(externalHTTPAddr.Address) err := os.RemoveAll(externalHTTPAddr.Address)
if err != nil {
logrus.WithError(err).Fatal("failed to remove existing unix socket")
}
listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address) listener, err := net.Listen(externalHTTPAddr.Network(), externalHTTPAddr.Address)
if err != nil { if err != nil {
logrus.WithError(err).Fatal("failed to serve unix socket HTTP") logrus.WithError(err).Fatal("failed to serve unix socket")
} }
err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission) err = os.Chmod(externalHTTPAddr.Address, externalHTTPAddr.UnixSocketPermission)
if err != nil { if err != nil {
@ -446,7 +449,7 @@ func (b *BaseDendrite) SetupAndServeHTTP(
} }
if err := externalServ.Serve(listener); err != nil { if err := externalServ.Serve(listener); err != nil {
if err != http.ErrServerClosed { if err != http.ErrServerClosed {
logrus.WithError(err).Fatal("failed to serve HTTP") logrus.WithError(err).Fatal("failed to serve unix socket")
} }
} }

View file

@ -2,6 +2,7 @@ package base_test
import ( import (
"bytes" "bytes"
"context"
"embed" "embed"
"html/template" "html/template"
"net" "net"
@ -11,8 +12,6 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/test/testrig"
@ -39,7 +38,7 @@ func TestLandingPage_Tcp(t *testing.T) {
s.Close() s.Close()
// start base with the listener and wait for it to be started // start base with the listener and wait for it to be started
address, err := config.HttpAddress(s.URL) address, err := config.HTTPAddress(s.URL)
assert.NoError(t, err) assert.NoError(t, err)
go b.SetupAndServeHTTP(address, nil, nil) go b.SetupAndServeHTTP(address, nil, nil)
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)

View file

@ -6,7 +6,7 @@ import (
) )
const ( const (
NetworkTcp = "tcp" NetworkTCP = "tcp"
NetworkUnix = "unix" NetworkUnix = "unix"
) )
@ -28,7 +28,7 @@ func (s ServerAddress) Network() string {
if s.Scheme == NetworkUnix { if s.Scheme == NetworkUnix {
return NetworkUnix return NetworkUnix
} else { } else {
return NetworkTcp return NetworkTCP
} }
} }
@ -36,7 +36,7 @@ func UnixSocketAddress(path string, perm fs.FileMode) ServerAddress {
return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm} return ServerAddress{Address: path, Scheme: NetworkUnix, UnixSocketPermission: perm}
} }
func HttpAddress(urlAddress string) (ServerAddress, error) { func HTTPAddress(urlAddress string) (ServerAddress, error) {
parsedUrl, err := url.Parse(urlAddress) parsedUrl, err := url.Parse(urlAddress)
if err != nil { if err != nil {
return ServerAddress{}, err return ServerAddress{}, err

View file

@ -8,14 +8,14 @@ import (
) )
func TestHttpAddress_ParseGood(t *testing.T) { func TestHttpAddress_ParseGood(t *testing.T) {
address, err := HttpAddress("http://localhost:123") address, err := HTTPAddress("http://localhost:123")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "localhost:123", address.Address) assert.Equal(t, "localhost:123", address.Address)
assert.Equal(t, "tcp", address.Network()) assert.Equal(t, "tcp", address.Network())
} }
func TestHttpAddress_ParseBad(t *testing.T) { func TestHttpAddress_ParseBad(t *testing.T) {
_, err := HttpAddress(":") _, err := HTTPAddress(":")
assert.Error(t, err) assert.Error(t, err)
} }