Add UserAPI storage tests (#2384)
* Add tests for parts of the userapi storage * Add tests for keybackup * Add LoginToken tests * Add OpenID tests * Add profile tests * Add pusher tests * Add ThreePID tests * Add notification tests * Add more device tests, fix numeric localpart query * Fix failing CI * Fix numeric local part query
This commit is contained in:
parent
d7cc187ec0
commit
f023cdf8c4
1
go.mod
1
go.mod
|
@ -47,6 +47,7 @@ require (
|
|||
github.com/pressly/goose v2.7.0+incompatible
|
||||
github.com/prometheus/client_golang v1.12.1
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/tidwall/gjson v1.14.0
|
||||
github.com/tidwall/sjson v1.2.4
|
||||
github.com/uber/jaeger-client-go v2.30.0+incompatible
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -56,8 +57,6 @@ import (
|
|||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
// BaseDendrite is a base for creating new instances of dendrite. It parses
|
||||
|
@ -273,7 +272,7 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
|
|||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||
// be called once per component.
|
||||
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||
db, err := userdb.NewDatabase(
|
||||
db, err := userdb.NewUserAPIDatabase(
|
||||
&b.Cfg.UserAPI.AccountDatabase,
|
||||
b.Cfg.Global.ServerName,
|
||||
b.Cfg.UserAPI.BCryptCost,
|
||||
|
|
|
@ -27,18 +27,24 @@ import (
|
|||
type Profile interface {
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
|
||||
SetDisplayName(ctx context.Context, localpart string, displayName string) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Profile
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
type Account interface {
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
}
|
||||
|
||||
type AccountData interface {
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
@ -46,26 +52,9 @@ type Database interface {
|
|||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
|
||||
// Key backups
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
|
@ -79,11 +68,22 @@ type Database interface {
|
|||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
}
|
||||
|
||||
type KeyBackup interface {
|
||||
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
|
||||
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
|
||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||
}
|
||||
|
||||
type LoginToken interface {
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
@ -94,21 +94,50 @@ type Database interface {
|
|||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||
// May return sql.ErrNoRows.
|
||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||
}
|
||||
|
||||
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
DeleteOldNotifications(ctx context.Context) error
|
||||
type OpenID interface {
|
||||
CreateOpenIDToken(ctx context.Context, token, userID string) (exp int64, err error)
|
||||
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
type Pusher interface {
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type ThreePID interface {
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
}
|
||||
|
||||
type Notification interface {
|
||||
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
KeyBackup
|
||||
LoginToken
|
||||
Notification
|
||||
OpenID
|
||||
Profile
|
||||
Pusher
|
||||
ThreePID
|
||||
}
|
||||
|
||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||
// a third-party identifier which is already associated to a local user.
|
||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||
|
|
|
@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts (
|
|||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
-- Create sequence for autogenerated numeric usernames
|
||||
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
|
@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
|||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT nextval('numeric_username_seq')"
|
||||
"SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
|
@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
return id + 1, err
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" +
|
|||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
@ -93,7 +93,7 @@ const deleteDevicesSQL = "" +
|
|||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
|
@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var lastseents sql.NullInt64
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
|
|
@ -577,21 +577,6 @@ func (d *Database) UpdateDevice(
|
|||
})
|
||||
}
|
||||
|
||||
// RemoveDevice revokes a device by deleting the entry in the database
|
||||
// matching with the given device ID and user ID localpart.
|
||||
// If the device doesn't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevice(
|
||||
ctx context.Context, deviceID, localpart string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||
// matching with the given device IDs and user ID localpart.
|
||||
// If the devices don't exist, it will not return an error
|
||||
|
|
|
@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
|
|||
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COUNT(localpart) FROM account_accounts"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount(
|
|||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -177,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
|
|||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
return
|
||||
if err == sql.ErrNoRows {
|
||||
return 1, nil
|
||||
}
|
||||
return id + 1, err
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" +
|
|||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
|
@ -78,7 +78,7 @@ const deleteDevicesSQL = "" +
|
|||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
|
||||
|
@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
return devices, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var lastseents sql.NullInt64
|
||||
var id, displayname, ip, useragent sql.NullString
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
var lastseents sql.NullInt64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
|
|
@ -28,9 +28,9 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
|
|
539
userapi/storage/storage_test.go
Normal file
539
userapi/storage/storage_test.go
Normal file
|
@ -0,0 +1,539 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const loginTokenLifetime = time.Minute
|
||||
|
||||
var (
|
||||
openIDLifetimeMS = time.Minute.Milliseconds()
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||
if err != nil {
|
||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
|
||||
// Tests storing and getting account data
|
||||
func Test_AccountData(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
events := room.Events()
|
||||
|
||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||
assert.NoError(t, err, "unable to get account data by type")
|
||||
assert.Equal(t, contentRoom, accountData)
|
||||
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||
})
|
||||
}
|
||||
|
||||
// Tests the creation of accounts
|
||||
func Test_Accounts(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
assert.NoError(t, err, "failed to get account by password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to get account by localpart")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, false, available)
|
||||
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, true, available)
|
||||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, second, first)
|
||||
|
||||
// update password for alice
|
||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.NoError(t, err, "failed to get account by new password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "failed to deactivate account")
|
||||
// This should fail now, as the account is deactivated
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
assert.Error(t, err, "expected an error, got none")
|
||||
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||
assert.Error(t, err, "expected an error for non existent localpart")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Devices(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
accessToken := util.RandomString(16)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
|
||||
assert.NoError(t, err, "unable to get device by access token")
|
||||
assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
|
||||
|
||||
// create a device without existing device ID
|
||||
accessToken = util.RandomString(16)
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
// Get devices
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||
assert.NoError(t, err, "unable to get devices by localpart")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
deviceIDs := make([]string, 0, len(devices))
|
||||
for _, dev := range devices {
|
||||
deviceIDs = append(deviceIDs, dev.ID)
|
||||
}
|
||||
|
||||
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
||||
assert.NoError(t, err, "unable to get devices by id")
|
||||
assert.Equal(t, devices, devices2)
|
||||
|
||||
// Update device
|
||||
newName := "new display name"
|
||||
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
||||
assert.NoError(t, err, "unable to update device displayname")
|
||||
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1")
|
||||
assert.NoError(t, err, "unable to update device last seen")
|
||||
|
||||
deviceWithID.DisplayName = newName
|
||||
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
|
||||
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
|
||||
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
|
||||
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
|
||||
|
||||
// create one more device and remove the devices step by step
|
||||
newDeviceID := util.RandomString(16)
|
||||
accessToken = util.RandomString(16)
|
||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create new device")
|
||||
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 3, len(devices))
|
||||
|
||||
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
||||
assert.NoError(t, err, "unable to remove devices")
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 1, len(devices))
|
||||
|
||||
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
||||
assert.NoError(t, err, "unable to remove all devices")
|
||||
assert.Equal(t, 1, len(deleted))
|
||||
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_KeyBackup(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
room := test.NewRoom(t, alice)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantAuthData := json.RawMessage("my auth data")
|
||||
wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
|
||||
assert.NoError(t, err, "unable to create key backup")
|
||||
// get key backup by version
|
||||
gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
|
||||
assert.NoError(t, err, "unable to get key backup")
|
||||
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||
|
||||
// get any key backup
|
||||
gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
|
||||
assert.NoError(t, err, "unable to get key backup")
|
||||
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
|
||||
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
|
||||
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
|
||||
|
||||
err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
|
||||
assert.NoError(t, err, "unable to update key backup auth data")
|
||||
|
||||
uploads := []api.InternalKeyBackupSession{
|
||||
{
|
||||
KeyBackupSession: api.KeyBackupSession{
|
||||
IsVerified: true,
|
||||
SessionData: wantAuthData,
|
||||
},
|
||||
RoomID: room.ID,
|
||||
SessionID: "1",
|
||||
},
|
||||
{
|
||||
KeyBackupSession: api.KeyBackupSession{},
|
||||
RoomID: room.ID,
|
||||
SessionID: "2",
|
||||
},
|
||||
}
|
||||
count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
|
||||
assert.NoError(t, err, "unable to upsert backup keys")
|
||||
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||
|
||||
// do it again to update a key
|
||||
uploads[1].IsVerified = true
|
||||
count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
|
||||
assert.NoError(t, err, "unable to upsert backup keys")
|
||||
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
|
||||
|
||||
// get backup keys by session id
|
||||
gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
|
||||
assert.NoError(t, err, "unable to get backup keys")
|
||||
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||
|
||||
// get backup keys by room id
|
||||
gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
|
||||
assert.NoError(t, err, "unable to get backup keys")
|
||||
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
|
||||
|
||||
gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
|
||||
assert.NoError(t, err, "unable to get backup keys count")
|
||||
assert.Equal(t, count, gotCount, "unexpected backup count")
|
||||
|
||||
// finally delete a key
|
||||
exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
|
||||
assert.NoError(t, err, "unable to delete key backup")
|
||||
assert.True(t, exists)
|
||||
|
||||
// this key should not exist
|
||||
exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
|
||||
assert.NoError(t, err, "unable to delete key backup")
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_LoginToken(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create a new token
|
||||
wantLoginToken := &api.LoginTokenData{UserID: alice.ID}
|
||||
|
||||
gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken)
|
||||
assert.NoError(t, err, "unable to create login token")
|
||||
assert.NotNil(t, gotMetadata)
|
||||
assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime))
|
||||
|
||||
// get the new token
|
||||
gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
|
||||
assert.NoError(t, err, "unable to get login token")
|
||||
assert.NotNil(t, gotLoginToken)
|
||||
assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token")
|
||||
|
||||
// remove the login token again
|
||||
err = db.RemoveLoginToken(ctx, gotMetadata.Token)
|
||||
assert.NoError(t, err, "unable to remove login token")
|
||||
|
||||
// check if the token was actually deleted
|
||||
_, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
|
||||
assert.Error(t, err, "expected an error, but got none")
|
||||
})
|
||||
}
|
||||
|
||||
func Test_OpenID(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
token := util.RandomString(24)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
|
||||
expires, err := db.CreateOpenIDToken(ctx, token, alice.ID)
|
||||
assert.NoError(t, err, "unable to create OpenID token")
|
||||
assert.Equal(t, expiresAtMS, expires)
|
||||
|
||||
attributes, err := db.GetOpenIDTokenAttributes(ctx, token)
|
||||
assert.NoError(t, err, "unable to get OpenID token attributes")
|
||||
assert.Equal(t, alice.ID, attributes.UserID)
|
||||
assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Profile(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get profile by localpart")
|
||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
|
||||
// set avatar & displayname
|
||||
wantProfile.DisplayName = "Alice"
|
||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
||||
assert.NoError(t, err, "unable to set displayname")
|
||||
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||
assert.NoError(t, err, "unable to set avatar url")
|
||||
// verify profile
|
||||
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get profile by localpart")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
|
||||
// search profiles
|
||||
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
|
||||
assert.NoError(t, err, "unable to search profiles")
|
||||
assert.Equal(t, 1, len(searchRes))
|
||||
assert.Equal(t, *wantProfile, searchRes[0])
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Pusher(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
appID := util.RandomString(8)
|
||||
var pushKeys []string
|
||||
var gotPushers []api.Pusher
|
||||
for i := 0; i < 2; i++ {
|
||||
pushKey := util.RandomString(8)
|
||||
|
||||
wantPusher := api.Pusher{
|
||||
PushKey: pushKey,
|
||||
Kind: api.HTTPKind,
|
||||
AppID: appID,
|
||||
AppDisplayName: util.RandomString(8),
|
||||
DeviceDisplayName: util.RandomString(8),
|
||||
ProfileTag: util.RandomString(8),
|
||||
Language: util.RandomString(2),
|
||||
}
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to upsert pusher")
|
||||
|
||||
// check it was actually persisted
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, i+1, len(gotPushers))
|
||||
assert.Equal(t, wantPusher, gotPushers[i])
|
||||
pushKeys = append(pushKeys, pushKey)
|
||||
}
|
||||
|
||||
// remove single pusher
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 1, len(gotPushers))
|
||||
|
||||
// remove last pusher
|
||||
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 0, len(gotPushers))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ThreePID(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
||||
assert.NoError(t, err, "unable to save threepid association")
|
||||
|
||||
// get the stored threepid
|
||||
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||
|
||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 1, len(threepids))
|
||||
assert.Equal(t, authtypes.ThreePID{
|
||||
Address: threePID,
|
||||
Medium: medium,
|
||||
}, threepids[0])
|
||||
|
||||
// remove threepid association
|
||||
err = db.RemoveThreePIDAssociation(ctx, threePID, medium)
|
||||
assert.NoError(t, err, "unexpected error")
|
||||
|
||||
// verify it was deleted
|
||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 0, len(threepids))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Notification(t *testing.T) {
|
||||
alice := test.NewUser()
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
// generate some dummy notifications
|
||||
for i := 0; i < 10; i++ {
|
||||
eventID := util.RandomString(16)
|
||||
roomID := room.ID
|
||||
ts := time.Now()
|
||||
if i > 5 {
|
||||
roomID = room2.ID
|
||||
// create some old notifications to test DeleteOldNotifications
|
||||
ts = ts.AddDate(0, -2, 0)
|
||||
}
|
||||
notification := &api.Notification{
|
||||
Actions: []*pushrules.Action{
|
||||
{},
|
||||
},
|
||||
Event: gomatrixserverlib.ClientEvent{
|
||||
Content: gomatrixserverlib.RawJSON("{}"),
|
||||
},
|
||||
Read: false,
|
||||
RoomID: roomID,
|
||||
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||
}
|
||||
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification)
|
||||
assert.NoError(t, err, "unable to insert notification")
|
||||
}
|
||||
|
||||
// get notifications
|
||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notification count")
|
||||
assert.Equal(t, int64(10), count)
|
||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notifications")
|
||||
assert.Equal(t, int64(10), count)
|
||||
assert.Equal(t, 10, len(notifs))
|
||||
// ... for a specific room
|
||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(4), total)
|
||||
|
||||
// mark notification as read
|
||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
// this should delete 2 notifications
|
||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(2), total)
|
||||
|
||||
// delete old notifications
|
||||
err = db.DeleteOldNotifications(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// this should now return 0 notifications
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
}
|
|
@ -23,7 +23,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func NewDatabase(
|
||||
func NewUserAPIDatabase(
|
||||
dbProperties *config.DatabaseOptions,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
bcryptCost int,
|
||||
|
|
|
@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
|
|||
MaxOpenConnections: 1,
|
||||
MaxIdleConnections: 1,
|
||||
}
|
||||
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue