From 9c8fb36a8587f827905f172d091373cfd966f6f0 Mon Sep 17 00:00:00 2001 From: Tak Wai Wong Date: Thu, 12 May 2022 14:00:10 -0700 Subject: [PATCH] verify chainId as part of login flow --- clientapi/auth/login_publickey_ethereum.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/clientapi/auth/login_publickey_ethereum.go b/clientapi/auth/login_publickey_ethereum.go index 6d842e57a..2a8cd78cb 100644 --- a/clientapi/auth/login_publickey_ethereum.go +++ b/clientapi/auth/login_publickey_ethereum.go @@ -33,7 +33,7 @@ import ( ) type LoginPublicKeyEthereum struct { - // Todo: See https://... + // https://github.com/tak-hntlabs/matrix-spec-proposals/blob/main/proposals/3782-matrix-publickey-login-spec.md#client-sends-login-request-with-authentication-data Type string `json:"type"` Address string `json:"address"` Session string `json:"session"` @@ -130,7 +130,7 @@ func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.Matri // Verify that the hash is valid for the message fields. if !verifyHash(pk.HashFieldsRaw, requiredFields.Hash) { - return false, jsonerror.InvalidParam("error verifying message hash") + return false, jsonerror.Forbidden("error verifying message hash") } // Unmarshal the hashFields for further validation @@ -141,12 +141,17 @@ func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.Matri // Error if the message is not from the expected public address if pk.Address != requiredFields.From || requiredFields.From != pk.HashFields.Address { - return false, jsonerror.InvalidParam("address") + return false, jsonerror.Forbidden("address") } // Error if the message is not for the home server if requiredFields.To != pk.HashFields.Domain { - return false, jsonerror.InvalidParam("domain") + return false, jsonerror.Forbidden("domain") + } + + // Error if the chainId is not supported by the server. + if !contains(pk.config.PublicKeyAuthentication.Ethereum.ChainIDs, authData.ChainId) { + return false, jsonerror.Forbidden("chainId") } // No errors. @@ -231,3 +236,12 @@ func verifyHash(rawStr string, expectedHash string) bool { hashStr := base64.StdEncoding.EncodeToString(hash) return expectedHash == hashStr } + +func contains(list []string, element string) bool { + for _, i := range list { + if i == element { + return true + } + } + return false +}