mirror of
https://github.com/matrix-org/dendrite.git
synced 2026-01-21 13:03:09 -06:00
Add AccountData tests, fix issue with Postgres
This commit is contained in:
parent
0646a73758
commit
911043ce10
|
|
@ -29,7 +29,7 @@ import (
|
||||||
type KeyChange struct {
|
type KeyChange struct {
|
||||||
Topic string
|
Topic string
|
||||||
JetStream nats.JetStreamContext
|
JetStream nats.JetStreamContext
|
||||||
DB storage.KeyserverDatabase
|
DB storage.KeyChangeDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ type JetStreamPublisher interface {
|
||||||
|
|
||||||
// SyncAPI produces messages for the Sync API server to consume.
|
// SyncAPI produces messages for the Sync API server to consume.
|
||||||
type SyncAPI struct {
|
type SyncAPI struct {
|
||||||
db storage.Database
|
db storage.Notification
|
||||||
producer JetStreamPublisher
|
producer JetStreamPublisher
|
||||||
clientDataTopic string
|
clientDataTopic string
|
||||||
notificationDataTopic string
|
notificationDataTopic string
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,14 @@ type Database interface {
|
||||||
KeyserverDatabase
|
KeyserverDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyChangeDatabase interface {
|
||||||
|
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||||
|
// `userID` is the the user who has changed their keys in some way.
|
||||||
|
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
type KeyserverDatabase interface {
|
type KeyserverDatabase interface {
|
||||||
|
KeyChangeDatabase
|
||||||
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
||||||
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
||||||
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||||
|
|
@ -185,10 +192,6 @@ type KeyserverDatabase interface {
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||||
|
|
||||||
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
|
||||||
// `userID` is the the user who has changed their keys in some way.
|
|
||||||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
|
||||||
|
|
||||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||||
// A to offset of types.OffsetNewest means no upper limit.
|
// A to offset of types.OffsetNewest means no upper limit.
|
||||||
// Returns the offset of the latest key change.
|
// Returns the offset of the latest key change.
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
|
||||||
roomID, dataType string, content json.RawMessage,
|
roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
|
||||||
|
// when passing the data to trigger "NOT NULL" constraint
|
||||||
|
var data *json.RawMessage
|
||||||
|
if len(content) > 0 {
|
||||||
|
data = &content
|
||||||
|
}
|
||||||
|
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/userapi/producers"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/nats-io/nats.go"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
|
@ -38,6 +40,13 @@ const (
|
||||||
|
|
||||||
type apiTestOpts struct {
|
type apiTestOpts struct {
|
||||||
loginTokenLifetime time.Duration
|
loginTokenLifetime time.Duration
|
||||||
|
serverName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyProducer struct{}
|
||||||
|
|
||||||
|
func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
|
||||||
|
return &nats.PubAck{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
|
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
|
||||||
|
|
@ -46,9 +55,13 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
||||||
}
|
}
|
||||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
sName := serverName
|
||||||
|
if opts.serverName != "" {
|
||||||
|
sName = gomatrixserverlib.ServerName(opts.serverName)
|
||||||
|
}
|
||||||
accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||||
ConnectionString: config.DataSource(connStr),
|
ConnectionString: config.DataSource(connStr),
|
||||||
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
}, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create account DB: %s", err)
|
t.Fatalf("failed to create account DB: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -56,14 +69,17 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
||||||
cfg := &config.UserAPI{
|
cfg := &config.UserAPI{
|
||||||
Matrix: &config.Global{
|
Matrix: &config.Global{
|
||||||
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||||
ServerName: serverName,
|
ServerName: sName,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
|
||||||
|
|
||||||
return &internal.UserInternalAPI{
|
return &internal.UserInternalAPI{
|
||||||
DB: accountDB,
|
DB: accountDB,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
|
SyncProducer: syncProducer,
|
||||||
}, accountDB, func() {
|
}, accountDB, func() {
|
||||||
close()
|
close()
|
||||||
baseclose()
|
baseclose()
|
||||||
|
|
@ -332,3 +348,87 @@ func TestQueryAccountByLocalpart(t *testing.T) {
|
||||||
testCases(t, intAPI)
|
testCases(t, intAPI)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountData(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputData *api.InputAccountDataRequest
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "not a local user",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "local user missing datatype",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing json",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with json",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "room data",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ignored users",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "m.fully_read",
|
||||||
|
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
res := api.InputAccountDataResponse{}
|
||||||
|
err := intAPI.InputAccountData(ctx, tc.inputData, &res)
|
||||||
|
if tc.wantErr && err == nil {
|
||||||
|
t.Fatalf("expected an error, but got none")
|
||||||
|
}
|
||||||
|
if !tc.wantErr && err != nil {
|
||||||
|
t.Fatalf("expected no error, but got: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// query the data again and compare
|
||||||
|
queryRes := api.QueryAccountDataResponse{}
|
||||||
|
queryReq := api.QueryAccountDataRequest{
|
||||||
|
UserID: tc.inputData.UserID,
|
||||||
|
DataType: tc.inputData.DataType,
|
||||||
|
RoomID: tc.inputData.RoomID,
|
||||||
|
}
|
||||||
|
err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes)
|
||||||
|
if err != nil && !tc.wantErr {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// verify global data
|
||||||
|
if tc.inputData.RoomID == "" {
|
||||||
|
if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) {
|
||||||
|
t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType]))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// verify room data
|
||||||
|
if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) {
|
||||||
|
t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue