Add test for QueryDeviceMessages
(#2773)
Adds tests for `QueryDeviceMessages` and also includes some optimizations to reduce allocations in the DB layer.
This commit is contained in:
parent
453b50e1d3
commit
b9d0e9f7ed
|
@ -212,15 +212,13 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
maxStreamID := int64(0)
|
maxStreamID := int64(0)
|
||||||
|
// remove deleted devices
|
||||||
|
var result []api.DeviceMessage
|
||||||
for _, m := range msgs {
|
for _, m := range msgs {
|
||||||
if m.StreamID > maxStreamID {
|
if m.StreamID > maxStreamID {
|
||||||
maxStreamID = m.StreamID
|
maxStreamID = m.StreamID
|
||||||
}
|
}
|
||||||
}
|
if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
|
||||||
// remove deleted devices
|
|
||||||
var result []api.DeviceMessage
|
|
||||||
for _, m := range msgs {
|
|
||||||
if m.KeyJSON == nil {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result = append(result, m)
|
result = append(result, m)
|
||||||
|
|
156
keyserver/internal/internal_test.go
Normal file
156
keyserver/internal/internal_test.go
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
package internal_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
|
t.Helper()
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := storage.NewDatabase(nil, &config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create new user db: %v", err)
|
||||||
|
}
|
||||||
|
return db, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_QueryDeviceMessages(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
type args struct {
|
||||||
|
req *api.QueryDeviceMessagesRequest
|
||||||
|
res *api.QueryDeviceMessagesResponse
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
want *api.QueryDeviceMessagesResponse
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no existing keys",
|
||||||
|
args: args{
|
||||||
|
req: &api.QueryDeviceMessagesRequest{
|
||||||
|
UserID: "@doesNotExist:localhost",
|
||||||
|
},
|
||||||
|
res: &api.QueryDeviceMessagesResponse{},
|
||||||
|
},
|
||||||
|
want: &api.QueryDeviceMessagesResponse{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "existing user returns devices",
|
||||||
|
args: args{
|
||||||
|
req: &api.QueryDeviceMessagesRequest{
|
||||||
|
UserID: alice.ID,
|
||||||
|
},
|
||||||
|
res: &api.QueryDeviceMessagesResponse{},
|
||||||
|
},
|
||||||
|
want: &api.QueryDeviceMessagesResponse{
|
||||||
|
StreamID: 6,
|
||||||
|
Devices: []api.DeviceMessage{
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
DisplayName: "first device",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("ghi"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "mySecondDevice",
|
||||||
|
DisplayName: "second device",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("jkl"),
|
||||||
|
}, // streamID 6
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceMessages := []api.DeviceMessage{
|
||||||
|
{ // not the user we're looking for
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
UserID: "@doesNotExist:localhost",
|
||||||
|
},
|
||||||
|
// streamID 1 for this user
|
||||||
|
},
|
||||||
|
{ // empty keyJSON will be ignored
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
UserID: alice.ID,
|
||||||
|
}, // streamID 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("abc"),
|
||||||
|
}, // streamID 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("def"),
|
||||||
|
}, // streamID 3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte(""),
|
||||||
|
}, // streamID 4
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "myDevice",
|
||||||
|
DisplayName: "first device",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("ghi"),
|
||||||
|
}, // streamID 5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||||
|
DeviceID: "mySecondDevice",
|
||||||
|
UserID: alice.ID,
|
||||||
|
KeyJSON: []byte("jkl"),
|
||||||
|
DisplayName: "second device",
|
||||||
|
}, // streamID 6
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, closeDB := mustCreateDatabase(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
|
||||||
|
t.Fatalf("failed to store local devicesKeys")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &internal.KeyInternalAPI{
|
||||||
|
DB: db,
|
||||||
|
}
|
||||||
|
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
got := tt.args.res
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
@ -204,20 +205,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
dk := api.DeviceMessage{
|
dk := api.DeviceMessage{
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
DeviceKeys: &api.DeviceKeys{},
|
DeviceKeys: &api.DeviceKeys{
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
dk.UserID = userID
|
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||||
var keyJSON string
|
|
||||||
var streamID int64
|
|
||||||
var displayName sql.NullString
|
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
|
||||||
dk.StreamID = streamID
|
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dk.DisplayName = displayName.String
|
dk.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,21 +137,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
dk := api.DeviceMessage{
|
dk := api.DeviceMessage{
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
Type: api.TypeDeviceKeyUpdate,
|
||||||
DeviceKeys: &api.DeviceKeys{},
|
DeviceKeys: &api.DeviceKeys{
|
||||||
|
UserID: userID,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
dk.Type = api.TypeDeviceKeyUpdate
|
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||||
dk.UserID = userID
|
|
||||||
var keyJSON string
|
|
||||||
var streamID int64
|
|
||||||
var displayName sql.NullString
|
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
|
||||||
dk.StreamID = streamID
|
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dk.DisplayName = displayName.String
|
dk.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue