mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Add tests for stream ID generation
This commit is contained in:
parent
cad8da67a9
commit
3febdd1a40
|
|
@ -102,8 +102,16 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) {
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID)
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIDToStreamID[userID] = streamID
|
||||
userIDToStreamID[userID] = int(streamID)
|
||||
}
|
||||
// set the stream IDs for each key
|
||||
for i := range keys {
|
||||
|
|
|
|||
|
|
@ -127,8 +127,16 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error) {
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&streamID)
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
|
@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
|
|||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||
// and that they are returned correctly when querying for device keys.
|
||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||
db, err := NewDatabase("file::memory:", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||
}
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
if len(msgs) != len(wantStreamIDs) {
|
||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ type OneTimeKeys interface {
|
|||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int, err error)
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue