Merge branch 'master' into neilalexander/config

This commit is contained in:
Neil Alexander 2020-07-30 13:58:28 +01:00
commit 930ced1102
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
33 changed files with 908 additions and 259 deletions

View file

@ -29,7 +29,9 @@ func main() {
rsAPI := base.RoomserverHTTPClient() rsAPI := base.RoomserverHTTPClient()
syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, &cfg.SyncAPI) syncapi.AddPublicRoutes(
base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, base.KeyServerHTTPClient(), base.CurrentStateAPIClient(),
federation, &cfg.SyncAPI)
base.SetupAndServeHTTP(string(base.Cfg.SyncAPI.Bind), string(base.Cfg.SyncAPI.Listen)) base.SetupAndServeHTTP(string(base.Cfg.SyncAPI.Bind), string(base.Cfg.SyncAPI.Listen))

View file

@ -77,6 +77,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
) )
mediaapi.AddPublicRoutes(publicMux, &m.Config.MediaAPI, m.UserAPI, m.Client) mediaapi.AddPublicRoutes(publicMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
syncapi.AddPublicRoutes( syncapi.AddPublicRoutes(
publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, &m.Config.SyncAPI, publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI,
m.KeyAPI, m.StateAPI, m.FedClient, &m.Config.SyncAPI,
) )
} }

View file

@ -26,6 +26,7 @@ type KeyInternalAPI interface {
// PerformClaimKeys claims one-time keys for use in pre-key messages // PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
} }
// KeyError is returned if there was a problem performing/querying the server // KeyError is returned if there was a problem performing/querying the server
@ -131,3 +132,21 @@ type QueryKeysResponse struct {
// Set if there was a fatal error processing this query // Set if there was a fatal error processing this query
Error *KeyError Error *KeyError
} }
type QueryKeyChangesRequest struct {
// The partition which had key events sent to
Partition int32
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
Offset int64
}
type QueryKeyChangesResponse struct {
// The set of users who have had their keys change.
UserIDs []string
// The partition being served - useful if the partition is unknown at request time
Partition int32
// The latest offset represented in this response.
Offset int64
// Set if there was a problem handling the request.
Error *KeyError
}

View file

@ -40,6 +40,21 @@ type KeyInternalAPI struct {
Producer *producers.KeyChange Producer *producers.KeyChange
} }
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
if req.Partition < 0 {
req.Partition = a.Producer.DefaultPartition()
}
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
}
res.Offset = latest
res.Partition = req.Partition
res.UserIDs = userIDs
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError) res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadDeviceKeys(ctx, req, res) a.uploadDeviceKeys(ctx, req, res)

View file

@ -29,6 +29,7 @@ const (
PerformUploadKeysPath = "/keyserver/performUploadKeys" PerformUploadKeysPath = "/keyserver/performUploadKeys"
PerformClaimKeysPath = "/keyserver/performClaimKeys" PerformClaimKeysPath = "/keyserver/performClaimKeys"
QueryKeysPath = "/keyserver/queryKeys" QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
) )
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API. // NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
@ -101,3 +102,20 @@ func (h *httpKeyInternalAPI) QueryKeys(
} }
} }
} }
func (h *httpKeyInternalAPI) QueryKeyChanges(
ctx context.Context,
request *api.QueryKeyChangesRequest,
response *api.QueryKeyChangesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
defer span.Finish()
apiURL := h.apiURL + QueryKeyChangesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}

View file

@ -58,4 +58,15 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryKeyChangesPath,
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
request := api.QueryKeyChangesRequest{}
response := api.QueryKeyChangesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QueryKeyChanges(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -46,6 +46,7 @@ func NewInternalAPI(
keyChangeProducer := &producers.KeyChange{ keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent), Topic: string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
Producer: producer, Producer: producer,
DB: db,
} }
return &internal.KeyInternalAPI{ return &internal.KeyInternalAPI{
DB: db, DB: db,

View file

@ -15,10 +15,12 @@
package producers package producers
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -26,6 +28,16 @@ import (
type KeyChange struct { type KeyChange struct {
Topic string Topic string
Producer sarama.SyncProducer Producer sarama.SyncProducer
DB storage.Database
}
// DefaultPartition returns the default partition this process is sending key changes to.
// NB: A keyserver MUST send key changes to only 1 partition or else query operations will
// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but
// then all keyservers must be queried to calculate the entire set of key changes between
// two sync tokens.
func (p *KeyChange) DefaultPartition() int32 {
return 0
} }
// ProduceKeyChanges creates new change events for each key // ProduceKeyChanges creates new change events for each key
@ -46,6 +58,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
if err != nil { if err != nil {
return err return err
} }
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
if err != nil {
return err
}
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"user_id": key.UserID, "user_id": key.UserID,
"device_id": key.DeviceID, "device_id": key.DeviceID,

View file

@ -43,4 +43,12 @@ type Database interface {
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
// their keys in some way.
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
// KeyChanges returns a list of user IDs who have modified their keys from the offset given.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
} }

View file

@ -0,0 +1,97 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
log_offset BIGINT NOT NULL,
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -35,9 +35,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil }, nil
} }

View file

@ -28,6 +28,7 @@ type Database struct {
DB *sql.DB DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
} }
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
}) })
return result, err return result, err
} }
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
}
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
}

View file

@ -0,0 +1,98 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
offset BIGINT NOT NULL,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (partition, offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT (partition, offset)" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -33,9 +33,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil }, nil
} }

View file

@ -0,0 +1,57 @@
package storage
import (
"context"
"reflect"
"testing"
)
var ctx = context.Background()
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
func TestKeyChangesNoDupes(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}

View file

@ -35,3 +35,8 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
} }
type KeyChanges interface {
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
}

View file

@ -91,7 +91,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0)) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil))
return nil return nil
} }

View file

@ -106,7 +106,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
s.notifier.OnNewSendToDevice( s.notifier.OnNewSendToDevice(
output.UserID, output.UserID,
[]string{output.DeviceID}, []string{output.DeviceID},
types.NewStreamToken(0, streamPos), types.NewStreamToken(0, streamPos, nil),
) )
return nil return nil

View file

@ -65,7 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent( s.notifier.OnNewEvent(
nil, roomID, nil, nil, roomID, nil,
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)), types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
) )
}) })
@ -94,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos)) s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil))
return nil return nil
} }

View file

@ -23,7 +23,9 @@ import (
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
syncinternal "github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
syncapi "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -31,13 +33,14 @@ import (
// OutputKeyChangeEventConsumer consumes events that originated in the key server. // OutputKeyChangeEventConsumer consumes events that originated in the key server.
type OutputKeyChangeEventConsumer struct { type OutputKeyChangeEventConsumer struct {
keyChangeConsumer *internal.ContinualConsumer keyChangeConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
currentStateAPI currentstateAPI.CurrentStateInternalAPI currentStateAPI currentstateAPI.CurrentStateInternalAPI
// keyAPI api.KeyInternalAPI keyAPI api.KeyInternalAPI
partitionToOffset map[int32]int64 partitionToOffset map[int32]int64
partitionToOffsetMu sync.Mutex partitionToOffsetMu sync.Mutex
notifier *syncapi.Notifier
} }
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
@ -46,6 +49,8 @@ func NewOutputKeyChangeEventConsumer(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
topic string, topic string,
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
n *syncapi.Notifier,
keyAPI api.KeyInternalAPI,
currentStateAPI currentstateAPI.CurrentStateInternalAPI, currentStateAPI currentstateAPI.CurrentStateInternalAPI,
store storage.Database, store storage.Database,
) *OutputKeyChangeEventConsumer { ) *OutputKeyChangeEventConsumer {
@ -60,9 +65,11 @@ func NewOutputKeyChangeEventConsumer(
keyChangeConsumer: &consumer, keyChangeConsumer: &consumer,
db: store, db: store,
serverName: serverName, serverName: serverName,
keyAPI: keyAPI,
currentStateAPI: currentStateAPI, currentStateAPI: currentStateAPI,
partitionToOffset: make(map[int32]int64), partitionToOffset: make(map[int32]int64),
partitionToOffsetMu: sync.Mutex{}, partitionToOffsetMu: sync.Mutex{},
notifier: n,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -107,36 +114,22 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
return err return err
} }
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams // TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
return nil posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{
} syncinternal.DeviceListLogName: &types.LogPosition{
Offset: msg.Offset,
// Catchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response Partition: msg.Partition,
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST },
// be already filled in with join/leave information. })
func (s *OutputKeyChangeEventConsumer) Catchup( for userID := range queryRes.UserIDsToCount {
ctx context.Context, userID string, res *types.Response, tok types.StreamingToken, s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
) (hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID)
newlyLeftRooms := leftRooms(res)
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
changed, left, err := s.trackChangedUsers(ctx, userID, newlyJoinedRooms, newlyLeftRooms)
if err != nil {
return false, err
}
res.DeviceLists.Changed = changed
res.DeviceLists.Left = left
hasNew = len(changed) > 0 || len(left) > 0
} }
return nil
// TODO: now also track users who we already share rooms with but who have updated their devices between the two tokens
return
} }
func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) { func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) {
// work out who we are now sharing rooms with which we previously were not and notify them about the joining // work out who we are now sharing rooms with which we previously were not and notify them about the joining
// users keys: // users keys:
changed, _, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), []string{ev.RoomID()}, nil) changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), []string{ev.RoomID()}, nil)
if err != nil { if err != nil {
log.WithError(err).Error("OnJoinEvent: failed to work out changed users") log.WithError(err).Error("OnJoinEvent: failed to work out changed users")
return return
@ -149,7 +142,7 @@ func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.Headere
func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) { func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) {
// work out who we are no longer sharing any rooms with and notify them about the leaving user // work out who we are no longer sharing any rooms with and notify them about the leaving user
_, left, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), nil, []string{ev.RoomID()}) _, left, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), nil, []string{ev.RoomID()})
if err != nil { if err != nil {
log.WithError(err).Error("OnLeaveEvent: failed to work out left users") log.WithError(err).Error("OnLeaveEvent: failed to work out left users")
return return
@ -160,129 +153,3 @@ func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.Header
} }
} }
// nolint:gocyclo
func (s *OutputKeyChangeEventConsumer) trackChangedUsers(
ctx context.Context, userID string, newlyJoinedRooms, newlyLeftRooms []string,
) (changed, left []string, err error) {
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
// Leave algorithm:
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
// - Get users in newly left room. - QueryCurrentState
// - Loop set of users and decrement by 1 for each user in newly left room.
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
var queryRes currentstateAPI.QuerySharedUsersResponse
err = s.currentStateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
IncludeRoomIDs: newlyLeftRooms,
}, &queryRes)
if err != nil {
return nil, nil, err
}
var stateRes currentstateAPI.QueryBulkStateContentResponse
err = s.currentStateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyLeftRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, nil, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
queryRes.UserIDsToCount[tuple.StateKey]--
}
}
for userID, count := range queryRes.UserIDsToCount {
if count <= 0 {
left = append(left, userID) // left is returned
}
}
// Join algorithm:
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
// - Get users in newly joined room - QueryCurrentState
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
// - If yes: then they already shared a room in common, do nothing.
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
err = s.currentStateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
ExcludeRoomIDs: newlyJoinedRooms,
}, &queryRes)
if err != nil {
return nil, left, err
}
err = s.currentStateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyJoinedRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, left, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned
}
}
}
return changed, left, nil
}
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {
// we would expect to see our join event somewhere if we newly joined the room.
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
newlyJoined := membershipEventPresent(join.State.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
continue
}
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
}
}
return roomIDs
}
func leftRooms(res *types.Response) []string {
roomIDs := make([]string, len(res.Rooms.Leave))
i := 0
for roomID := range res.Rooms.Leave {
roomIDs[i] = roomID
i++
}
return roomIDs
}
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
for _, ev := range events {
// it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
return true
}
}
return false
}

View file

@ -158,7 +158,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0)) s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
return nil return nil
} }
@ -176,7 +176,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0)) s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil))
return nil return nil
} }
@ -194,7 +194,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
} }
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs // Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0)) s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil))
return nil return nil
} }

View file

@ -0,0 +1,219 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"github.com/Shopify/sarama"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/keyserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
const DeviceListLogName = "dl"
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
userID string, res *types.Response, tok types.StreamingToken,
) (newTok *types.StreamingToken, hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID)
newlyLeftRooms := leftRooms(res)
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
changed, left, err := TrackChangedUsers(ctx, stateAPI, userID, newlyJoinedRooms, newlyLeftRooms)
if err != nil {
return nil, false, err
}
res.DeviceLists.Changed = changed
res.DeviceLists.Left = left
hasNew = len(changed) > 0 || len(left) > 0
}
// now also track users who we already share rooms with but who have updated their devices between the two tokens
var partition int32
var offset int64
// Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
logOffset := tok.Log(DeviceListLogName)
if logOffset != nil {
partition = logOffset.Partition
offset = logOffset.Offset
} else {
partition = -1
offset = sarama.OffsetOldest
}
var queryRes api.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
Partition: partition,
Offset: offset,
}, &queryRes)
if queryRes.Error != nil {
// don't fail the catchup because we may have got useful information by tracking membership
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
return
}
userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true
}
for _, userID := range queryRes.UserIDs {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
}
}
// Make a new streaming token using the new offset
tok.SetLog(DeviceListLogName, &types.LogPosition{
Offset: queryRes.Offset,
Partition: queryRes.Partition,
})
newTok = &tok
return
}
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.
// nolint:gocyclo
func TrackChangedUsers(
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string,
) (changed, left []string, err error) {
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
// Leave algorithm:
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
// - Get users in newly left room. - QueryCurrentState
// - Loop set of users and decrement by 1 for each user in newly left room.
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
var queryRes currentstateAPI.QuerySharedUsersResponse
err = stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
IncludeRoomIDs: newlyLeftRooms,
}, &queryRes)
if err != nil {
return nil, nil, err
}
var stateRes currentstateAPI.QueryBulkStateContentResponse
err = stateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyLeftRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, nil, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
queryRes.UserIDsToCount[tuple.StateKey]--
}
}
for userID, count := range queryRes.UserIDsToCount {
if count <= 0 {
left = append(left, userID) // left is returned
}
}
// Join algorithm:
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
// - Get users in newly joined room - QueryCurrentState
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
// - If yes: then they already shared a room in common, do nothing.
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
err = stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
UserID: userID,
ExcludeRoomIDs: newlyJoinedRooms,
}, &queryRes)
if err != nil {
return nil, left, err
}
err = stateAPI.QueryBulkStateContent(ctx, &currentstateAPI.QueryBulkStateContentRequest{
RoomIDs: newlyJoinedRooms,
StateTuples: []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomMember,
StateKey: "*",
},
},
AllowWildcards: true,
}, &stateRes)
if err != nil {
return nil, left, err
}
for _, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != gomatrixserverlib.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned
}
}
}
return changed, left, nil
}
func joinedRooms(res *types.Response, userID string) []string {
var roomIDs []string
for roomID, join := range res.Rooms.Join {
// we would expect to see our join event somewhere if we newly joined the room.
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
newlyJoined := membershipEventPresent(join.State.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
continue
}
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
if newlyJoined {
roomIDs = append(roomIDs, roomID)
}
}
return roomIDs
}
func leftRooms(res *types.Response) []string {
roomIDs := make([]string, len(res.Rooms.Leave))
i := 0
for roomID := range res.Rooms.Leave {
roomIDs[i] = roomID
i++
}
return roomIDs
}
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
for _, ev := range events {
// it's enough to know that we have our member event here, don't need to check membership content
// as it's implied by being in the respective section of the sync response.
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
return true
}
}
return false
}

View file

@ -1,4 +1,4 @@
package consumers package internal
import ( import (
"context" "context"
@ -7,14 +7,29 @@ import (
"testing" "testing"
"github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
var ( var (
syncingUser = "@alice:localhost" syncingUser = "@alice:localhost"
emptyToken = types.NewStreamToken(0, 0, nil)
) )
type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
}
type mockCurrentStateAPI struct { type mockCurrentStateAPI struct {
roomIDToJoinedMembers map[string][]string roomIDToJoinedMembers map[string][]string
} }
@ -144,18 +159,17 @@ func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs
func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
newShareUser := "@bill:localhost" newShareUser := "@bill:localhost"
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{ roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, newShareUser}, newlyJoinedRoom: {syncingUser, newShareUser},
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
}, nil) }, syncingUser, syncResponse, emptyToken)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
assertCatchup(t, hasNew, syncResponse, wantCatchup{ assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true, hasNew: true,
@ -167,18 +181,17 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
removeUser := "@bill:localhost" removeUser := "@bill:localhost"
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{ roomIDToJoinedMembers: map[string][]string{
newlyLeftRoom: {removeUser}, newlyLeftRoom: {removeUser},
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
}, nil) }, syncingUser, syncResponse, emptyToken)
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
assertCatchup(t, hasNew, syncResponse, wantCatchup{ assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true, hasNew: true,
@ -190,16 +203,15 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
existingUser := "@bob:localhost" existingUser := "@bob:localhost"
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar" newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{ roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, existingUser}, newlyJoinedRoom: {syncingUser, existingUser},
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
}, nil) }, syncingUser, syncResponse, emptyToken)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -212,18 +224,17 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
existingUser := "@bob:localhost" existingUser := "@bob:localhost"
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar" newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{ roomIDToJoinedMembers: map[string][]string{
newlyLeftRoom: {existingUser}, newlyLeftRoom: {existingUser},
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
}, nil) }, syncingUser, syncResponse, emptyToken)
syncResponse := types.NewResponse()
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
assertCatchup(t, hasNew, syncResponse, wantCatchup{ assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: false, hasNew: false,
@ -234,11 +245,6 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
existingUser := "@bob1:localhost" existingUser := "@bob1:localhost"
roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar" roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {syncingUser, existingUser},
},
}, nil)
syncResponse := types.NewResponse() syncResponse := types.NewResponse()
empty := "" empty := ""
roomStateEvents := []gomatrixserverlib.ClientEvent{ roomStateEvents := []gomatrixserverlib.ClientEvent{
@ -280,9 +286,13 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
jr.Timeline.Events = roomTimelineEvents jr.Timeline.Events = roomTimelineEvents
syncResponse.Rooms.Join[roomID] = jr syncResponse.Rooms.Join[roomID] = jr
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {syncingUser, existingUser},
},
}, syncingUser, syncResponse, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
assertCatchup(t, hasNew, syncResponse, wantCatchup{ assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: false, hasNew: false,
@ -297,18 +307,17 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
newlyLeftUser2 := "@debra:localhost" newlyLeftUser2 := "@debra:localhost"
newlyJoinedRoom := "!join:bar" newlyJoinedRoom := "!join:bar"
newlyLeftRoom := "!left:bar" newlyLeftRoom := "!left:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{ syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{ roomIDToJoinedMembers: map[string][]string{
newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2}, newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2},
newlyLeftRoom: {newlyLeftUser, newlyLeftUser2}, newlyLeftRoom: {newlyLeftUser, newlyLeftUser2},
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
}, nil) }, syncingUser, syncResponse, emptyToken)
syncResponse := types.NewResponse()
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -333,12 +342,6 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
newShareUser := "@berta:localhost" newShareUser := "@berta:localhost"
newShareUser2 := "@bobby:localhost" newShareUser2 := "@bobby:localhost"
roomID := "!join:bar" roomID := "!join:bar"
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {newShareUser, newShareUser2},
"!another:room": {syncingUser},
},
}, nil)
syncResponse := types.NewResponse() syncResponse := types.NewResponse()
roomEvents := []gomatrixserverlib.ClientEvent{ roomEvents := []gomatrixserverlib.ClientEvent{
{ {
@ -393,9 +396,14 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
lr.Timeline.Events = roomEvents lr.Timeline.Events = roomEvents
syncResponse.Rooms.Leave[roomID] = lr syncResponse.Rooms.Leave[roomID] = lr
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0)) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
roomIDToJoinedMembers: map[string][]string{
roomID: {newShareUser, newShareUser2},
"!another:room": {syncingUser},
},
}, syncingUser, syncResponse, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
assertCatchup(t, hasNew, syncResponse, wantCatchup{ assertCatchup(t, hasNew, syncResponse, wantCatchup{
hasNew: true, hasNew: true,

View file

@ -434,7 +434,7 @@ func (d *Database) syncPositionTx(
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition())) sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil)
return return
} }
@ -731,7 +731,7 @@ func (d *Database) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res, types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -166,7 +166,7 @@ func TestSyncResponse(t *testing.T) {
Name: "IncrementalSync penultimate", Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are at the penultimate event from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0), nil,
) )
res := types.NewResponse() res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false) return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
@ -179,7 +179,7 @@ func TestSyncResponse(t *testing.T) {
Name: "IncrementalSync limited", Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewStreamToken( // pretend we are 10 events behind from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0), positions[len(positions)-11], types.StreamPosition(0), nil,
) )
res := types.NewResponse() res := types.NewResponse()
// limit is set to 5 // limit is set to 5
@ -222,7 +222,7 @@ func TestSyncResponse(t *testing.T) {
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition()) next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil)
if res.NextBatch != next.String() { if res.NextBatch != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
} }
@ -246,7 +246,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
from := types.NewStreamToken( from := types.NewStreamToken(
positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0), nil,
) )
res := types.NewResponse() res := types.NewResponse()
@ -291,7 +291,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewStreamToken(0, 0) to := types.NewStreamToken(0, 0, nil)
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
@ -534,14 +534,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0)) events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates") t.Fatal("first call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil))
if err != nil { if err != nil {
return return
} }
@ -559,14 +559,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update") t.Fatal("second call should have one update")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
if err != nil { if err != nil {
return return
} }
@ -574,35 +574,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos)) events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 { if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still") t.Fatal("third call should have one update still")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
if err != nil { if err != nil {
return return
} }
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1)) events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 { if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates") t.Fatal("fourth call should have no updates")
} }
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1)) err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil))
if err != nil { if err != nil {
return return
} }
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2)) events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -639,7 +639,7 @@ func TestInviteBehaviour(t *testing.T) {
} }
// both invite events should appear in a new sync // both invite events should appear in a new sync
beforeRetireRes := types.NewResponse() beforeRetireRes := types.NewResponse()
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false) beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }
@ -654,7 +654,7 @@ func TestInviteBehaviour(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
res := types.NewResponse() res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }

