From 4b977cefd5df5f223537b864608e2569af1ae6c9 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Wed, 27 Apr 2022 09:46:46 +0200 Subject: [PATCH] Add LoginToken tests --- userapi/storage/interface.go | 29 +++++++++++++++------------ userapi/storage/storage_test.go | 35 ++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 9dfb87ab3..160c99425 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -84,19 +84,7 @@ type KeyBackup interface { CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error) } -type Database interface { - Account - AccountData - Device - KeyBackup - Profile - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) - RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) - GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) - +type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is // determined by the loginTokenLifetime given to the Database constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) @@ -107,6 +95,21 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) +} + +type Database interface { + Account + AccountData + Device + KeyBackup + LoginToken + Profile + SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error) + GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 54cf7def8..1460423e6 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -18,11 +18,13 @@ import ( "golang.org/x/crypto/bcrypt" ) +const loginTokenLifetime = time.Minute + func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, "localhost", bcrypt.MinCost, time.Minute.Milliseconds(), time.Minute, "_server") + }, "localhost", bcrypt.MinCost, time.Minute.Milliseconds(), loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } @@ -246,3 +248,34 @@ func Test_KeyBackup(t *testing.T) { assert.False(t, exists) }) } + +func Test_LoginToken(t *testing.T) { + ctx := context.Background() + alice := test.NewUser() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + + // create a new token + wantLoginToken := &api.LoginTokenData{UserID: alice.ID} + + gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken) + assert.NoError(t, err, "unable to create login token") + assert.NotNil(t, gotMetadata) + assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime)) + + // get the new token + gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) + assert.NoError(t, err, "unable to get login token") + assert.NotNil(t, gotLoginToken) + assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token") + + // remove the login token again + err = db.RemoveLoginToken(ctx, gotMetadata.Token) + assert.NoError(t, err, "unable to remove login token") + + // check if the token was actually deleted + _, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) + assert.Error(t, err, "expected an error, but got none") + }) +}