diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 860aa6d38..dc0180da6 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -15,8 +15,6 @@ package routing import ( - "encoding/json" - "fmt" "net/http" "context" @@ -47,7 +45,7 @@ type loginIdentifier struct { User string `json:"user"` } -type loginWithPasswordRequest struct { +type passwordRequest struct { Identifier loginIdentifier `json:"identifier"` User string `json:"user"` // deprecated in favour of identifier Password string `json:"password"` @@ -57,11 +55,6 @@ type loginWithPasswordRequest struct { DeviceID *string `json:"device_id"` } -type loginRequest struct { - Type string `json:"type"` - loginWithPasswordRequest -} - type loginResponse struct { UserID string `json:"user_id"` AccessToken string `json:"access_token"` @@ -94,24 +87,12 @@ func Login( if resErr != nil { return *resErr } - - j, _ := json.MarshalIndent(temp, "", " ") - fmt.Println(string(j)) - - var r loginRequest - json.Unmarshal(j, &r) - - switch r.Type { - case "m.login.password": - j, _ := json.MarshalIndent(r, "", " ") - fmt.Printf("LOGIN REQUEST: %+v\n", string(j)) - switch r.Identifier.Type { - case "m.id.user": - if r.Identifier.User == "" { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("'user' must be supplied."), - } + switch r.Identifier.Type { + case "m.id.user": + if r.Identifier.User == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("'user' must be supplied."), } } acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User) @@ -139,7 +120,7 @@ func Login( return jsonerror.InternalServerError() } - dev, err := getDevice(req.Context(), r.loginWithPasswordRequest, deviceDB, acc, token) + dev, err := getDevice(req.Context(), r, deviceDB, acc, token) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -166,7 +147,7 @@ func Login( // getDevice returns a new or existing device func getDevice( ctx context.Context, - r loginWithPasswordRequest, + r passwordRequest, deviceDB devices.Database, acc *api.Account, token string, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 2076aa135..ddc80ccfa 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -51,18 +51,6 @@ var ( func main() { flag.Parse() - // Build both ends of a HTTP multiplex. - httpServer := &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 15 * time.Second, - WriteTimeout: 45 * time.Second, - IdleTimeout: 60 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - } - ygg, err := yggconn.Setup(*instanceName, *instancePeer, ".") if err != nil { panic(err) @@ -131,7 +119,7 @@ func main() { Config: base.Cfg, AccountDB: accountDB, DeviceDB: deviceDB, - Client: createClient(ygg), + Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, KafkaConsumer: base.KafkaConsumer, diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 36ea32973..399993e3e 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -11,10 +11,25 @@ import ( "time" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert" - "github.com/matrix-org/dendrite/internal/basecomponent" + "github.com/matrix-org/dendrite/internal/setup" "github.com/matrix-org/gomatrixserverlib" ) +func (n *Node) yggdialer(_, address string) (net.Conn, error) { + tokens := strings.Split(address, ":") + raw, err := hex.DecodeString(tokens[0]) + if err != nil { + return nil, fmt.Errorf("hex.DecodeString: %w", err) + } + converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) + convhex := hex.EncodeToString(converted) + return n.Dial("curve25519", convhex) +} + +func (n *Node) yggdialerctx(ctx context.Context, network, address string) (net.Conn, error) { + return n.yggdialer(network, address) +} + type yggroundtripper struct { inner *http.Transport } @@ -24,29 +39,32 @@ func (y *yggroundtripper) RoundTrip(req *http.Request) (*http.Response, error) { return y.inner.RoundTrip(req) } -func (n *Node) CreateFederationClient( - base *basecomponent.BaseDendrite, -) *gomatrixserverlib.FederationClient { - yggdialer := func(_, address string) (net.Conn, error) { - tokens := strings.Split(address, ":") - raw, err := hex.DecodeString(tokens[0]) - if err != nil { - return nil, fmt.Errorf("hex.DecodeString: %w", err) - } - converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw)) - convhex := hex.EncodeToString(converted) - return n.Dial("curve25519", convhex) - } - yggdialerctx := func(ctx context.Context, network, address string) (net.Conn, error) { - return yggdialer(network, address) - } +func (n *Node) CreateClient( + base *setup.BaseDendrite, +) *gomatrixserverlib.Client { tr := &http.Transport{} tr.RegisterProtocol( "matrix", &yggroundtripper{ inner: &http.Transport{ ResponseHeaderTimeout: 15 * time.Second, IdleConnTimeout: 60 * time.Second, - DialContext: yggdialerctx, + DialContext: n.yggdialerctx, + }, + }, + ) + return gomatrixserverlib.NewClientWithTransport(tr) +} + +func (n *Node) CreateFederationClient( + base *setup.BaseDendrite, +) *gomatrixserverlib.FederationClient { + tr := &http.Transport{} + tr.RegisterProtocol( + "matrix", &yggroundtripper{ + inner: &http.Transport{ + ResponseHeaderTimeout: 15 * time.Second, + IdleConnTimeout: 60 * time.Second, + DialContext: n.yggdialerctx, }, }, )