View file

@ -132,6 +132,16 @@ func (n *Notifier) OnNewSendToDevice(
n.wakeupUserDevice(userID, deviceIDs, latestPos) n.wakeupUserDevice(userID, deviceIDs, latestPos)
} }
func (n *Notifier) OnNewKeyChange(
posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUsers([]string{wakeUserID}, latestPos)
}
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos

View file

@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld = types.NewStreamToken(5, 0) syncPositionVeryOld = types.NewStreamToken(5, 0, nil)
syncPositionBefore = types.NewStreamToken(11, 0) syncPositionBefore = types.NewStreamToken(11, 0, nil)
syncPositionAfter = types.NewStreamToken(12, 0) syncPositionAfter = types.NewStreamToken(12, 0, nil)
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1) syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil)
syncPositionAfter2 = types.NewStreamToken(13, 0) syncPositionAfter2 = types.NewStreamToken(13, 0, nil)
) )
var ( var (

View file

@ -65,7 +65,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
since = &tok since = &tok
} }
if since == nil { if since == nil {
tok := types.NewStreamToken(0, 0) tok := types.NewStreamToken(0, 0, nil)
since = &tok since = &tok
} }
timelineLimit := DefaultTimelineLimit timelineLimit := DefaultTimelineLimit

View file

@ -22,6 +22,9 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -35,11 +38,16 @@ type RequestPool struct {
db storage.Database db storage.Database
userAPI userapi.UserInternalAPI userAPI userapi.UserInternalAPI
notifier *Notifier notifier *Notifier
keyAPI keyapi.KeyInternalAPI
stateAPI currentstateAPI.CurrentStateInternalAPI
} }
// NewRequestPool makes a new RequestPool // NewRequestPool makes a new RequestPool
func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool { func NewRequestPool(
return &RequestPool{db, userAPI, n} db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
) *RequestPool {
return &RequestPool{db, userAPI, n, keyAPI, stateAPI}
} }
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
@ -138,7 +146,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
res = types.NewResponse() res = types.NewResponse()
since := types.NewStreamToken(0, 0) since := types.NewStreamToken(0, 0, nil)
if req.since != nil { if req.since != nil {
since = *req.since since = *req.since
} }
@ -164,6 +172,10 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
if err != nil { if err != nil {
return return
} }
res, err = rp.appendDeviceLists(res, req.device.UserID, since)
if err != nil {
return
}
// Before we return the sync response, make sure that we take action on // Before we return the sync response, make sure that we take action on
// any send-to-device database updates or deletions that we need to do. // any send-to-device database updates or deletions that we need to do.
@ -192,6 +204,22 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
return return
} }
func (rp *RequestPool) appendDeviceLists(
data *types.Response, userID string, since types.StreamingToken,
) (*types.Response, error) {
// TODO: Currently this code will race which may result in duplicates but not missing data.
// This happens because, whilst we are told the range to fetch here (since / latest) the
// QueryKeyChanges API only exposes a "from" value (on purpose to avoid racing, which then
// returns the latest position with which the response has authority on). We'd need to tweak
// the API to expose a "to" value to fix this.
_, _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since)
if err != nil {
return nil, err
}
return data, nil
}
// nolint:gocyclo // nolint:gocyclo
func (rp *RequestPool) appendAccountData( func (rp *RequestPool) appendAccountData(
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition, data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition,

View file

@ -21,7 +21,9 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
currentstateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -39,6 +41,8 @@ func AddPublicRoutes(
consumer sarama.Consumer, consumer sarama.Consumer,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
keyAPI keyapi.KeyInternalAPI,
currentStateAPI currentstateapi.CurrentStateInternalAPI,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
cfg *config.SyncAPI, cfg *config.SyncAPI,
) { ) {
@ -58,7 +62,7 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start notifier") logrus.WithError(err).Panicf("failed to start notifier")
} }
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI) requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
cfg, consumer, notifier, syncDB, rsAPI, cfg, consumer, notifier, syncDB, rsAPI,
@ -88,5 +92,13 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start send-to-device consumer") logrus.WithError(err).Panicf("failed to start send-to-device consumer")
} }
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent),
consumer, notifier, keyAPI, currentStateAPI, syncDB,
)
if err = keyChangeConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start key change consumer")
}
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg) routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
} }

View file

@ -39,6 +39,23 @@ var (
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64 type StreamPosition int64
// LogPosition represents the offset in a Kafka log a client is at.
type LogPosition struct {
Partition int32
Offset int64
}
// IsAfter returns true if this position is after `lp`.
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
if lp == nil {
return false
}
if p.Partition != lp.Partition {
return false
}
return p.Offset > lp.Offset
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct { type StreamEvent struct {
gomatrixserverlib.HeaderedEvent gomatrixserverlib.HeaderedEvent
@ -90,6 +107,19 @@ const (
type StreamingToken struct { type StreamingToken struct {
syncToken syncToken
logs map[string]*LogPosition
}
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
t.logs[name] = lp
}
func (t *StreamingToken) Log(name string) *LogPosition {
l, ok := t.logs[name]
if !ok {
return nil
}
return l
} }
func (t *StreamingToken) PDUPosition() StreamPosition { func (t *StreamingToken) PDUPosition() StreamPosition {
@ -99,7 +129,15 @@ func (t *StreamingToken) EDUPosition() StreamPosition {
return t.Positions[1] return t.Positions[1]
} }
func (t *StreamingToken) String() string { func (t *StreamingToken) String() string {
return t.syncToken.String() logStrings := []string{
t.syncToken.String(),
}
for name, lp := range t.logs {
logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset)
logStrings = append(logStrings, logStr)
}
// E.g s11_22_33.dl0-134.ab1-441
return strings.Join(logStrings, ".")
} }
// IsAfter returns true if ANY position in this token is greater than `other`. // IsAfter returns true if ANY position in this token is greater than `other`.
@ -109,12 +147,22 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true return true
} }
} }
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
if t.logs[name].IsAfter(otherLog) {
return true
}
}
return false return false
} }
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// If the latter StreamingToken contains a field that is not 0, it is considered an update, // If the latter StreamingToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. // and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
// If the other token has a log, they will replace any existing log on this token.
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
ret.Type = t.Type ret.Type = t.Type
ret.Positions = make([]StreamPosition, len(t.Positions)) ret.Positions = make([]StreamPosition, len(t.Positions))
@ -125,6 +173,13 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
} }
ret.Positions[i] = other.Positions[i] ret.Positions[i] = other.Positions[i]
} }
for name := range t.logs {
otherLog := other.Log(name)
if otherLog == nil {
continue
}
t.logs[name] = otherLog
}
return ret return ret
} }
@ -139,7 +194,7 @@ func (t *TopologyToken) PDUPosition() StreamPosition {
return t.Positions[1] return t.Positions[1]
} }
func (t *TopologyToken) StreamToken() StreamingToken { func (t *TopologyToken) StreamToken() StreamingToken {
return NewStreamToken(t.PDUPosition(), 0) return NewStreamToken(t.PDUPosition(), 0, nil)
} }
func (t *TopologyToken) String() string { func (t *TopologyToken) String() string {
return t.syncToken.String() return t.syncToken.String()
@ -174,9 +229,9 @@ func (t *TopologyToken) Decrement() {
// error if the token couldn't be parsed into an int64, or if the token type // error if the token couldn't be parsed into an int64, or if the token type
// isn't a known type (returns ErrInvalidSyncTokenType in the latter // isn't a known type (returns ErrInvalidSyncTokenType in the latter
// case). // case).
func newSyncTokenFromString(s string) (token *syncToken, err error) { func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) {
if len(s) == 0 { if len(s) == 0 {
return nil, ErrInvalidSyncTokenLen return nil, nil, ErrInvalidSyncTokenLen
} }
token = new(syncToken) token = new(syncToken)
@ -185,16 +240,17 @@ func newSyncTokenFromString(s string) (token *syncToken, err error) {
switch t := SyncTokenType(s[:1]); t { switch t := SyncTokenType(s[:1]); t {
case SyncTokenTypeStream, SyncTokenTypeTopology: case SyncTokenTypeStream, SyncTokenTypeTopology:
token.Type = t token.Type = t
positions = strings.Split(s[1:], "_") categories = strings.Split(s[1:], ".")
positions = strings.Split(categories[0], "_")
default: default:
return nil, ErrInvalidSyncTokenType return nil, nil, ErrInvalidSyncTokenType
} }
for _, pos := range positions { for _, pos := range positions {
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil { if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
return nil, err return nil, nil, err
} else if posInt < 0 { } else if posInt < 0 {
return nil, errors.New("negative position not allowed") return nil, nil, errors.New("negative position not allowed")
} else { } else {
token.Positions = append(token.Positions, StreamPosition(posInt)) token.Positions = append(token.Positions, StreamPosition(posInt))
} }
@ -215,7 +271,7 @@ func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
} }
} }
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
t, err := newSyncTokenFromString(tok) t, _, err := newSyncTokenFromString(tok)
if err != nil { if err != nil {
return return
} }
@ -233,16 +289,20 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
} }
// NewStreamToken creates a new sync token for /sync // NewStreamToken creates a new sync token for /sync
func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken { func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken {
if logs == nil {
logs = make(map[string]*LogPosition)
}
return StreamingToken{ return StreamingToken{
syncToken: syncToken{ syncToken: syncToken{
Type: SyncTokenTypeStream, Type: SyncTokenTypeStream,
Positions: []StreamPosition{pduPos, eduPos}, Positions: []StreamPosition{pduPos, eduPos},
}, },
logs: logs,
} }
} }
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
t, err := newSyncTokenFromString(tok) t, categories, err := newSyncTokenFromString(tok)
if err != nil { if err != nil {
return return
} }
@ -254,8 +314,35 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions)) err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
return return
} }
logs := make(map[string]*LogPosition)
if len(categories) > 1 {
// dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("token %s - invalid log: %s", tok, logStr)
return
}
var partition int64
partition, err = strconv.ParseInt(segments[1], 10, 32)
if err != nil {
return
}
var offset int64
offset, err = strconv.ParseInt(segments[2], 10, 64)
if err != nil {
return
}
logs[segments[0]] = &LogPosition{
Partition: int32(partition),
Offset: offset,
}
}
}
return StreamingToken{ return StreamingToken{
syncToken: *t, syncToken: *t,
logs: logs,
}, nil }, nil
} }

View file

@ -1,11 +1,61 @@
package types package types
import "testing" import (
"reflect"
"testing"
)
func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{
"s4_0": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: make(map[string]*LogPosition),
},
"s4_0.dl-0-123": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"dl": &LogPosition{
Partition: 0,
Offset: 123,
},
},
},
"s4_0.dl-0-123.ab-1-14419482332": &StreamingToken{
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
logs: map[string]*LogPosition{
"ab": &LogPosition{
Partition: 1,
Offset: 14419482332,
},
"dl": &LogPosition{
Partition: 0,
Offset: 123,
},
},
},
}
for tok, want := range tests {
got, err := NewStreamTokenFromString(tok)
if err != nil {
if want == nil {
continue // error expected
}
t.Errorf("%s errored: %s", tok, err)
continue
}
if !reflect.DeepEqual(got, *want) {
t.Errorf("%s mismatch: got %v want %v", tok, got, want)
}
if got.String() != tok {
t.Errorf("%s reserialisation mismatch: got %s want %s", tok, got.String(), tok)
}
}
}
func TestNewSyncTokenFromString(t *testing.T) { func TestNewSyncTokenFromString(t *testing.T) {
shouldPass := map[string]syncToken{ shouldPass := map[string]syncToken{
"s4_0": NewStreamToken(4, 0).syncToken, "s4_0": NewStreamToken(4, 0, nil).syncToken,
"s3_1": NewStreamToken(3, 1).syncToken, "s3_1": NewStreamToken(3, 1, nil).syncToken,
"t3_1": NewTopologyToken(3, 1).syncToken, "t3_1": NewTopologyToken(3, 1).syncToken,
} }
@ -21,7 +71,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
} }
for test, expected := range shouldPass { for test, expected := range shouldPass {
result, err := newSyncTokenFromString(test) result, _, err := newSyncTokenFromString(test)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -31,7 +81,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
} }
for _, test := range shouldFail { for _, test := range shouldFail {
if _, err := newSyncTokenFromString(test); err == nil { if _, _, err := newSyncTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test) t.Errorf("input '%v' should have errored but didn't", test)
} }
} }

View file

@ -127,6 +127,7 @@ Can query specific device keys using POST
query for user with no keys returns empty key dict query for user with no keys returns empty key dict
Can claim one time key using POST Can claim one time key using POST
Can claim remote one time key using POST Can claim remote one time key using POST
Local device key changes appear in v2 /sync
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing