verify chainId as part of login flow

This commit is contained in:
Tak Wai Wong 2022-05-12 14:00:10 -07:00
parent 4665941912
commit 9c8fb36a85

View file

@ -33,7 +33,7 @@ import (
) )
type LoginPublicKeyEthereum struct { type LoginPublicKeyEthereum struct {
// Todo: See https://... // https://github.com/tak-hntlabs/matrix-spec-proposals/blob/main/proposals/3782-matrix-publickey-login-spec.md#client-sends-login-request-with-authentication-data
Type string `json:"type"` Type string `json:"type"`
Address string `json:"address"` Address string `json:"address"`
Session string `json:"session"` Session string `json:"session"`
@ -130,7 +130,7 @@ func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.Matri
// Verify that the hash is valid for the message fields. // Verify that the hash is valid for the message fields.
if !verifyHash(pk.HashFieldsRaw, requiredFields.Hash) { if !verifyHash(pk.HashFieldsRaw, requiredFields.Hash) {
return false, jsonerror.InvalidParam("error verifying message hash") return false, jsonerror.Forbidden("error verifying message hash")
} }
// Unmarshal the hashFields for further validation // Unmarshal the hashFields for further validation
@ -141,12 +141,17 @@ func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.Matri
// Error if the message is not from the expected public address // Error if the message is not from the expected public address
if pk.Address != requiredFields.From || requiredFields.From != pk.HashFields.Address { if pk.Address != requiredFields.From || requiredFields.From != pk.HashFields.Address {
return false, jsonerror.InvalidParam("address") return false, jsonerror.Forbidden("address")
} }
// Error if the message is not for the home server // Error if the message is not for the home server
if requiredFields.To != pk.HashFields.Domain { if requiredFields.To != pk.HashFields.Domain {
return false, jsonerror.InvalidParam("domain") return false, jsonerror.Forbidden("domain")
}
// Error if the chainId is not supported by the server.
if !contains(pk.config.PublicKeyAuthentication.Ethereum.ChainIDs, authData.ChainId) {
return false, jsonerror.Forbidden("chainId")
} }
// No errors. // No errors.
@ -231,3 +236,12 @@ func verifyHash(rawStr string, expectedHash string) bool {
hashStr := base64.StdEncoding.EncodeToString(hash) hashStr := base64.StdEncoding.EncodeToString(hash)
return expectedHash == hashStr return expectedHash == hashStr
} }
func contains(list []string, element string) bool {
for _, i := range list {
if i == element {
return true
}
}
return false
}