From a70493828ab74e85284a8cb5ef3dcc347cacd0c5 Mon Sep 17 00:00:00 2001 From: Till Faelligen Date: Wed, 27 Apr 2022 13:13:30 +0200 Subject: [PATCH] Add more device tests, fix numeric localpart query --- userapi/storage/interface.go | 1 - userapi/storage/postgres/accounts_table.go | 6 +-- userapi/storage/shared/storage.go | 15 ------ userapi/storage/sqlite3/accounts_table.go | 5 +- userapi/storage/storage_test.go | 61 ++++++++++++++++------ 5 files changed, 52 insertions(+), 36 deletions(-) diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 50595cd74..a4562cf19 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -68,7 +68,6 @@ type Device interface { CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error - RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 92311d56d..f86812f17 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); --- Create sequence for autogenerated numeric usernames -CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + @@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT nextval('numeric_username_seq')" + "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart( stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) - return + return id + 1, err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 72ae96ecc..f7212e030 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -577,21 +577,6 @@ func (d *Database) UpdateDevice( }) } -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - // RemoveDevices revokes one or more devices by deleting the entry in the database // matching with the given device IDs and user ID localpart. // If the devices don't exist, it will not return an error diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 560491c0b..e19f42d96 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COUNT(localpart) FROM account_accounts" + "SELECT MAX(CAST(localpart AS INT)) FROM account_accounts WHERE CAST(localpart AS INT) <> 0" type accountsStatements struct { db *sql.DB @@ -178,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart( stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) + if err == sql.ErrNoRows { + return 1, nil + } return id + 1, err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 4ba45fa89..e6c7d35fc 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "strconv" "testing" "time" @@ -23,7 +22,10 @@ import ( const loginTokenLifetime = time.Minute -var openIDLifetimeMS = time.Minute.Milliseconds() +var ( + openIDLifetimeMS = time.Minute.Milliseconds() + ctx = context.Background() +) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { connStr, close := test.PrepareDBConnectionString(t, dbType) @@ -38,7 +40,6 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { - ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() @@ -70,11 +71,9 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { - ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - _ = close alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -102,8 +101,8 @@ func Test_Accounts(t *testing.T) { // get guest account numeric aliceLocalpart first, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err, "failed to get new numeric localpart") - // SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead - _, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin) + // Create a new account to verify the numeric localpart is updated + _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") second, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err) @@ -129,7 +128,6 @@ func Test_Accounts(t *testing.T) { } func Test_Devices(t *testing.T) { - ctx := context.Background() alice := test.NewUser() localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -162,6 +160,7 @@ func Test_Devices(t *testing.T) { // Get devices devices, err := db.GetDevicesByLocalpart(ctx, localpart) assert.NoError(t, err, "unable to get devices by localpart") + assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) for _, dev := range devices { deviceIDs = append(deviceIDs, dev.ID) @@ -170,11 +169,49 @@ func Test_Devices(t *testing.T) { devices2, err := db.GetDevicesByID(ctx, deviceIDs) assert.NoError(t, err, "unable to get devices by id") assert.Equal(t, devices, devices2) + + // Update device + newName := "new display name" + err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + assert.NoError(t, err, "unable to update device displayname") + err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1") + assert.NoError(t, err, "unable to update device last seen") + + deviceWithID.DisplayName = newName + deviceWithID.LastSeenIP = "127.0.0.1" + deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second))) + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 2, len(devices)) + assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName) + assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP) + truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second) + assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime)) + + // create one more device and remove the devices step by step + newDeviceID := util.RandomString(16) + accessToken = util.RandomString(16) + _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + assert.NoError(t, err, "unable to create new device") + + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 3, len(devices)) + + err = db.RemoveDevices(ctx, localpart, deviceIDs) + assert.NoError(t, err, "unable to remove devices") + devices, err = db.GetDevicesByLocalpart(ctx, localpart) + assert.NoError(t, err, "unable to get device by id") + assert.Equal(t, 1, len(devices)) + + deleted, err := db.RemoveAllDevices(ctx, localpart, "") + assert.NoError(t, err, "unable to remove all devices") + assert.Equal(t, 1, len(deleted)) + assert.Equal(t, newDeviceID, deleted[0].ID) }) } func Test_KeyBackup(t *testing.T) { - ctx := context.Background() alice := test.NewUser() room := test.NewRoom(t, alice) @@ -254,7 +291,6 @@ func Test_KeyBackup(t *testing.T) { } func Test_LoginToken(t *testing.T) { - ctx := context.Background() alice := test.NewUser() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) @@ -285,7 +321,6 @@ func Test_LoginToken(t *testing.T) { } func Test_OpenID(t *testing.T) { - ctx := context.Background() alice := test.NewUser() token := util.RandomString(24) @@ -306,7 +341,6 @@ func Test_OpenID(t *testing.T) { } func Test_Profile(t *testing.T) { - ctx := context.Background() alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -345,7 +379,6 @@ func Test_Profile(t *testing.T) { } func Test_Pusher(t *testing.T) { - ctx := context.Background() alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -397,7 +430,6 @@ func Test_Pusher(t *testing.T) { } func Test_ThreePID(t *testing.T) { - ctx := context.Background() alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) @@ -435,7 +467,6 @@ func Test_ThreePID(t *testing.T) { } func Test_Notification(t *testing.T) { - ctx := context.Background() alice := test.NewUser() aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err)