mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-03 12:13:09 -06:00
Add more device tests, fix numeric localpart query
This commit is contained in:
parent
3347bc81f5
commit
a70493828a
|
|
@ -68,7 +68,6 @@ type Device interface {
|
||||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
-- Create sequence for autogenerated numeric usernames
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
|
|
@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT nextval('numeric_username_seq')"
|
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
|
@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
return
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -577,21 +577,6 @@ func (d *Database) UpdateDevice(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
// matching with the given device IDs and user ID localpart.
|
// matching with the given device IDs and user ID localpart.
|
||||||
// If the devices don't exist, it will not return an error
|
// If the devices don't exist, it will not return an error
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COUNT(localpart) FROM account_accounts"
|
"SELECT MAX(CAST(localpart AS INT)) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
|
@ -178,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
return id + 1, err
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -23,7 +22,10 @@ import (
|
||||||
|
|
||||||
const loginTokenLifetime = time.Minute
|
const loginTokenLifetime = time.Minute
|
||||||
|
|
||||||
var openIDLifetimeMS = time.Minute.Milliseconds()
|
var (
|
||||||
|
openIDLifetimeMS = time.Minute.Milliseconds()
|
||||||
|
ctx = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
|
@ -38,7 +40,6 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
|
||||||
|
|
||||||
// Tests storing and getting account data
|
// Tests storing and getting account data
|
||||||
func Test_AccountData(t *testing.T) {
|
func Test_AccountData(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
@ -70,11 +71,9 @@ func Test_AccountData(t *testing.T) {
|
||||||
|
|
||||||
// Tests the creation of accounts
|
// Tests the creation of accounts
|
||||||
func Test_Accounts(t *testing.T) {
|
func Test_Accounts(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
_ = close
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -102,8 +101,8 @@ func Test_Accounts(t *testing.T) {
|
||||||
// get guest account numeric aliceLocalpart
|
// get guest account numeric aliceLocalpart
|
||||||
first, err := db.GetNewNumericLocalpart(ctx)
|
first, err := db.GetNewNumericLocalpart(ctx)
|
||||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||||
// SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead
|
// Create a new account to verify the numeric localpart is updated
|
||||||
_, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin)
|
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
second, err := db.GetNewNumericLocalpart(ctx)
|
second, err := db.GetNewNumericLocalpart(ctx)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -129,7 +128,6 @@ func Test_Accounts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Devices(t *testing.T) {
|
func Test_Devices(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -162,6 +160,7 @@ func Test_Devices(t *testing.T) {
|
||||||
// Get devices
|
// Get devices
|
||||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
assert.NoError(t, err, "unable to get devices by localpart")
|
assert.NoError(t, err, "unable to get devices by localpart")
|
||||||
|
assert.Equal(t, 2, len(devices))
|
||||||
deviceIDs := make([]string, 0, len(devices))
|
deviceIDs := make([]string, 0, len(devices))
|
||||||
for _, dev := range devices {
|
for _, dev := range devices {
|
||||||
deviceIDs = append(deviceIDs, dev.ID)
|
deviceIDs = append(deviceIDs, dev.ID)
|
||||||
|
|
@ -170,11 +169,49 @@ func Test_Devices(t *testing.T) {
|
||||||
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
||||||
assert.NoError(t, err, "unable to get devices by id")
|
assert.NoError(t, err, "unable to get devices by id")
|
||||||
assert.Equal(t, devices, devices2)
|
assert.Equal(t, devices, devices2)
|
||||||
|
|
||||||
|
// Update device
|
||||||
|
newName := "new display name"
|
||||||
|
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
||||||
|
assert.NoError(t, err, "unable to update device displayname")
|
||||||
|
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1")
|
||||||
|
assert.NoError(t, err, "unable to update device last seen")
|
||||||
|
|
||||||
|
deviceWithID.DisplayName = newName
|
||||||
|
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||||
|
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 2, len(devices))
|
||||||
|
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
|
||||||
|
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
|
||||||
|
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
|
||||||
|
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
|
||||||
|
|
||||||
|
// create one more device and remove the devices step by step
|
||||||
|
newDeviceID := util.RandomString(16)
|
||||||
|
accessToken = util.RandomString(16)
|
||||||
|
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||||
|
assert.NoError(t, err, "unable to create new device")
|
||||||
|
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 3, len(devices))
|
||||||
|
|
||||||
|
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
||||||
|
assert.NoError(t, err, "unable to remove devices")
|
||||||
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||||
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
|
assert.Equal(t, 1, len(devices))
|
||||||
|
|
||||||
|
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
||||||
|
assert.NoError(t, err, "unable to remove all devices")
|
||||||
|
assert.Equal(t, 1, len(deleted))
|
||||||
|
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_KeyBackup(t *testing.T) {
|
func Test_KeyBackup(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
|
@ -254,7 +291,6 @@ func Test_KeyBackup(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginToken(t *testing.T) {
|
func Test_LoginToken(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
|
|
@ -285,7 +321,6 @@ func Test_LoginToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_OpenID(t *testing.T) {
|
func Test_OpenID(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
token := util.RandomString(24)
|
token := util.RandomString(24)
|
||||||
|
|
||||||
|
|
@ -306,7 +341,6 @@ func Test_OpenID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Profile(t *testing.T) {
|
func Test_Profile(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -345,7 +379,6 @@ func Test_Profile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Pusher(t *testing.T) {
|
func Test_Pusher(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -397,7 +430,6 @@ func Test_Pusher(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ThreePID(t *testing.T) {
|
func Test_ThreePID(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -435,7 +467,6 @@ func Test_ThreePID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Notification(t *testing.T) {
|
func Test_Notification(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
alice := test.NewUser()
|
alice := test.NewUser()
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue