mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-01 11:13:12 -06:00
165 lines
6.3 KiB
Go
165 lines
6.3 KiB
Go
package storage_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/matrix-org/dendrite/setup/config"
|
|
"github.com/matrix-org/dendrite/test"
|
|
"github.com/matrix-org/dendrite/userapi/api"
|
|
"github.com/matrix-org/dendrite/userapi/storage"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/matrix-org/util"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
|
db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
|
|
ConnectionString: config.DataSource(connStr),
|
|
}, "localhost", bcrypt.MinCost, time.Minute.Milliseconds(), time.Minute, "_server")
|
|
if err != nil {
|
|
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
|
}
|
|
return db, close
|
|
}
|
|
|
|
// Tests storing and getting account data
|
|
func Test_AccountData(t *testing.T) {
|
|
ctx := context.Background()
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
db, close := mustCreateDatabase(t, dbType)
|
|
defer close()
|
|
alice := test.NewUser()
|
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
|
assert.NoError(t, err)
|
|
|
|
room := test.NewRoom(t, alice)
|
|
events := room.Events()
|
|
|
|
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
|
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
|
assert.NoError(t, err, "unable to save account data")
|
|
|
|
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
|
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
|
assert.NoError(t, err, "unable to save account data")
|
|
|
|
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
|
assert.NoError(t, err, "unable to get account data by type")
|
|
assert.Equal(t, contentRoom, accountData)
|
|
|
|
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
|
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
|
})
|
|
}
|
|
|
|
// Tests the creation of accounts
|
|
func Test_Accounts(t *testing.T) {
|
|
ctx := context.Background()
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
db, close := mustCreateDatabase(t, dbType)
|
|
defer close()
|
|
_ = close
|
|
alice := test.NewUser()
|
|
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
|
assert.NoError(t, err)
|
|
|
|
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
|
assert.NoError(t, err, "failed to create account")
|
|
// verify the newly create account is the same as returned by CreateAccount
|
|
accGet, err := db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
|
assert.Equal(t, accAlice, accGet)
|
|
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
|
assert.Equal(t, accAlice, accGet)
|
|
|
|
// check account availability
|
|
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
|
assert.NoError(t, err, "failed to checkout account availability")
|
|
assert.Equal(t, false, available)
|
|
|
|
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
|
assert.NoError(t, err, "failed to checkout account availability")
|
|
assert.Equal(t, true, available)
|
|
|
|
// get guest account numeric aliceLocalpart
|
|
first, err := db.GetNewNumericLocalpart(ctx)
|
|
assert.NoError(t, err)
|
|
// SQLite requires a new user to be created, as it doesn't have a sequence and uses the count(localpart) instead
|
|
_, err = db.CreateAccount(ctx, strconv.Itoa(int(first)), "testing", "", api.AccountTypeAdmin)
|
|
assert.NoError(t, err, "failed to create account")
|
|
second, err := db.GetNewNumericLocalpart(ctx)
|
|
assert.NoError(t, err)
|
|
assert.Greater(t, second, first)
|
|
|
|
// update password for alice
|
|
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
|
assert.NoError(t, err, "failed to update password")
|
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
|
assert.Equal(t, accAlice, accGet)
|
|
|
|
// deactivate account
|
|
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
|
assert.NoError(t, err, "failed to deactivate account")
|
|
// This should fail now, as the account is deactivated
|
|
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
|
assert.Error(t, err, "expected an error, got none")
|
|
|
|
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
|
assert.Error(t, err, "expected an error for non existent localpart")
|
|
})
|
|
}
|
|
|
|
func Test_Devices(t *testing.T) {
|
|
ctx := context.Background()
|
|
alice := test.NewUser()
|
|
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
|
assert.NoError(t, err)
|
|
deviceID := util.RandomString(8)
|
|
accessToken := util.RandomString(16)
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
db, close := mustCreateDatabase(t, dbType)
|
|
defer close()
|
|
|
|
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
|
|
|
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
|
assert.NoError(t, err, "unable to get device by id")
|
|
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
|
|
|
gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
|
|
assert.NoError(t, err, "unable to get device by access token")
|
|
assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
|
|
|
|
// create a device without existing device ID
|
|
accessToken = util.RandomString(16)
|
|
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
|
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
|
assert.NoError(t, err, "unable to get device by id")
|
|
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
|
|
|
// Get devices
|
|
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
|
assert.NoError(t, err, "unable to get devices by localpart")
|
|
deviceIDs := make([]string, 0, len(devices))
|
|
for _, dev := range devices {
|
|
deviceIDs = append(deviceIDs, dev.ID)
|
|
}
|
|
|
|
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
|
|
assert.NoError(t, err, "unable to get devices by id")
|
|
assert.Equal(t, devices, devices2)
|
|
t.Logf("%+v", devices)
|
|
})
|
|
}
|