mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 03:03:10 -06:00
Add tests for parts of the userapi storage
This commit is contained in:
parent
e8be2b234f
commit
cf5acdd16f
|
|
@ -21,6 +21,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
|
@ -55,8 +56,6 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
||||
|
|
@ -272,7 +271,7 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
|||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||
db, err := userdb.NewDatabase(
|
||||
db, err := userdb.NewUserAPIDatabase(
|
||||
&b.Cfg.UserAPI.AccountDatabase,
|
||||
b.Cfg.Global.ServerName,
|
||||
b.Cfg.UserAPI.BCryptCost,
|
||||
|
|
|
|||
|
|
@ -32,13 +32,19 @@ type Profile interface {
|
|||
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Profile
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
type Account interface {
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
}
|
||||
|
||||
type AccountData interface {
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
|
@ -46,26 +52,9 @@ type Database interface {
|
|||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
|
||||
// Key backups
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
|
|
@ -83,6 +72,28 @@ type Database interface {
|
|||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
Profile
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
|
||||
// Key backups
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" +
|
|||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
|
@ -93,7 +93,7 @@ const deleteDevicesSQL = "" +
|
|||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
|
|
@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var lastseents sql.NullInt64
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
|
@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
|
||||
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -177,5 +178,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
return id + 1, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" +
|
|||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
|
@ -78,7 +78,7 @@ const deleteDevicesSQL = "" +
|
|||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
|
|
@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
return devices, err
|
||||
}
|
||||
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
|
@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
var lastseents sql.NullInt64
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
|
|
|
|||
164
userapi/storage/storage_test.go
Normal file
164
userapi/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, time.Minute.Milliseconds(), time.Minute, "_server")
|
||||
if err != nil {
|
||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
|
||||
// Tests storing and getting account data
|
||||
func Test_AccountData(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
events := room.Events()
|
||||
|
||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||
assert.NoError(t, err, "unable to get account data by type")
|
||||
assert.Equal(t, contentRoom, accountData)
|
||||
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||
})
|
||||
}
|
||||
|
||||
// Tests the creation of accounts
|
||||
func Test_Accounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
_ = close
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
accGet, err := db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, false, available)
|
||||
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, true, available)
|
||||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
// SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead
|
||||
_, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, second, first)
|
||||
|
||||
// update password for alice
|
||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to deactivate account")
|
||||
// This should fail now, as the account is deactivated
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.Error(t, err, "expected an error, got none")
|
||||
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||
assert.Error(t, err, "expected an error for non existent localpart")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Devices(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser()
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
accessToken := util.RandomString(16)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
|
||||
assert.NoError(t, err, "unable to get device by access token")
|
||||
assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
|
||||
|
||||
// create a device without existing device ID
|
||||
accessToken = util.RandomString(16)
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
// Get devices
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||
assert.NoError(t, err, "unable to get devices by localpart")
|
||||
deviceIDs := make([]string, 0, len(devices))
|
||||
for _, dev := range devices {
|
||||
deviceIDs = append(deviceIDs, dev.ID)
|
||||
}
|
||||
|
||||
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
||||
assert.NoError(t, err, "unable to get devices by id")
|
||||
assert.Equal(t, devices, devices2)
|
||||
t.Logf("%+v", devices)
|
||||
})
|
||||
}
|
||||
|
|
@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
|
|||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
}
|
||||
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue