Takwaiw/dendrite publickey (#2)

* Implementation of MSC 3782 Add publickey login as a new auth type.

Co-authored-by: Tak Wai Wong <takwaiw@gmail.com>
This commit is contained in:
Tak Wai Wong 2022-05-12 16:47:48 -07:00 committed by Tak Wai Wong
parent 6c7d03c925
commit b2717519f8
11 changed files with 36 additions and 56 deletions

View file

@ -30,10 +30,10 @@ import (
type LoginPublicKeyHandler interface { type LoginPublicKeyHandler interface {
AccountExists(ctx context.Context) (string, *jsonerror.MatrixError) AccountExists(ctx context.Context) (string, *jsonerror.MatrixError)
IsValidUserIdForRegistration(userId string) bool
CreateLogin() *Login CreateLogin() *Login
GetSession() string GetSession() string
GetType() string GetType() string
IsValidUserId(userId string) bool
ValidateLoginResponse() (bool, *jsonerror.MatrixError) ValidateLoginResponse() (bool, *jsonerror.MatrixError)
} }

View file

@ -73,6 +73,10 @@ func (pk LoginPublicKeyEthereum) AccountExists(ctx context.Context) (string, *js
return "", jsonerror.Forbidden("the address is incorrect, or the account does not exist.") return "", jsonerror.Forbidden("the address is incorrect, or the account does not exist.")
} }
if !pk.IsValidUserId(localPart) {
return "", jsonerror.InvalidUsername("the username is not valid.")
}
res := userapi.QueryAccountAvailabilityResponse{} res := userapi.QueryAccountAvailabilityResponse{}
if err := pk.userAPI.QueryAccountAvailability(ctx, &userapi.QueryAccountAvailabilityRequest{ if err := pk.userAPI.QueryAccountAvailability(ctx, &userapi.QueryAccountAvailabilityRequest{
Localpart: localPart, Localpart: localPart,
@ -80,7 +84,7 @@ func (pk LoginPublicKeyEthereum) AccountExists(ctx context.Context) (string, *js
return "", jsonerror.Unknown("failed to check availability: " + err.Error()) return "", jsonerror.Unknown("failed to check availability: " + err.Error())
} }
if res.Available { if localPart == "" || res.Available {
return "", jsonerror.Forbidden("the address is incorrect, account does not exist") return "", jsonerror.Forbidden("the address is incorrect, account does not exist")
} }
@ -89,7 +93,7 @@ func (pk LoginPublicKeyEthereum) AccountExists(ctx context.Context) (string, *js
var validChainAgnosticIdRegex = regexp.MustCompile("^eip155=3a[0-9]+=3a0x[0-9a-fA-F]+$") var validChainAgnosticIdRegex = regexp.MustCompile("^eip155=3a[0-9]+=3a0x[0-9a-fA-F]+$")
func (pk LoginPublicKeyEthereum) IsValidUserIdForRegistration(userId string) bool { func (pk LoginPublicKeyEthereum) IsValidUserId(userId string) bool {
// Verify that the user ID is a valid one according to spec. // Verify that the user ID is a valid one according to spec.
// https://github.com/ChainAgnostic/CAIPs/blob/master/CAIPs/caip-10.md // https://github.com/ChainAgnostic/CAIPs/blob/master/CAIPs/caip-10.md
@ -100,9 +104,9 @@ func (pk LoginPublicKeyEthereum) IsValidUserIdForRegistration(userId string) boo
isValid := validChainAgnosticIdRegex.MatchString(userId) isValid := validChainAgnosticIdRegex.MatchString(userId)
// In addition, double check that the user ID for registration // In addition, double check that the user ID
// matches the authentication data in the request. // matches the authentication data in the request.
return isValid && userId == pk.UserId return isValid && strings.ToLower(userId) == pk.UserId
} }
func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.MatrixError) { func (pk LoginPublicKeyEthereum) ValidateLoginResponse() (bool, *jsonerror.MatrixError) {

View file

@ -246,7 +246,7 @@ func (u *UserInteractive) ResponseWithChallenge(sessionID string, response inter
// Verify returns an error/challenge response to send to the client, or nil if the user is authenticated. // Verify returns an error/challenge response to send to the client, or nil if the user is authenticated.
// `bodyBytes` is the HTTP request body which must contain an `auth` key. // `bodyBytes` is the HTTP request body which must contain an `auth` key.
// Returns the login that was verified for additional checks if required. // Returns the login that was verified for additional checks if required.
func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device *api.Device) (*Login, *util.JSONResponse) { func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte) (*Login, *util.JSONResponse) {
// TODO: rate limit // TODO: rate limit
// "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body" // "A client should first make a request with no auth parameter. The homeserver returns an HTTP 401 response, with a JSON body"

View file

@ -17,11 +17,6 @@ var (
serverName = gomatrixserverlib.ServerName("example.com") serverName = gomatrixserverlib.ServerName("example.com")
// space separated localpart+password -> account // space separated localpart+password -> account
lookup = make(map[string]*api.Account) lookup = make(map[string]*api.Account)
device = &api.Device{
AccessToken: "flibble",
DisplayName: "My Device",
ID: "device_id_goes_here",
}
) )
type fakeAccountDatabase struct { type fakeAccountDatabase struct {
@ -60,7 +55,7 @@ func setup() *UserInteractive {
func TestUserInteractiveChallenge(t *testing.T) { func TestUserInteractiveChallenge(t *testing.T) {
uia := setup() uia := setup()
// no auth key results in a challenge // no auth key results in a challenge
_, errRes := uia.Verify(ctx, []byte(`{}`), device) _, errRes := uia.Verify(ctx, []byte(`{}`))
if errRes == nil { if errRes == nil {
t.Fatalf("Verify succeeded with {} but expected failure") t.Fatalf("Verify succeeded with {} but expected failure")
} }
@ -100,7 +95,7 @@ func TestUserInteractivePasswordLogin(t *testing.T) {
}`), }`),
} }
for _, tc := range testCases { for _, tc := range testCases {
_, errRes := uia.Verify(ctx, tc, device) _, errRes := uia.Verify(ctx, tc)
if errRes != nil { if errRes != nil {
t.Errorf("Verify failed but expected success for request: %s - got %+v", string(tc), errRes) t.Errorf("Verify failed but expected success for request: %s - got %+v", string(tc), errRes)
} }
@ -181,7 +176,7 @@ func TestUserInteractivePasswordBadLogin(t *testing.T) {
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
_, errRes := uia.Verify(ctx, tc.body, device) _, errRes := uia.Verify(ctx, tc.body)
if errRes == nil { if errRes == nil {
t.Errorf("Verify succeeded but expected failure for request: %s", string(tc.body)) t.Errorf("Verify succeeded but expected failure for request: %s", string(tc.body))
continue continue

View file

@ -28,7 +28,7 @@ func Deactivate(
} }
} }
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, deviceAPI) login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes)
if errRes != nil { if errRes != nil {
return *errRes return *errRes
} }

View file

@ -198,7 +198,7 @@ func DeleteDeviceById(
sessionID = s sessionID = s
} }
login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes, device) login, errRes := userInteractiveAuth.Verify(ctx, bodyBytes)
if errRes != nil { if errRes != nil {
switch data := errRes.JSON.(type) { switch data := errRes.JSON.(type) {
case auth.Challenge: case auth.Challenge:

View file

@ -156,6 +156,13 @@ func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
}) })
} }
func (d *sessionsDict) hasSession(sessionID string) bool {
d.RLock()
defer d.RUnlock()
_, ok := d.sessions[sessionID]
return ok
}
// addCompletedSessionStage records that a session has completed an auth stage // addCompletedSessionStage records that a session has completed an auth stage
// also starts a timer to delete the session once done. // also starts a timer to delete the session once done.
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) { func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {

View file

@ -62,14 +62,14 @@ func handlePublicKeyRegistration(
return false, authtypes.LoginStagePublicKeyNewRegistration, nil return false, authtypes.LoginStagePublicKeyNewRegistration, nil
} }
if _, ok := sessions.sessions[authHandler.GetSession()]; !ok { if !sessions.hasSession(authHandler.GetSession()) {
return false, "", &util.JSONResponse{ return false, "", &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.Unknown("the session ID is missing or unknown."), JSON: jsonerror.Unknown("the session ID is missing or unknown."),
} }
} }
isValidUserId := authHandler.IsValidUserIdForRegistration(r.Username) isValidUserId := authHandler.IsValidUserId(r.Username)
if !isValidUserId { if !isValidUserId {
return false, "", &util.JSONResponse{ return false, "", &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,

View file

@ -54,7 +54,7 @@ type ClientAPI struct {
PasswordAuthenticationDisabled bool `yaml:"password_authentication_disabled"` PasswordAuthenticationDisabled bool `yaml:"password_authentication_disabled"`
// Public key authentication // Public key authentication
PublicKeyAuthentication publicKeyAuthentication `yaml:"public_key_authentication"` PublicKeyAuthentication PublicKeyAuthentication `yaml:"public_key_authentication"`
} }
func (c *ClientAPI) Defaults(generate bool) { func (c *ClientAPI) Defaults(generate bool) {

View file

@ -1,52 +1,40 @@
package config package config
import ( import (
"math/rand"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
var nonceLength = 32
type AuthParams interface { type AuthParams interface {
GetParams() interface{} GetParams() interface{}
GetNonce() string
} }
type EthereumAuthParams struct { type EthereumAuthParams struct {
Version uint `json:"version"` Version uint `json:"version"`
ChainIDs []int `json:"chain_ids"` ChainIDs []int `json:"chain_ids"`
Nonce string `json:"nonce"`
} }
func (p EthereumAuthParams) GetParams() interface{} { func (p EthereumAuthParams) GetParams() interface{} {
copyP := p copyP := p
copyP.ChainIDs = make([]int, len(p.ChainIDs)) copyP.ChainIDs = make([]int, len(p.ChainIDs))
copy(copyP.ChainIDs, p.ChainIDs) copy(copyP.ChainIDs, p.ChainIDs)
copyP.Nonce = newNonce(nonceLength)
return copyP return copyP
} }
func (p EthereumAuthParams) GetNonce() string { type EthereumAuthConfig struct {
return p.Nonce
}
type ethereumAuthConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Version uint `yaml:"version"` Version uint `yaml:"version"`
ChainIDs []int `yaml:"chain_ids"` ChainIDs []int `yaml:"chain_ids"`
} }
type publicKeyAuthentication struct { type PublicKeyAuthentication struct {
Ethereum ethereumAuthConfig `yaml:"ethereum"` Ethereum EthereumAuthConfig `yaml:"ethereum"`
} }
func (pk *publicKeyAuthentication) Enabled() bool { func (pk *PublicKeyAuthentication) Enabled() bool {
return pk.Ethereum.Enabled return pk.Ethereum.Enabled
} }
func (pk *publicKeyAuthentication) GetPublicKeyRegistrationFlows() []authtypes.Flow { func (pk *PublicKeyAuthentication) GetPublicKeyRegistrationFlows() []authtypes.Flow {
var flows []authtypes.Flow var flows []authtypes.Flow
if pk.Ethereum.Enabled { if pk.Ethereum.Enabled {
flows = append(flows, authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypePublicKeyEthereum}}) flows = append(flows, authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypePublicKeyEthereum}})
@ -55,29 +43,15 @@ func (pk *publicKeyAuthentication) GetPublicKeyRegistrationFlows() []authtypes.F
return flows return flows
} }
func (pk *publicKeyAuthentication) GetPublicKeyRegistrationParams() map[string]interface{} { func (pk *PublicKeyAuthentication) GetPublicKeyRegistrationParams() map[string]interface{} {
params := make(map[string]interface{}) params := make(map[string]interface{})
if pk.Ethereum.Enabled { if pk.Ethereum.Enabled {
p := EthereumAuthParams{ p := EthereumAuthParams{
Version: pk.Ethereum.Version, Version: pk.Ethereum.Version,
ChainIDs: pk.Ethereum.ChainIDs, ChainIDs: pk.Ethereum.ChainIDs,
Nonce: "",
} }
params[authtypes.LoginTypePublicKeyEthereum] = p params[authtypes.LoginTypePublicKeyEthereum] = p
} }
return params return params
} }
const lettersAndNumbers = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func newNonce(n int) string {
nonce := make([]byte, n)
rand.Seed(time.Now().UnixNano())
for i := range nonce {
nonce[i] = lettersAndNumbers[rand.Int63()%int64(len(lettersAndNumbers))]
}
return string(nonce)
}

View file

@ -300,10 +300,10 @@ func Test_UserStatistics(t *testing.T) {
}, },
R30UsersV2: map[string]int64{ R30UsersV2: map[string]int64{
"ios": 0, "ios": 0,
"android": 0, "android": 1,
"web": 0, "web": 1,
"electron": 0, "electron": 0,
"all": 0, "all": 2,
}, },
AllUsers: 6, AllUsers: 6,
NonBridgedUsers: 5, NonBridgedUsers: 5,