Merge branch 'master' into neilalexander/config

This commit is contained in:
Neil Alexander 2020-07-30 13:58:28 +01:00
commit 930ced1102
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
33 changed files with 908 additions and 259 deletions

View file

@ -29,7 +29,9 @@ func main() {
rsAPI := base.RoomserverHTTPClient()
syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, &cfg.SyncAPI)
syncapi.AddPublicRoutes(
base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, base.KeyServerHTTPClient(), base.CurrentStateAPIClient(),
federation, &cfg.SyncAPI)
base.SetupAndServeHTTP(string(base.Cfg.SyncAPI.Bind), string(base.Cfg.SyncAPI.Listen))

View file

@ -77,6 +77,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
)
mediaapi.AddPublicRoutes(publicMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
syncapi.AddPublicRoutes(
publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, &m.Config.SyncAPI,
publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI,
m.KeyAPI, m.StateAPI, m.FedClient, &m.Config.SyncAPI,
)
}

View file

@ -26,6 +26,7 @@ type KeyInternalAPI interface {
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
}
// KeyError is returned if there was a problem performing/querying the server
@ -131,3 +132,21 @@ type QueryKeysResponse struct {
// Set if there was a fatal error processing this query
Error *KeyError
}
type QueryKeyChangesRequest struct {
// The partition which had key events sent to
Partition int32
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
Offset int64
}
type QueryKeyChangesResponse struct {
// The set of users who have had their keys change.
UserIDs []string
// The partition being served - useful if the partition is unknown at request time
Partition int32
// The latest offset represented in this response.
Offset int64
// Set if there was a problem handling the request.
Error *KeyError
}

View file

@ -40,6 +40,21 @@ type KeyInternalAPI struct {
Producer *producers.KeyChange
}
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
if req.Partition < 0 {
req.Partition = a.Producer.DefaultPartition()
}
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
}
res.Offset = latest
res.Partition = req.Partition
res.UserIDs = userIDs
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadDeviceKeys(ctx, req, res)

View file

@ -29,6 +29,7 @@ const (
PerformUploadKeysPath = "/keyserver/performUploadKeys"
PerformClaimKeysPath = "/keyserver/performClaimKeys"
QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
)
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
@ -101,3 +102,20 @@ func (h *httpKeyInternalAPI) QueryKeys(
}
}
}
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
defer span.Finish()
apiURL := h.apiURL + QueryKeyChangesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}

View file

@ -58,4 +58,15 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}
response := api.QueryKeyChangesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyChanges(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -46,6 +46,7 @@ func NewInternalAPI(
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
Producer: producer,
DB: db,
}
return &internal.KeyInternalAPI{
DB: db,

View file

@ -15,10 +15,12 @@
package producers
import (
"context"
"encoding/json"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/sirupsen/logrus"
)
@ -26,6 +28,16 @@ import (
type KeyChange struct {
Topic string
Producer sarama.SyncProducer
DB storage.Database
}
// DefaultPartition returns the default partition this process is sending key changes to.
// NB: A keyserver MUST send key changes to only 1 partition or else query operations will
// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but
// then all keyservers must be queried to calculate the entire set of key changes between
// two sync tokens.
func (p *KeyChange) DefaultPartition() int32 {
return 0
}
// ProduceKeyChanges creates new change events for each key
@ -46,6 +58,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
if err != nil {
return err
}
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
if err != nil {
return err
}
logrus.WithFields(logrus.Fields{
"user_id": key.UserID,
"device_id": key.DeviceID,

View file

@ -43,4 +43,12 @@ type Database interface {
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
// their keys in some way.
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
// KeyChanges returns a list of user IDs who have modified their keys from the offset given.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
}

View file

@ -0,0 +1,97 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
log_offset BIGINT NOT NULL,
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -35,9 +35,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil
}

View file

@ -28,6 +28,7 @@ type Database struct {
DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
})
return result, err
}
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
}
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
}

View file

@ -0,0 +1,98 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
offset BIGINT NOT NULL,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (partition, offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT (partition, offset)" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -33,9 +33,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil
}

View file

@ -0,0 +1,57 @@
package storage
import (
"context"
"reflect"
"testing"
)
var ctx = context.Background()
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
func TestKeyChangesNoDupes(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}

View file

@ -35,3 +35,8 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
}

View file

@ -91,7 +91,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data")
}
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0))
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil))
return nil
}

View file

@ -106,7 +106,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
s.notifier.OnNewSendToDevice(
output.UserID,
[]string{output.DeviceID},
types.NewStreamToken(0, streamPos),
types.NewStreamToken(0, streamPos, nil),
)
return nil

View file

@ -65,7 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent(
nil, roomID, nil,
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)),
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
)
})
@ -94,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
}
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos))
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil))
return nil
}

View file

@ -23,7 +23,9 @@ import (
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
syncinternal "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
syncapi "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
@ -31,13 +33,14 @@ import (
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
type OutputKeyChangeEventConsumer struct {
keyChangeConsumer *internal.ContinualConsumer
db storage.Database
serverName gomatrixserverlib.ServerName // our server name
currentStateAPI currentstateAPI.CurrentStateInternalAPI
// keyAPI api.KeyInternalAPI
keyChangeConsumer *internal.ContinualConsumer
db storage.Database
serverName gomatrixserverlib.ServerName // our server name
currentStateAPI currentstateAPI.CurrentStateInternalAPI
keyAPI api.KeyInternalAPI
partitionToOffset map[int32]int64
partitionToOffsetMu sync.Mutex
notifier *syncapi.Notifier
}
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
@ -46,6 +49,8 @@ func NewOutputKeyChangeEventConsumer(
serverName gomatrixserverlib.ServerName,
topic string,
kafkaConsumer sarama.Consumer,
n *syncapi.Notifier,
keyAPI api.KeyInternalAPI,
currentStateAPI currentstateAPI.CurrentStateInternalAPI,
store storage.Database,
) *OutputKeyChangeEventConsumer {
@ -60,9 +65,11 @@ func NewOutputKeyChangeEventConsumer(
keyChangeConsumer: &consumer,
db: store,
serverName: serverName,
keyAPI: keyAPI,
currentStateAPI: currentStateAPI,
partitionToOffset: make(map[int32]int64),
partitionToOffsetMu: sync.Mutex{},
notifier: n,
}
consumer.ProcessMessage = s.onMessage
@ -107,36 +114,22 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
return err
}
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
return nil
}
// Catchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func (s *OutputKeyChangeEventConsumer) Catchup(
ctx context.Context, userID string, res *types.Response, tok types.StreamingToken,
) (hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID)
newlyLeftRooms := leftRooms(res)
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
changed, left, err := s.trackChangedUsers(ctx, userID, newlyJoinedRooms, newlyLeftRooms)
if err != nil {
return false, err
}
res.DeviceLists.Changed = changed
res.DeviceLists.Left = left
hasNew = len(changed) > 0 || len(left) > 0
posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{
syncinternal.DeviceListLogName: &types.LogPosition{
Offset: msg.Offset,
Partition: msg.Partition,
},
})
for userID := range queryRes.UserIDsToCount {
s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
}
// TODO: now also track users who we already share rooms with but who have updated their devices between the two tokens
return
return nil
}
func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) {
// work out who we are now sharing rooms with which we previously were not and notify them about the joining
// users keys:
changed, _, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), []string{ev.RoomID()}, nil)
changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), []string{ev.RoomID()}, nil)
if err != nil {
log.WithError(err).Error("OnJoinEvent: failed to work out changed users")
return
@ -149,7 +142,7 @@ func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.Headere
func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) {
// work out who we are no longer sharing any rooms with and notify them about the leaving user
_, left, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), nil, []string{ev.RoomID()})
_, left, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), nil, []string{ev.RoomID()})
if err != nil {
log.WithError(err).Error("OnLeaveEvent: failed to work out left users")
return
@ -160,129 +153,3 @@ func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.Header
}
}
// nolint:gocyclo
func (s *OutputKeyChangeEventConsumer) trackChangedUsers(
ctx context.Context, userID string, newlyJoinedRooms, newlyLeftRooms []string,
) (changed, left []string, err error) {
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
// Leave algorithm:
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
// - Get users in newly left room. - QueryCurrentState
// - Loop set of users and decrement by 1 for each user in newly left room.
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
var queryRes currentstateAPI.QuerySharedUsersResponse
err = s.currentStateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
IncludeRoomIDs: newlyLeftRooms,
}, &queryRes)
if err != nil {
return nil, nil, err
}
var stateRes currentstateAPI.QueryBulkStateContentResponse
err = s.currentStateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyLeftRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, nil, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
queryRes.UserIDsToCount[tuple.StateKey]--
}
}
for userID, count := range queryRes.UserIDsToCount {
if count <= 0 {
left = append(left, userID) // left is returned
}
}
// Join algorithm:
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
// - Get users in newly joined room - QueryCurrentState
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
// - If yes: then they already shared a room in common, do nothing.
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
err = s.currentStateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
ExcludeRoomIDs: newlyJoinedRooms,
}, &queryRes)
if err != nil {
return nil, left, err
}
err = s.currentStateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyJoinedRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, left, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned
}
}
}
return changed, left, nil
}
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {
// we would expect to see our join event somewhere if we newly joined the room.
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
newlyJoined := membershipEventPresent(join.State.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
continue
}
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
}
}
return roomIDs
}
func leftRooms(res *types.Response) []string {
roomIDs := make([]string, len(res.Rooms.Leave))
i := 0
for roomID := range res.Rooms.Leave {
roomIDs[i] = roomID
i++
}
return roomIDs
}
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
for _, ev := range events {
// it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
return true
}
}
return false
}

View file

@ -158,7 +158,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure")
return nil
}
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0))
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
return nil
}
@ -176,7 +176,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure")
return nil
}
s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0))
s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil))
return nil
}
@ -194,7 +194,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
}
// Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0))
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil))
return nil
}

View file

@ -0,0 +1,219 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/Shopify/sarama"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/keyserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
const DeviceListLogName = "dl"
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
userID string, res *types.Response, tok types.StreamingToken,
) (newTok *types.StreamingToken, hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID)
newlyLeftRooms := leftRooms(res)
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
changed, left, err := TrackChangedUsers(ctx, stateAPI, userID, newlyJoinedRooms, newlyLeftRooms)
if err != nil {
return nil, false, err
}
res.DeviceLists.Changed = changed
res.DeviceLists.Left = left
hasNew = len(changed) > 0 || len(left) > 0
}
// now also track users who we already share rooms with but who have updated their devices between the two tokens
var partition int32
var offset int64
// Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := tok.Log(DeviceListLogName)
if logOffset != nil {
partition = logOffset.Partition
offset = logOffset.Offset
} else {
partition = -1
offset = sarama.OffsetOldest
}
var queryRes api.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
Partition: partition,
Offset: offset,
}, &queryRes)
if queryRes.Error != nil {
// don't fail the catchup because we may have got useful information by tracking membership
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return
}
userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true
}
for _, userID := range queryRes.UserIDs {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
}
}
// Make a new streaming token using the new offset
tok.SetLog(DeviceListLogName, &types.LogPosition{
Offset: queryRes.Offset,
Partition: queryRes.Partition,
})
newTok = &tok
return
}
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.
// nolint:gocyclo
func TrackChangedUsers(
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string,
) (changed, left []string, err error) {
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
// Leave algorithm:
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
// - Get users in newly left room. - QueryCurrentState
// - Loop set of users and decrement by 1 for each user in newly left room.
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
var queryRes currentstateAPI.QuerySharedUsersResponse
err = stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
IncludeRoomIDs: newlyLeftRooms,
}, &queryRes)
if err != nil {
return nil, nil, err
}
var stateRes currentstateAPI.QueryBulkStateContentResponse
err = stateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyLeftRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, nil, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
queryRes.UserIDsToCount[tuple.StateKey]--
}
}
for userID, count := range queryRes.UserIDsToCount {
if count <= 0 {
left = append(left, userID) // left is returned
}
}
// Join algorithm:
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
// - Get users in newly joined room - QueryCurrentState
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
// - If yes: then they already shared a room in common, do nothing.
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
err = stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
ExcludeRoomIDs: newlyJoinedRooms,
}, &queryRes)
if err != nil {
return nil, left, err
}
err = stateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyJoinedRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, left, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned
}
}
}
return changed, left, nil
}
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {
// we would expect to see our join event somewhere if we newly joined the room.
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
newlyJoined := membershipEventPresent(join.State.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
continue
}
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
}
}
return roomIDs
}
func leftRooms(res *types.Response) []string {
roomIDs := make([]string, len(res.Rooms.Leave))
i := 0
for roomID := range res.Rooms.Leave {
roomIDs[i] = roomID
i++
}
return roomIDs
}
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
for _, ev := range events {
// it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
return true
}
}
return false
}

View file

@ -1,4 +1,4 @@
package consumers
package internal
import (
"context"
@ -7,14 +7,29 @@ import (
"testing"
"github.com/matrix-org/dendrite/currentstateserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var (
syncingUser = "@alice:localhost"
emptyToken = types.NewStreamToken(0, 0, nil)
)
type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
}
type mockCurrentStateAPI struct {
roomIDToJoinedMembers map[string][]string
}
@ -144,18 +159,17 @@ func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs
func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
newShareUser := "@bill:localhost"
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, newShareUser},
"!another:room": {syncingUser},
},
}, nil)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true,
@ -167,18 +181,17 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
removeUser := "@bill:localhost"
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
newlyLeftRoom: {removeUser},
"!another:room": {syncingUser},
},
}, nil)
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true,
@ -190,16 +203,15 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
existingUser := "@bob:localhost"
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, existingUser},
"!another:room": {syncingUser, existingUser},
},
}, nil)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@ -212,18 +224,17 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
existingUser := "@bob:localhost"
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
newlyLeftRoom: {existingUser},
"!another:room": {syncingUser, existingUser},
},
}, nil)
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: false,
@ -234,11 +245,6 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
existingUser := "@bob1:localhost"
roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {syncingUser, existingUser},
},
}, nil)
syncResponse := types.NewResponse()
empty := ""
roomStateEvents := []gomatrixserverlib.ClientEvent{
@ -280,9 +286,13 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
jr.Timeline.Events = roomTimelineEvents
syncResponse.Rooms.Join[roomID] = jr
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {syncingUser, existingUser},
},
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: false,
@ -297,18 +307,17 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
newlyLeftUser2 := "@debra:localhost"
newlyJoinedRoom := "!join:bar"
newlyLeftRoom := "!left:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2},
newlyLeftRoom: {newlyLeftUser, newlyLeftUser2},
"!another:room": {syncingUser},
},
}, nil)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@ -333,12 +342,6 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
newShareUser := "@berta:localhost"
newShareUser2 := "@bobby:localhost"
roomID := "!join:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {newShareUser, newShareUser2},
"!another:room": {syncingUser},
},
}, nil)
syncResponse := types.NewResponse()
roomEvents := []gomatrixserverlib.ClientEvent{
{
@ -393,9 +396,14 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
lr.Timeline.Events = roomEvents
syncResponse.Rooms.Leave[roomID] = lr
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {newShareUser, newShareUser2},
"!another:room": {syncingUser},
},
}, syncingUser, syncResponse, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true,

View file

@ -434,7 +434,7 @@ func (d *Database) syncPositionTx(
if maxInviteID > maxEventID {
maxEventID = maxInviteID
}
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()))
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil)
return
}
@ -731,7 +731,7 @@ func (d *Database) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res,
types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err

View file

@ -166,7 +166,7 @@ func TestSyncResponse(t *testing.T) {
Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0),
positions[len(positions)-2], types.StreamPosition(0), nil,
)
res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -179,7 +179,7 @@ func TestSyncResponse(t *testing.T) {
Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0),
positions[len(positions)-11], types.StreamPosition(0), nil,
)
res := types.NewResponse()
// limit is set to 5
@ -222,7 +222,7 @@ func TestSyncResponse(t *testing.T) {
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition())
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil)
if res.NextBatch != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
}
@ -246,7 +246,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err)
}
from := types.NewStreamToken(
positions[len(positions)-2], types.StreamPosition(0),
positions[len(positions)-2], types.StreamPosition(0), nil,
)
res := types.NewResponse()
@ -291,7 +291,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err)
}
// head towards the beginning of time
to := types.NewStreamToken(0, 0)
to := types.NewStreamToken(0, 0, nil)
// backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
@ -534,14 +534,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything
// yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil))
if err != nil {
return
}
@ -559,14 +559,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
if err != nil {
return
}
@ -574,35 +574,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil))
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil))
if err != nil {
t.Fatal(err)
}
@ -639,7 +639,7 @@ func TestInviteBehaviour(t *testing.T) {
}
// both invite events should appear in a new sync
beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}
@ -654,7 +654,7 @@ func TestInviteBehaviour(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err)
}
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
if err != nil {
t.Fatalf("IncrementalSync failed: %s", err)
}

View file

@ -132,6 +132,16 @@ func (n *Notifier) OnNewSendToDevice(
n.wakeupUserDevice(userID, deviceIDs, latestPos)
}
func (n *Notifier) OnNewKeyChange(
posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUsers([]string{wakeUserID}, latestPos)
}
// GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed.
// notify for anything before sincePos

View file

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld = types.NewStreamToken(5, 0)
syncPositionBefore = types.NewStreamToken(11, 0)
syncPositionAfter = types.NewStreamToken(12, 0)
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1)
syncPositionAfter2 = types.NewStreamToken(13, 0)
syncPositionVeryOld = types.NewStreamToken(5, 0, nil)
syncPositionBefore = types.NewStreamToken(11, 0, nil)
syncPositionAfter = types.NewStreamToken(12, 0, nil)
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil)
syncPositionAfter2 = types.NewStreamToken(13, 0, nil)
)
var (

View file

@ -65,7 +65,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
since = &tok
}
if since == nil {
tok := types.NewStreamToken(0, 0)
tok := types.NewStreamToken(0, 0, nil)
since = &tok
}
timelineLimit := DefaultTimelineLimit

View file

@ -22,6 +22,9 @@ import (
"time"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -35,11 +38,16 @@ type RequestPool struct {
db storage.Database
userAPI userapi.UserInternalAPI
notifier *Notifier
keyAPI keyapi.KeyInternalAPI
stateAPI currentstateAPI.CurrentStateInternalAPI
}
// NewRequestPool makes a new RequestPool
func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool {
return &RequestPool{db, userAPI, n}
func NewRequestPool(
db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
) *RequestPool {
return &RequestPool{db, userAPI, n, keyAPI, stateAPI}
}
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
@ -138,7 +146,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
res = types.NewResponse()
since := types.NewStreamToken(0, 0)
since := types.NewStreamToken(0, 0, nil)
if req.since != nil {
since = *req.since
}
@ -164,6 +172,10 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
if err != nil {
return
}
res, err = rp.appendDeviceLists(res, req.device.UserID, since)
if err != nil {
return
}
// Before we return the sync response, make sure that we take action on
// any send-to-device database updates or deletions that we need to do.
@ -192,6 +204,22 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
return
}
func (rp *RequestPool) appendDeviceLists(
data *types.Response, userID string, since types.StreamingToken,
) (*types.Response, error) {
// TODO: Currently this code will race which may result in duplicates but not missing data.
// This happens because, whilst we are told the range to fetch here (since / latest) the
// QueryKeyChanges API only exposes a "from" value (on purpose to avoid racing, which then
// returns the latest position with which the response has authority on). We'd need to tweak
// the API to expose a "to" value to fix this.
_, _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since)
if err != nil {
return nil, err
}
return data, nil
}
// nolint:gocyclo
func (rp *RequestPool) appendAccountData(
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition,

View file

@ -21,7 +21,9 @@ import (
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
currentstateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal/config"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -39,6 +41,8 @@ func AddPublicRoutes(
consumer sarama.Consumer,
userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI,
keyAPI keyapi.KeyInternalAPI,
currentStateAPI currentstateapi.CurrentStateInternalAPI,
federation *gomatrixserverlib.FederationClient,
cfg *config.SyncAPI,
) {
@ -58,7 +62,7 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start notifier")
}
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI)
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
roomConsumer := consumers.NewOutputRoomEventConsumer(
cfg, consumer, notifier, syncDB, rsAPI,
@ -88,5 +92,13 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
}
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent),
consumer, notifier, keyAPI, currentStateAPI, syncDB,
)
if err = keyChangeConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start key change consumer")
}
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
}

View file

@ -39,6 +39,23 @@ var (
// StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64
// LogPosition represents the offset in a Kafka log a client is at.
type LogPosition struct {
Partition int32
Offset int64
}
// IsAfter returns true if this position is after `lp`.
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
if lp == nil {
return false
}
if p.Partition != lp.Partition {
return false
}
return p.Offset > lp.Offset
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct {
gomatrixserverlib.HeaderedEvent
@ -90,6 +107,19 @@ const (
type StreamingToken struct {
syncToken
logs map[string]*LogPosition
}
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
t.logs[name] = lp
}
func (t *StreamingToken) Log(name string) *LogPosition {
l, ok := t.logs[name]
if !ok {
return nil
}
return l
}
func (t *StreamingToken) PDUPosition() StreamPosition {
@ -99,7 +129,15 @@ func (t *StreamingToken) EDUPosition() StreamPosition {
return t.Positions[1]
}
func (t *StreamingToken) String() string {
return t.syncToken.String()
logStrings := []string{
t.syncToken.String(),
}
for name, lp := range t.logs {
logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset)
logStrings = append(logStrings, logStr)
}
// E.g s11_22_33.dl0-134.ab1-441
return strings.Join(logStrings, ".")
}
// IsAfter returns true if ANY position in this token is greater than `other`.
@ -109,12 +147,22 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true
}
}
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
if t.logs[name].IsAfter(otherLog) {
return true
}
}
return false
}
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// If the latter StreamingToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
// If the other token has a log, they will replace any existing log on this token.
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
ret.Type = t.Type
ret.Positions = make([]StreamPosition, len(t.Positions))
@ -125,6 +173,13 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
}
ret.Positions[i] = other.Positions[i]
}
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
t.logs[name] = otherLog
}
return ret
}
@ -139,7 +194,7 @@ func (t *TopologyToken) PDUPosition() StreamPosition {
return t.Positions[1]
}
func (t *TopologyToken) StreamToken() StreamingToken {
return NewStreamToken(t.PDUPosition(), 0)
return NewStreamToken(t.PDUPosition(), 0, nil)
}
func (t *TopologyToken) String() string {
return t.syncToken.String()
@ -174,9 +229,9 @@ func (t *TopologyToken) Decrement() {
// error if the token couldn't be parsed into an int64, or if the token type
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
// case).
func newSyncTokenFromString(s string) (token *syncToken, err error) {
func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) {
if len(s) == 0 {
return nil, ErrInvalidSyncTokenLen
return nil, nil, ErrInvalidSyncTokenLen
}
token = new(syncToken)
@ -185,16 +240,17 @@ func newSyncTokenFromString(s string) (token *syncToken, err error) {
switch t := SyncTokenType(s[:1]); t {
case SyncTokenTypeStream, SyncTokenTypeTopology:
token.Type = t
positions = strings.Split(s[1:], "_")
categories = strings.Split(s[1:], ".")
positions = strings.Split(categories[0], "_")
default:
return nil, ErrInvalidSyncTokenType
return nil, nil, ErrInvalidSyncTokenType
}
for _, pos := range positions {
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
return nil, err
return nil, nil, err
} else if posInt < 0 {
return nil, errors.New("negative position not allowed")
return nil, nil, errors.New("negative position not allowed")
} else {
token.Positions = append(token.Positions, StreamPosition(posInt))
}
@ -215,7 +271,7 @@ func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
}
}
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
t, err := newSyncTokenFromString(tok)
t, _, err := newSyncTokenFromString(tok)
if err != nil {
return
}
@ -233,16 +289,20 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
}
// NewStreamToken creates a new sync token for /sync
func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken {
func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken {
if logs == nil {
logs = make(map[string]*LogPosition)
}
return StreamingToken{
syncToken: syncToken{
Type: SyncTokenTypeStream,
Positions: []StreamPosition{pduPos, eduPos},
},
logs: logs,
}
}
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
t, err := newSyncTokenFromString(tok)
t, categories, err := newSyncTokenFromString(tok)
if err != nil {
return
}
@ -254,8 +314,35 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
return
}
logs := make(map[string]*LogPosition)
if len(categories) > 1 {
// dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("token %s - invalid log: %s", tok, logStr)
return
}
var partition int64
partition, err = strconv.ParseInt(segments[1], 10, 32)
if err != nil {
return
}
var offset int64
offset, err = strconv.ParseInt(segments[2], 10, 64)
if err != nil {
return
}
logs[segments[0]] = &LogPosition{
Partition: int32(partition),
Offset: offset,
}
}
}
return StreamingToken{
syncToken: *t,
logs: logs,
}, nil
}

View file

@ -1,11 +1,61 @@
package types
import "testing"
import (
"reflect"
"testing"
)
func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{
"s4_0": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: make(map[string]*LogPosition),
},
"s4_0.dl-0-123": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"dl": &LogPosition{
Partition: 0,
Offset: 123,
},
},
},
"s4_0.dl-0-123.ab-1-14419482332": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"ab": &LogPosition{
Partition: 1,
Offset: 14419482332,
},
"dl": &LogPosition{
Partition: 0,
Offset: 123,
},
},
},
}
for tok, want := range tests {
got, err := NewStreamTokenFromString(tok)
if err != nil {
if want == nil {
continue // error expected
}
t.Errorf("%s errored: %s", tok, err)
continue
}
if !reflect.DeepEqual(got, *want) {
t.Errorf("%s mismatch: got %v want %v", tok, got, want)
}
if got.String() != tok {
t.Errorf("%s reserialisation mismatch: got %s want %s", tok, got.String(), tok)
}
}
}
func TestNewSyncTokenFromString(t *testing.T) {
shouldPass := map[string]syncToken{
"s4_0": NewStreamToken(4, 0).syncToken,
"s3_1": NewStreamToken(3, 1).syncToken,
"s4_0": NewStreamToken(4, 0, nil).syncToken,
"s3_1": NewStreamToken(3, 1, nil).syncToken,
"t3_1": NewTopologyToken(3, 1).syncToken,
}
@ -21,7 +71,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
}
for test, expected := range shouldPass {
result, err := newSyncTokenFromString(test)
result, _, err := newSyncTokenFromString(test)
if err != nil {
t.Error(err)
}
@ -31,7 +81,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
}
for _, test := range shouldFail {
if _, err := newSyncTokenFromString(test); err == nil {
if _, _, err := newSyncTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test)
}
}

View file

@ -127,6 +127,7 @@ Can query specific device keys using POST
query for user with no keys returns empty key dict
Can claim one time key using POST
Can claim remote one time key using POST
Local device key changes appear in v2 /sync
Can add account data
Can add account data to room
Can get account data without syncing