mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Merge branch 'master' into neilalexander/config
This commit is contained in:
commit
930ced1102
|
|
@ -29,7 +29,9 @@ func main() {
|
||||||
|
|
||||||
rsAPI := base.RoomserverHTTPClient()
|
rsAPI := base.RoomserverHTTPClient()
|
||||||
|
|
||||||
syncapi.AddPublicRoutes(base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, federation, &cfg.SyncAPI)
|
syncapi.AddPublicRoutes(
|
||||||
|
base.PublicAPIMux, base.KafkaConsumer, userAPI, rsAPI, base.KeyServerHTTPClient(), base.CurrentStateAPIClient(),
|
||||||
|
federation, &cfg.SyncAPI)
|
||||||
|
|
||||||
base.SetupAndServeHTTP(string(base.Cfg.SyncAPI.Bind), string(base.Cfg.SyncAPI.Listen))
|
base.SetupAndServeHTTP(string(base.Cfg.SyncAPI.Bind), string(base.Cfg.SyncAPI.Listen))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ func (m *Monolith) AddAllPublicRoutes(publicMux *mux.Router) {
|
||||||
)
|
)
|
||||||
mediaapi.AddPublicRoutes(publicMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
|
mediaapi.AddPublicRoutes(publicMux, &m.Config.MediaAPI, m.UserAPI, m.Client)
|
||||||
syncapi.AddPublicRoutes(
|
syncapi.AddPublicRoutes(
|
||||||
publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI, m.FedClient, &m.Config.SyncAPI,
|
publicMux, m.KafkaConsumer, m.UserAPI, m.RoomserverAPI,
|
||||||
|
m.KeyAPI, m.StateAPI, m.FedClient, &m.Config.SyncAPI,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ type KeyInternalAPI interface {
|
||||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
|
||||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
|
||||||
|
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyError is returned if there was a problem performing/querying the server
|
// KeyError is returned if there was a problem performing/querying the server
|
||||||
|
|
@ -131,3 +132,21 @@ type QueryKeysResponse struct {
|
||||||
// Set if there was a fatal error processing this query
|
// Set if there was a fatal error processing this query
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryKeyChangesRequest struct {
|
||||||
|
// The partition which had key events sent to
|
||||||
|
Partition int32
|
||||||
|
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
|
||||||
|
Offset int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryKeyChangesResponse struct {
|
||||||
|
// The set of users who have had their keys change.
|
||||||
|
UserIDs []string
|
||||||
|
// The partition being served - useful if the partition is unknown at request time
|
||||||
|
Partition int32
|
||||||
|
// The latest offset represented in this response.
|
||||||
|
Offset int64
|
||||||
|
// Set if there was a problem handling the request.
|
||||||
|
Error *KeyError
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,21 @@ type KeyInternalAPI struct {
|
||||||
Producer *producers.KeyChange
|
Producer *producers.KeyChange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||||
|
if req.Partition < 0 {
|
||||||
|
req.Partition = a.Producer.DefaultPartition()
|
||||||
|
}
|
||||||
|
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.Offset = latest
|
||||||
|
res.Partition = req.Partition
|
||||||
|
res.UserIDs = userIDs
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||||
a.uploadDeviceKeys(ctx, req, res)
|
a.uploadDeviceKeys(ctx, req, res)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ const (
|
||||||
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
PerformUploadKeysPath = "/keyserver/performUploadKeys"
|
||||||
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
PerformClaimKeysPath = "/keyserver/performClaimKeys"
|
||||||
QueryKeysPath = "/keyserver/queryKeys"
|
QueryKeysPath = "/keyserver/queryKeys"
|
||||||
|
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
|
@ -101,3 +102,20 @@ func (h *httpKeyInternalAPI) QueryKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpKeyInternalAPI) QueryKeyChanges(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryKeyChangesRequest,
|
||||||
|
response *api.QueryKeyChangesResponse,
|
||||||
|
) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyChanges")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryKeyChangesPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
if err != nil {
|
||||||
|
response.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,4 +58,15 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryKeyChangesPath,
|
||||||
|
httputil.MakeInternalAPI("queryKeyChanges", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryKeyChangesRequest{}
|
||||||
|
response := api.QueryKeyChangesResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
s.QueryKeyChanges(req.Context(), &request, &response)
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ func NewInternalAPI(
|
||||||
keyChangeProducer := &producers.KeyChange{
|
keyChangeProducer := &producers.KeyChange{
|
||||||
Topic: string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
|
Topic: string(cfg.Matrix.Kafka.Topics.OutputKeyChangeEvent),
|
||||||
Producer: producer,
|
Producer: producer,
|
||||||
|
DB: db,
|
||||||
}
|
}
|
||||||
return &internal.KeyInternalAPI{
|
return &internal.KeyInternalAPI{
|
||||||
DB: db,
|
DB: db,
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,12 @@
|
||||||
package producers
|
package producers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -26,6 +28,16 @@ import (
|
||||||
type KeyChange struct {
|
type KeyChange struct {
|
||||||
Topic string
|
Topic string
|
||||||
Producer sarama.SyncProducer
|
Producer sarama.SyncProducer
|
||||||
|
DB storage.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultPartition returns the default partition this process is sending key changes to.
|
||||||
|
// NB: A keyserver MUST send key changes to only 1 partition or else query operations will
|
||||||
|
// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but
|
||||||
|
// then all keyservers must be queried to calculate the entire set of key changes between
|
||||||
|
// two sync tokens.
|
||||||
|
func (p *KeyChange) DefaultPartition() int32 {
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProduceKeyChanges creates new change events for each key
|
// ProduceKeyChanges creates new change events for each key
|
||||||
|
|
@ -46,6 +58,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"user_id": key.UserID,
|
"user_id": key.UserID,
|
||||||
"device_id": key.DeviceID,
|
"device_id": key.DeviceID,
|
||||||
|
|
|
||||||
|
|
@ -43,4 +43,12 @@ type Database interface {
|
||||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||||
|
|
||||||
|
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
|
||||||
|
// their keys in some way.
|
||||||
|
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||||
|
|
||||||
|
// KeyChanges returns a list of user IDs who have modified their keys from the offset given.
|
||||||
|
// Returns the offset of the latest key change.
|
||||||
|
KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
97
keyserver/storage/postgres/key_changes_table.go
Normal file
97
keyserver/storage/postgres/key_changes_table.go
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var keyChangesSchema = `
|
||||||
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
log_offset BIGINT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||||
|
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||||
|
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||||
|
const upsertKeyChangeSQL = "" +
|
||||||
|
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
|
||||||
|
" VALUES ($1, $2, $3)" +
|
||||||
|
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
|
||||||
|
" DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
|
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||||
|
// take the max offset value as the latest offset.
|
||||||
|
const selectKeyChangesSQL = "" +
|
||||||
|
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
|
||||||
|
|
||||||
|
type keyChangesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
|
selectKeyChangesStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
|
s := &keyChangesStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(keyChangesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
|
ctx context.Context, partition int32, fromOffset int64,
|
||||||
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
var offset int64
|
||||||
|
if err := rows.Scan(&userID, &offset); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if offset > latestOffset {
|
||||||
|
latestOffset = offset
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -35,9 +35,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
kc, err := NewPostgresKeyChangesTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
|
KeyChangesTable: kc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ type Database struct {
|
||||||
DB *sql.DB
|
DB *sql.DB
|
||||||
OneTimeKeysTable tables.OneTimeKeys
|
OneTimeKeysTable tables.OneTimeKeys
|
||||||
DeviceKeysTable tables.DeviceKeys
|
DeviceKeysTable tables.DeviceKeys
|
||||||
|
KeyChangesTable tables.KeyChanges
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||||
|
|
@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
||||||
})
|
})
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
|
||||||
|
}
|
||||||
|
|
|
||||||
98
keyserver/storage/sqlite3/key_changes_table.go
Normal file
98
keyserver/storage/sqlite3/key_changes_table.go
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
var keyChangesSchema = `
|
||||||
|
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||||
|
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||||
|
partition BIGINT NOT NULL,
|
||||||
|
offset BIGINT NOT NULL,
|
||||||
|
-- The key owner
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
UNIQUE (partition, offset)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||||
|
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||||
|
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||||
|
const upsertKeyChangeSQL = "" +
|
||||||
|
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||||
|
" VALUES ($1, $2, $3)" +
|
||||||
|
" ON CONFLICT (partition, offset)" +
|
||||||
|
" DO UPDATE SET user_id = $3"
|
||||||
|
|
||||||
|
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||||
|
// take the max offset value as the latest offset.
|
||||||
|
const selectKeyChangesSQL = "" +
|
||||||
|
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
|
||||||
|
|
||||||
|
type keyChangesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
upsertKeyChangeStmt *sql.Stmt
|
||||||
|
selectKeyChangesStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||||
|
s := &keyChangesStatements{
|
||||||
|
db: db,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(keyChangesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||||
|
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyChangesStatements) SelectKeyChanges(
|
||||||
|
ctx context.Context, partition int32, fromOffset int64,
|
||||||
|
) (userIDs []string, latestOffset int64, err error) {
|
||||||
|
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
var offset int64
|
||||||
|
if err := rows.Scan(&userID, &offset); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if offset > latestOffset {
|
||||||
|
latestOffset = offset
|
||||||
|
}
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
@ -33,9 +33,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
kc, err := NewSqliteKeyChangesTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
DB: db,
|
DB: db,
|
||||||
OneTimeKeysTable: otk,
|
OneTimeKeysTable: otk,
|
||||||
DeviceKeysTable: dk,
|
DeviceKeysTable: dk,
|
||||||
|
KeyChangesTable: kc,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
57
keyserver/storage/storage_test.go
Normal file
57
keyserver/storage/storage_test.go
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
func MustNotError(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("operation failed: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyChanges(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
|
}
|
||||||
|
if latest != 2 {
|
||||||
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||||
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyChangesNoDupes(t *testing.T) {
|
||||||
|
db, err := NewDatabase("file::memory:", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to NewDatabase: %s", err)
|
||||||
|
}
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
||||||
|
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
||||||
|
userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||||
|
}
|
||||||
|
if latest != 2 {
|
||||||
|
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||||
|
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -35,3 +35,8 @@ type DeviceKeys interface {
|
||||||
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyChanges interface {
|
||||||
|
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||||
|
SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
|
||||||
}).Panicf("could not save account data")
|
}).Panicf("could not save account data")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0))
|
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0, nil))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
|
||||||
s.notifier.OnNewSendToDevice(
|
s.notifier.OnNewSendToDevice(
|
||||||
output.UserID,
|
output.UserID,
|
||||||
[]string{output.DeviceID},
|
[]string{output.DeviceID},
|
||||||
types.NewStreamToken(0, streamPos),
|
types.NewStreamToken(0, streamPos, nil),
|
||||||
)
|
)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error {
|
||||||
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
|
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
|
||||||
s.notifier.OnNewEvent(
|
s.notifier.OnNewEvent(
|
||||||
nil, roomID, nil,
|
nil, roomID, nil,
|
||||||
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)),
|
types.NewStreamToken(0, types.StreamPosition(latestSyncPosition), nil),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -94,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
|
||||||
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
|
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos))
|
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos, nil))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ import (
|
||||||
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
syncinternal "github.com/matrix-org/dendrite/syncapi/internal"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
syncapi "github.com/matrix-org/dendrite/syncapi/sync"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
@ -31,13 +33,14 @@ import (
|
||||||
|
|
||||||
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
||||||
type OutputKeyChangeEventConsumer struct {
|
type OutputKeyChangeEventConsumer struct {
|
||||||
keyChangeConsumer *internal.ContinualConsumer
|
keyChangeConsumer *internal.ContinualConsumer
|
||||||
db storage.Database
|
db storage.Database
|
||||||
serverName gomatrixserverlib.ServerName // our server name
|
serverName gomatrixserverlib.ServerName // our server name
|
||||||
currentStateAPI currentstateAPI.CurrentStateInternalAPI
|
currentStateAPI currentstateAPI.CurrentStateInternalAPI
|
||||||
// keyAPI api.KeyInternalAPI
|
keyAPI api.KeyInternalAPI
|
||||||
partitionToOffset map[int32]int64
|
partitionToOffset map[int32]int64
|
||||||
partitionToOffsetMu sync.Mutex
|
partitionToOffsetMu sync.Mutex
|
||||||
|
notifier *syncapi.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
||||||
|
|
@ -46,6 +49,8 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
topic string,
|
topic string,
|
||||||
kafkaConsumer sarama.Consumer,
|
kafkaConsumer sarama.Consumer,
|
||||||
|
n *syncapi.Notifier,
|
||||||
|
keyAPI api.KeyInternalAPI,
|
||||||
currentStateAPI currentstateAPI.CurrentStateInternalAPI,
|
currentStateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
store storage.Database,
|
store storage.Database,
|
||||||
) *OutputKeyChangeEventConsumer {
|
) *OutputKeyChangeEventConsumer {
|
||||||
|
|
@ -60,9 +65,11 @@ func NewOutputKeyChangeEventConsumer(
|
||||||
keyChangeConsumer: &consumer,
|
keyChangeConsumer: &consumer,
|
||||||
db: store,
|
db: store,
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
|
keyAPI: keyAPI,
|
||||||
currentStateAPI: currentStateAPI,
|
currentStateAPI: currentStateAPI,
|
||||||
partitionToOffset: make(map[int32]int64),
|
partitionToOffset: make(map[int32]int64),
|
||||||
partitionToOffsetMu: sync.Mutex{},
|
partitionToOffsetMu: sync.Mutex{},
|
||||||
|
notifier: n,
|
||||||
}
|
}
|
||||||
|
|
||||||
consumer.ProcessMessage = s.onMessage
|
consumer.ProcessMessage = s.onMessage
|
||||||
|
|
@ -107,36 +114,22 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
|
// TODO: f.e queryRes.UserIDsToCount : notify users by waking up streams
|
||||||
return nil
|
posUpdate := types.NewStreamToken(0, 0, map[string]*types.LogPosition{
|
||||||
}
|
syncinternal.DeviceListLogName: &types.LogPosition{
|
||||||
|
Offset: msg.Offset,
|
||||||
// Catchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
Partition: msg.Partition,
|
||||||
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
},
|
||||||
// be already filled in with join/leave information.
|
})
|
||||||
func (s *OutputKeyChangeEventConsumer) Catchup(
|
for userID := range queryRes.UserIDsToCount {
|
||||||
ctx context.Context, userID string, res *types.Response, tok types.StreamingToken,
|
s.notifier.OnNewKeyChange(posUpdate, userID, output.UserID)
|
||||||
) (hasNew bool, err error) {
|
|
||||||
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
|
||||||
newlyJoinedRooms := joinedRooms(res, userID)
|
|
||||||
newlyLeftRooms := leftRooms(res)
|
|
||||||
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
|
|
||||||
changed, left, err := s.trackChangedUsers(ctx, userID, newlyJoinedRooms, newlyLeftRooms)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
res.DeviceLists.Changed = changed
|
|
||||||
res.DeviceLists.Left = left
|
|
||||||
hasNew = len(changed) > 0 || len(left) > 0
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
// TODO: now also track users who we already share rooms with but who have updated their devices between the two tokens
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) {
|
func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.HeaderedEvent) {
|
||||||
// work out who we are now sharing rooms with which we previously were not and notify them about the joining
|
// work out who we are now sharing rooms with which we previously were not and notify them about the joining
|
||||||
// users keys:
|
// users keys:
|
||||||
changed, _, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), []string{ev.RoomID()}, nil)
|
changed, _, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), []string{ev.RoomID()}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("OnJoinEvent: failed to work out changed users")
|
log.WithError(err).Error("OnJoinEvent: failed to work out changed users")
|
||||||
return
|
return
|
||||||
|
|
@ -149,7 +142,7 @@ func (s *OutputKeyChangeEventConsumer) OnJoinEvent(ev *gomatrixserverlib.Headere
|
||||||
|
|
||||||
func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) {
|
func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.HeaderedEvent) {
|
||||||
// work out who we are no longer sharing any rooms with and notify them about the leaving user
|
// work out who we are no longer sharing any rooms with and notify them about the leaving user
|
||||||
_, left, err := s.trackChangedUsers(context.Background(), *ev.StateKey(), nil, []string{ev.RoomID()})
|
_, left, err := syncinternal.TrackChangedUsers(context.Background(), s.currentStateAPI, *ev.StateKey(), nil, []string{ev.RoomID()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("OnLeaveEvent: failed to work out left users")
|
log.WithError(err).Error("OnLeaveEvent: failed to work out left users")
|
||||||
return
|
return
|
||||||
|
|
@ -160,129 +153,3 @@ func (s *OutputKeyChangeEventConsumer) OnLeaveEvent(ev *gomatrixserverlib.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
|
||||||
func (s *OutputKeyChangeEventConsumer) trackChangedUsers(
|
|
||||||
ctx context.Context, userID string, newlyJoinedRooms, newlyLeftRooms []string,
|
|
||||||
) (changed, left []string, err error) {
|
|
||||||
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
|
|
||||||
|
|
||||||
// Leave algorithm:
|
|
||||||
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
|
|
||||||
// - Get users in newly left room. - QueryCurrentState
|
|
||||||
// - Loop set of users and decrement by 1 for each user in newly left room.
|
|
||||||
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
|
|
||||||
var queryRes currentstateAPI.QuerySharedUsersResponse
|
|
||||||
err = s.currentStateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
|
||||||
UserID: userID,
|
|
||||||
IncludeRoomIDs: newlyLeftRooms,
|
|
||||||
}, &queryRes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
var stateRes currentstateAPI.QueryBulkStateContentResponse
|
|
||||||
err = s.currentStateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{
|
|
||||||
RoomIDs: newlyLeftRooms,
|
|
||||||
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
|
||||||
{
|
|
||||||
EventType: gomatrixserverlib.MRoomMember,
|
|
||||||
StateKey: "*",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
AllowWildcards: true,
|
|
||||||
}, &stateRes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
for _, state := range stateRes.Rooms {
|
|
||||||
for tuple, membership := range state {
|
|
||||||
if membership != gomatrixserverlib.Join {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
queryRes.UserIDsToCount[tuple.StateKey]--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for userID, count := range queryRes.UserIDsToCount {
|
|
||||||
if count <= 0 {
|
|
||||||
left = append(left, userID) // left is returned
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join algorithm:
|
|
||||||
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
|
|
||||||
// - Get users in newly joined room - QueryCurrentState
|
|
||||||
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
|
|
||||||
// - If yes: then they already shared a room in common, do nothing.
|
|
||||||
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
|
|
||||||
err = s.currentStateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
|
||||||
UserID: userID,
|
|
||||||
ExcludeRoomIDs: newlyJoinedRooms,
|
|
||||||
}, &queryRes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, left, err
|
|
||||||
}
|
|
||||||
err = s.currentStateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{
|
|
||||||
RoomIDs: newlyJoinedRooms,
|
|
||||||
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
|
||||||
{
|
|
||||||
EventType: gomatrixserverlib.MRoomMember,
|
|
||||||
StateKey: "*",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
AllowWildcards: true,
|
|
||||||
}, &stateRes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, left, err
|
|
||||||
}
|
|
||||||
for _, state := range stateRes.Rooms {
|
|
||||||
for tuple, membership := range state {
|
|
||||||
if membership != gomatrixserverlib.Join {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// new user who we weren't previously sharing rooms with
|
|
||||||
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
|
||||||
changed = append(changed, tuple.StateKey) // changed is returned
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return changed, left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func joinedRooms(res *types.Response, userID string) []string {
|
|
||||||
var roomIDs []string
|
|
||||||
for roomID, join := range res.Rooms.Join {
|
|
||||||
// we would expect to see our join event somewhere if we newly joined the room.
|
|
||||||
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
|
|
||||||
newlyJoined := membershipEventPresent(join.State.Events, userID)
|
|
||||||
if newlyJoined {
|
|
||||||
roomIDs = append(roomIDs, roomID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
|
|
||||||
if newlyJoined {
|
|
||||||
roomIDs = append(roomIDs, roomID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return roomIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
func leftRooms(res *types.Response) []string {
|
|
||||||
roomIDs := make([]string, len(res.Rooms.Leave))
|
|
||||||
i := 0
|
|
||||||
for roomID := range res.Rooms.Leave {
|
|
||||||
roomIDs[i] = roomID
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return roomIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
|
|
||||||
for _, ev := range events {
|
|
||||||
// it's enough to know that we have our member event here, don't need to check membership content
|
|
||||||
// as it's implied by being in the respective section of the sync response.
|
|
||||||
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
||||||
}).Panicf("roomserver output log: write event failure")
|
}).Panicf("roomserver output log: write event failure")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0))
|
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -176,7 +176,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
||||||
}).Panicf("roomserver output log: write invite failure")
|
}).Panicf("roomserver output log: write invite failure")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0))
|
s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0, nil))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -194,7 +194,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
||||||
}
|
}
|
||||||
// Notify any active sync requests that the invite has been retired.
|
// Notify any active sync requests that the invite has been retired.
|
||||||
// Invites share the same stream counter as PDUs
|
// Invites share the same stream counter as PDUs
|
||||||
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0))
|
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.NewStreamToken(sp, 0, nil))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
219
syncapi/internal/keychange.go
Normal file
219
syncapi/internal/keychange.go
Normal file
|
|
@ -0,0 +1,219 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/Shopify/sarama"
|
||||||
|
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const DeviceListLogName = "dl"
|
||||||
|
|
||||||
|
// DeviceListCatchup fills in the given response for the given user ID to bring it up-to-date with device lists. hasNew=true if the response
|
||||||
|
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
|
||||||
|
// be already filled in with join/leave information.
|
||||||
|
func DeviceListCatchup(
|
||||||
|
ctx context.Context, keyAPI keyapi.KeyInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
|
userID string, res *types.Response, tok types.StreamingToken,
|
||||||
|
) (newTok *types.StreamingToken, hasNew bool, err error) {
|
||||||
|
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
||||||
|
newlyJoinedRooms := joinedRooms(res, userID)
|
||||||
|
newlyLeftRooms := leftRooms(res)
|
||||||
|
if len(newlyJoinedRooms) > 0 || len(newlyLeftRooms) > 0 {
|
||||||
|
changed, left, err := TrackChangedUsers(ctx, stateAPI, userID, newlyJoinedRooms, newlyLeftRooms)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
res.DeviceLists.Changed = changed
|
||||||
|
res.DeviceLists.Left = left
|
||||||
|
hasNew = len(changed) > 0 || len(left) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// now also track users who we already share rooms with but who have updated their devices between the two tokens
|
||||||
|
|
||||||
|
var partition int32
|
||||||
|
var offset int64
|
||||||
|
// Extract partition/offset from sync token
|
||||||
|
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
|
||||||
|
logOffset := tok.Log(DeviceListLogName)
|
||||||
|
if logOffset != nil {
|
||||||
|
partition = logOffset.Partition
|
||||||
|
offset = logOffset.Offset
|
||||||
|
} else {
|
||||||
|
partition = -1
|
||||||
|
offset = sarama.OffsetOldest
|
||||||
|
}
|
||||||
|
var queryRes api.QueryKeyChangesResponse
|
||||||
|
keyAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
|
||||||
|
Partition: partition,
|
||||||
|
Offset: offset,
|
||||||
|
}, &queryRes)
|
||||||
|
if queryRes.Error != nil {
|
||||||
|
// don't fail the catchup because we may have got useful information by tracking membership
|
||||||
|
util.GetLogger(ctx).WithError(queryRes.Error).Error("QueryKeyChanges failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userSet := make(map[string]bool)
|
||||||
|
for _, userID := range res.DeviceLists.Changed {
|
||||||
|
userSet[userID] = true
|
||||||
|
}
|
||||||
|
for _, userID := range queryRes.UserIDs {
|
||||||
|
if !userSet[userID] {
|
||||||
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
|
hasNew = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Make a new streaming token using the new offset
|
||||||
|
tok.SetLog(DeviceListLogName, &types.LogPosition{
|
||||||
|
Offset: queryRes.Offset,
|
||||||
|
Partition: queryRes.Partition,
|
||||||
|
})
|
||||||
|
newTok = &tok
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.
|
||||||
|
// nolint:gocyclo
|
||||||
|
func TrackChangedUsers(
|
||||||
|
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, newlyJoinedRooms, newlyLeftRooms []string,
|
||||||
|
) (changed, left []string, err error) {
|
||||||
|
// process leaves first, then joins afterwards so if we join/leave/join/leave we err on the side of including users.
|
||||||
|
|
||||||
|
// Leave algorithm:
|
||||||
|
// - Get set of users and number of times they appear in rooms prior to leave. - QuerySharedUsersRequest with 'IncludeRoomID'.
|
||||||
|
// - Get users in newly left room. - QueryCurrentState
|
||||||
|
// - Loop set of users and decrement by 1 for each user in newly left room.
|
||||||
|
// - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync.
|
||||||
|
var queryRes currentstateAPI.QuerySharedUsersResponse
|
||||||
|
err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
||||||
|
UserID: userID,
|
||||||
|
IncludeRoomIDs: newlyLeftRooms,
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
var stateRes currentstateAPI.QueryBulkStateContentResponse
|
||||||
|
err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{
|
||||||
|
RoomIDs: newlyLeftRooms,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
{
|
||||||
|
EventType: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AllowWildcards: true,
|
||||||
|
}, &stateRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
for _, state := range stateRes.Rooms {
|
||||||
|
for tuple, membership := range state {
|
||||||
|
if membership != gomatrixserverlib.Join {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
queryRes.UserIDsToCount[tuple.StateKey]--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for userID, count := range queryRes.UserIDsToCount {
|
||||||
|
if count <= 0 {
|
||||||
|
left = append(left, userID) // left is returned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join algorithm:
|
||||||
|
// - Get the set of all joined users prior to joining room - QuerySharedUsersRequest with 'ExcludeRoomID'.
|
||||||
|
// - Get users in newly joined room - QueryCurrentState
|
||||||
|
// - Loop set of users in newly joined room, do they appear in the set of users prior to joining?
|
||||||
|
// - If yes: then they already shared a room in common, do nothing.
|
||||||
|
// - If no: then they are a brand new user so inform BOTH parties of this via 'changed=[...]'
|
||||||
|
err = stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
||||||
|
UserID: userID,
|
||||||
|
ExcludeRoomIDs: newlyJoinedRooms,
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, left, err
|
||||||
|
}
|
||||||
|
err = stateAPI.QueryBulkStateContent(ctx, ¤tstateAPI.QueryBulkStateContentRequest{
|
||||||
|
RoomIDs: newlyJoinedRooms,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
{
|
||||||
|
EventType: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: "*",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AllowWildcards: true,
|
||||||
|
}, &stateRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, left, err
|
||||||
|
}
|
||||||
|
for _, state := range stateRes.Rooms {
|
||||||
|
for tuple, membership := range state {
|
||||||
|
if membership != gomatrixserverlib.Join {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// new user who we weren't previously sharing rooms with
|
||||||
|
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
||||||
|
changed = append(changed, tuple.StateKey) // changed is returned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed, left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinedRooms(res *types.Response, userID string) []string {
|
||||||
|
var roomIDs []string
|
||||||
|
for roomID, join := range res.Rooms.Join {
|
||||||
|
// we would expect to see our join event somewhere if we newly joined the room.
|
||||||
|
// Normal events get put in the join section so it's not enough to know the room ID is present in 'join'.
|
||||||
|
newlyJoined := membershipEventPresent(join.State.Events, userID)
|
||||||
|
if newlyJoined {
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newlyJoined = membershipEventPresent(join.Timeline.Events, userID)
|
||||||
|
if newlyJoined {
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return roomIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func leftRooms(res *types.Response) []string {
|
||||||
|
roomIDs := make([]string, len(res.Rooms.Leave))
|
||||||
|
i := 0
|
||||||
|
for roomID := range res.Rooms.Leave {
|
||||||
|
roomIDs[i] = roomID
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return roomIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID string) bool {
|
||||||
|
for _, ev := range events {
|
||||||
|
// it's enough to know that we have our member event here, don't need to check membership content
|
||||||
|
// as it's implied by being in the respective section of the sync response.
|
||||||
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil && *ev.StateKey == userID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package consumers
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
@ -7,14 +7,29 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/api"
|
"github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
syncingUser = "@alice:localhost"
|
syncingUser = "@alice:localhost"
|
||||||
|
emptyToken = types.NewStreamToken(0, 0, nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockKeyAPI struct{}
|
||||||
|
|
||||||
|
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||||
|
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) {
|
||||||
|
}
|
||||||
|
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) {
|
||||||
|
}
|
||||||
|
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) {
|
||||||
|
}
|
||||||
|
|
||||||
type mockCurrentStateAPI struct {
|
type mockCurrentStateAPI struct {
|
||||||
roomIDToJoinedMembers map[string][]string
|
roomIDToJoinedMembers map[string][]string
|
||||||
}
|
}
|
||||||
|
|
@ -144,18 +159,17 @@ func leaveResponseWithRooms(syncResponse *types.Response, userID string, roomIDs
|
||||||
func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
|
func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
|
||||||
newShareUser := "@bill:localhost"
|
newShareUser := "@bill:localhost"
|
||||||
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar"
|
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNewUser:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
syncResponse := types.NewResponse()
|
||||||
|
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
||||||
|
|
||||||
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
newlyJoinedRoom: {syncingUser, newShareUser},
|
newlyJoinedRoom: {syncingUser, newShareUser},
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
syncResponse := types.NewResponse()
|
|
||||||
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
||||||
hasNew: true,
|
hasNew: true,
|
||||||
|
|
@ -167,18 +181,17 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
|
||||||
func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
|
func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
|
||||||
removeUser := "@bill:localhost"
|
removeUser := "@bill:localhost"
|
||||||
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar"
|
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareLeftUser:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
syncResponse := types.NewResponse()
|
||||||
|
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
||||||
|
|
||||||
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
newlyLeftRoom: {removeUser},
|
newlyLeftRoom: {removeUser},
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
syncResponse := types.NewResponse()
|
|
||||||
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
||||||
hasNew: true,
|
hasNew: true,
|
||||||
|
|
@ -190,16 +203,15 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
|
||||||
func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
|
func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
|
||||||
existingUser := "@bob:localhost"
|
existingUser := "@bob:localhost"
|
||||||
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar"
|
newlyJoinedRoom := "!TestKeyChangeCatchupOnJoinShareNoNewUsers:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
syncResponse := types.NewResponse()
|
||||||
|
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
||||||
|
|
||||||
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
newlyJoinedRoom: {syncingUser, existingUser},
|
newlyJoinedRoom: {syncingUser, existingUser},
|
||||||
"!another:room": {syncingUser, existingUser},
|
"!another:room": {syncingUser, existingUser},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
syncResponse := types.NewResponse()
|
|
||||||
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("Catchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -212,18 +224,17 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
|
||||||
func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
|
func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
|
||||||
existingUser := "@bob:localhost"
|
existingUser := "@bob:localhost"
|
||||||
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar"
|
newlyLeftRoom := "!TestKeyChangeCatchupOnLeaveShareNoUsers:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
syncResponse := types.NewResponse()
|
||||||
|
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
||||||
|
|
||||||
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
newlyLeftRoom: {existingUser},
|
newlyLeftRoom: {existingUser},
|
||||||
"!another:room": {syncingUser, existingUser},
|
"!another:room": {syncingUser, existingUser},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
syncResponse := types.NewResponse()
|
|
||||||
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
||||||
hasNew: false,
|
hasNew: false,
|
||||||
|
|
@ -234,11 +245,6 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
|
||||||
func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
|
func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
|
||||||
existingUser := "@bob1:localhost"
|
existingUser := "@bob1:localhost"
|
||||||
roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar"
|
roomID := "!TestKeyChangeCatchupNoNewJoinsButMessages:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
|
||||||
roomID: {syncingUser, existingUser},
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
syncResponse := types.NewResponse()
|
syncResponse := types.NewResponse()
|
||||||
empty := ""
|
empty := ""
|
||||||
roomStateEvents := []gomatrixserverlib.ClientEvent{
|
roomStateEvents := []gomatrixserverlib.ClientEvent{
|
||||||
|
|
@ -280,9 +286,13 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
|
||||||
jr.Timeline.Events = roomTimelineEvents
|
jr.Timeline.Events = roomTimelineEvents
|
||||||
syncResponse.Rooms.Join[roomID] = jr
|
syncResponse.Rooms.Join[roomID] = jr
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
|
roomID: {syncingUser, existingUser},
|
||||||
|
},
|
||||||
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
||||||
hasNew: false,
|
hasNew: false,
|
||||||
|
|
@ -297,18 +307,17 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
|
||||||
newlyLeftUser2 := "@debra:localhost"
|
newlyLeftUser2 := "@debra:localhost"
|
||||||
newlyJoinedRoom := "!join:bar"
|
newlyJoinedRoom := "!join:bar"
|
||||||
newlyLeftRoom := "!left:bar"
|
newlyLeftRoom := "!left:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
syncResponse := types.NewResponse()
|
||||||
|
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
||||||
|
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
||||||
|
|
||||||
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2},
|
newlyJoinedRoom: {syncingUser, newShareUser, newShareUser2},
|
||||||
newlyLeftRoom: {newlyLeftUser, newlyLeftUser2},
|
newlyLeftRoom: {newlyLeftUser, newlyLeftUser2},
|
||||||
"!another:room": {syncingUser},
|
"!another:room": {syncingUser},
|
||||||
},
|
},
|
||||||
}, nil)
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
syncResponse := types.NewResponse()
|
|
||||||
syncResponse = joinResponseWithRooms(syncResponse, syncingUser, []string{newlyJoinedRoom})
|
|
||||||
syncResponse = leaveResponseWithRooms(syncResponse, syncingUser, []string{newlyLeftRoom})
|
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("Catchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -333,12 +342,6 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
|
||||||
newShareUser := "@berta:localhost"
|
newShareUser := "@berta:localhost"
|
||||||
newShareUser2 := "@bobby:localhost"
|
newShareUser2 := "@bobby:localhost"
|
||||||
roomID := "!join:bar"
|
roomID := "!join:bar"
|
||||||
consumer := NewOutputKeyChangeEventConsumer(gomatrixserverlib.ServerName("localhost"), "some_topic", nil, &mockCurrentStateAPI{
|
|
||||||
roomIDToJoinedMembers: map[string][]string{
|
|
||||||
roomID: {newShareUser, newShareUser2},
|
|
||||||
"!another:room": {syncingUser},
|
|
||||||
},
|
|
||||||
}, nil)
|
|
||||||
syncResponse := types.NewResponse()
|
syncResponse := types.NewResponse()
|
||||||
roomEvents := []gomatrixserverlib.ClientEvent{
|
roomEvents := []gomatrixserverlib.ClientEvent{
|
||||||
{
|
{
|
||||||
|
|
@ -393,9 +396,14 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
|
||||||
lr.Timeline.Events = roomEvents
|
lr.Timeline.Events = roomEvents
|
||||||
syncResponse.Rooms.Leave[roomID] = lr
|
syncResponse.Rooms.Leave[roomID] = lr
|
||||||
|
|
||||||
hasNew, err := consumer.Catchup(context.Background(), syncingUser, syncResponse, types.NewStreamToken(0, 0))
|
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, &mockCurrentStateAPI{
|
||||||
|
roomIDToJoinedMembers: map[string][]string{
|
||||||
|
roomID: {newShareUser, newShareUser2},
|
||||||
|
"!another:room": {syncingUser},
|
||||||
|
},
|
||||||
|
}, syncingUser, syncResponse, emptyToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Catchup returned an error: %s", err)
|
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||||
}
|
}
|
||||||
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
assertCatchup(t, hasNew, syncResponse, wantCatchup{
|
||||||
hasNew: true,
|
hasNew: true,
|
||||||
|
|
@ -434,7 +434,7 @@ func (d *Database) syncPositionTx(
|
||||||
if maxInviteID > maxEventID {
|
if maxInviteID > maxEventID {
|
||||||
maxEventID = maxInviteID
|
maxEventID = maxInviteID
|
||||||
}
|
}
|
||||||
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()))
|
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -731,7 +731,7 @@ func (d *Database) CompleteSync(
|
||||||
|
|
||||||
// Use a zero value SyncPosition for fromPos so all EDU states are added.
|
// Use a zero value SyncPosition for fromPos so all EDU states are added.
|
||||||
err = d.addEDUDeltaToResponse(
|
err = d.addEDUDeltaToResponse(
|
||||||
types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res,
|
types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ func TestSyncResponse(t *testing.T) {
|
||||||
Name: "IncrementalSync penultimate",
|
Name: "IncrementalSync penultimate",
|
||||||
DoSync: func() (*types.Response, error) {
|
DoSync: func() (*types.Response, error) {
|
||||||
from := types.NewStreamToken( // pretend we are at the penultimate event
|
from := types.NewStreamToken( // pretend we are at the penultimate event
|
||||||
positions[len(positions)-2], types.StreamPosition(0),
|
positions[len(positions)-2], types.StreamPosition(0), nil,
|
||||||
)
|
)
|
||||||
res := types.NewResponse()
|
res := types.NewResponse()
|
||||||
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
|
||||||
|
|
@ -179,7 +179,7 @@ func TestSyncResponse(t *testing.T) {
|
||||||
Name: "IncrementalSync limited",
|
Name: "IncrementalSync limited",
|
||||||
DoSync: func() (*types.Response, error) {
|
DoSync: func() (*types.Response, error) {
|
||||||
from := types.NewStreamToken( // pretend we are 10 events behind
|
from := types.NewStreamToken( // pretend we are 10 events behind
|
||||||
positions[len(positions)-11], types.StreamPosition(0),
|
positions[len(positions)-11], types.StreamPosition(0), nil,
|
||||||
)
|
)
|
||||||
res := types.NewResponse()
|
res := types.NewResponse()
|
||||||
// limit is set to 5
|
// limit is set to 5
|
||||||
|
|
@ -222,7 +222,7 @@ func TestSyncResponse(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.Fatalf("failed to do sync: %s", err)
|
st.Fatalf("failed to do sync: %s", err)
|
||||||
}
|
}
|
||||||
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition())
|
next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition(), nil)
|
||||||
if res.NextBatch != next.String() {
|
if res.NextBatch != next.String() {
|
||||||
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
|
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
|
||||||
}
|
}
|
||||||
|
|
@ -246,7 +246,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||||
}
|
}
|
||||||
from := types.NewStreamToken(
|
from := types.NewStreamToken(
|
||||||
positions[len(positions)-2], types.StreamPosition(0),
|
positions[len(positions)-2], types.StreamPosition(0), nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
res := types.NewResponse()
|
res := types.NewResponse()
|
||||||
|
|
@ -291,7 +291,7 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||||
}
|
}
|
||||||
// head towards the beginning of time
|
// head towards the beginning of time
|
||||||
to := types.NewStreamToken(0, 0)
|
to := types.NewStreamToken(0, 0, nil)
|
||||||
|
|
||||||
// backpaginate 5 messages starting at the latest position.
|
// backpaginate 5 messages starting at the latest position.
|
||||||
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
|
||||||
|
|
@ -534,14 +534,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
|
|
||||||
// At this point there should be no messages. We haven't sent anything
|
// At this point there should be no messages. We haven't sent anything
|
||||||
// yet.
|
// yet.
|
||||||
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
|
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
|
||||||
t.Fatal("first call should have no updates")
|
t.Fatal("first call should have no updates")
|
||||||
}
|
}
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -559,14 +559,14 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
// At this point we should get exactly one message. We're sending the sync position
|
// At this point we should get exactly one message. We're sending the sync position
|
||||||
// that we were given from the update and the send-to-device update will be updated
|
// that we were given from the update and the send-to-device update will be updated
|
||||||
// in the database to reflect that this was the sync position we sent the message at.
|
// in the database to reflect that this was the sync position we sent the message at.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
|
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
|
||||||
t.Fatal("second call should have one update")
|
t.Fatal("second call should have one update")
|
||||||
}
|
}
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -574,35 +574,35 @@ func TestSendToDeviceBehaviour(t *testing.T) {
|
||||||
// At this point we should still have one message because we haven't progressed the
|
// At this point we should still have one message because we haven't progressed the
|
||||||
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
// sync position yet. This is equivalent to the client failing to /sync and retrying
|
||||||
// with the same position.
|
// with the same position.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
|
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
|
||||||
t.Fatal("third call should have one update still")
|
t.Fatal("third call should have one update still")
|
||||||
}
|
}
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should now have no updates, because we've progressed the sync
|
// At this point we should now have no updates, because we've progressed the sync
|
||||||
// position. Therefore the update from before will not be sent again.
|
// position. Therefore the update from before will not be sent again.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
|
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
|
||||||
t.Fatal("fourth call should have no updates")
|
t.Fatal("fourth call should have no updates")
|
||||||
}
|
}
|
||||||
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
|
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point we should still have no updates, because no new updates have been
|
// At this point we should still have no updates, because no new updates have been
|
||||||
// sent.
|
// sent.
|
||||||
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
|
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2, nil))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -639,7 +639,7 @@ func TestInviteBehaviour(t *testing.T) {
|
||||||
}
|
}
|
||||||
// both invite events should appear in a new sync
|
// both invite events should appear in a new sync
|
||||||
beforeRetireRes := types.NewResponse()
|
beforeRetireRes := types.NewResponse()
|
||||||
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
|
beforeRetireRes, err = db.IncrementalSync(ctx, beforeRetireRes, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("IncrementalSync failed: %s", err)
|
t.Fatalf("IncrementalSync failed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -654,7 +654,7 @@ func TestInviteBehaviour(t *testing.T) {
|
||||||
t.Fatalf("failed to get SyncPosition: %s", err)
|
t.Fatalf("failed to get SyncPosition: %s", err)
|
||||||
}
|
}
|
||||||
res := types.NewResponse()
|
res := types.NewResponse()
|
||||||
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0), latest, 0, false)
|
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, types.NewStreamToken(0, 0, nil), latest, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("IncrementalSync failed: %s", err)
|
t.Fatalf("IncrementalSync failed: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,16 @@ func (n *Notifier) OnNewSendToDevice(
|
||||||
n.wakeupUserDevice(userID, deviceIDs, latestPos)
|
n.wakeupUserDevice(userID, deviceIDs, latestPos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) OnNewKeyChange(
|
||||||
|
posUpdate types.StreamingToken, wakeUserID, keyChangeUserID string,
|
||||||
|
) {
|
||||||
|
n.streamLock.Lock()
|
||||||
|
defer n.streamLock.Unlock()
|
||||||
|
latestPos := n.currPos.WithUpdates(posUpdate)
|
||||||
|
n.currPos = latestPos
|
||||||
|
n.wakeupUsers([]string{wakeUserID}, latestPos)
|
||||||
|
}
|
||||||
|
|
||||||
// GetListener returns a UserStreamListener that can be used to wait for
|
// GetListener returns a UserStreamListener that can be used to wait for
|
||||||
// updates for a user. Must be closed.
|
// updates for a user. Must be closed.
|
||||||
// notify for anything before sincePos
|
// notify for anything before sincePos
|
||||||
|
|
|
||||||
|
|
@ -32,11 +32,11 @@ var (
|
||||||
randomMessageEvent gomatrixserverlib.HeaderedEvent
|
randomMessageEvent gomatrixserverlib.HeaderedEvent
|
||||||
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
|
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
|
||||||
bobLeaveEvent gomatrixserverlib.HeaderedEvent
|
bobLeaveEvent gomatrixserverlib.HeaderedEvent
|
||||||
syncPositionVeryOld = types.NewStreamToken(5, 0)
|
syncPositionVeryOld = types.NewStreamToken(5, 0, nil)
|
||||||
syncPositionBefore = types.NewStreamToken(11, 0)
|
syncPositionBefore = types.NewStreamToken(11, 0, nil)
|
||||||
syncPositionAfter = types.NewStreamToken(12, 0)
|
syncPositionAfter = types.NewStreamToken(12, 0, nil)
|
||||||
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1)
|
syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1, nil)
|
||||||
syncPositionAfter2 = types.NewStreamToken(13, 0)
|
syncPositionAfter2 = types.NewStreamToken(13, 0, nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
|
||||||
since = &tok
|
since = &tok
|
||||||
}
|
}
|
||||||
if since == nil {
|
if since == nil {
|
||||||
tok := types.NewStreamToken(0, 0)
|
tok := types.NewStreamToken(0, 0, nil)
|
||||||
since = &tok
|
since = &tok
|
||||||
}
|
}
|
||||||
timelineLimit := DefaultTimelineLimit
|
timelineLimit := DefaultTimelineLimit
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/internal"
|
||||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
|
@ -35,11 +38,16 @@ type RequestPool struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
userAPI userapi.UserInternalAPI
|
userAPI userapi.UserInternalAPI
|
||||||
notifier *Notifier
|
notifier *Notifier
|
||||||
|
keyAPI keyapi.KeyInternalAPI
|
||||||
|
stateAPI currentstateAPI.CurrentStateInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestPool makes a new RequestPool
|
// NewRequestPool makes a new RequestPool
|
||||||
func NewRequestPool(db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI) *RequestPool {
|
func NewRequestPool(
|
||||||
return &RequestPool{db, userAPI, n}
|
db storage.Database, n *Notifier, userAPI userapi.UserInternalAPI, keyAPI keyapi.KeyInternalAPI,
|
||||||
|
stateAPI currentstateAPI.CurrentStateInternalAPI,
|
||||||
|
) *RequestPool {
|
||||||
|
return &RequestPool{db, userAPI, n, keyAPI, stateAPI}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
|
// OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be
|
||||||
|
|
@ -138,7 +146,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
||||||
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
|
||||||
res = types.NewResponse()
|
res = types.NewResponse()
|
||||||
|
|
||||||
since := types.NewStreamToken(0, 0)
|
since := types.NewStreamToken(0, 0, nil)
|
||||||
if req.since != nil {
|
if req.since != nil {
|
||||||
since = *req.since
|
since = *req.since
|
||||||
}
|
}
|
||||||
|
|
@ -164,6 +172,10 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
res, err = rp.appendDeviceLists(res, req.device.UserID, since)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Before we return the sync response, make sure that we take action on
|
// Before we return the sync response, make sure that we take action on
|
||||||
// any send-to-device database updates or deletions that we need to do.
|
// any send-to-device database updates or deletions that we need to do.
|
||||||
|
|
@ -192,6 +204,22 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *RequestPool) appendDeviceLists(
|
||||||
|
data *types.Response, userID string, since types.StreamingToken,
|
||||||
|
) (*types.Response, error) {
|
||||||
|
// TODO: Currently this code will race which may result in duplicates but not missing data.
|
||||||
|
// This happens because, whilst we are told the range to fetch here (since / latest) the
|
||||||
|
// QueryKeyChanges API only exposes a "from" value (on purpose to avoid racing, which then
|
||||||
|
// returns the latest position with which the response has authority on). We'd need to tweak
|
||||||
|
// the API to expose a "to" value to fix this.
|
||||||
|
_, _, err := internal.DeviceListCatchup(context.Background(), rp.keyAPI, rp.stateAPI, userID, data, since)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (rp *RequestPool) appendAccountData(
|
func (rp *RequestPool) appendAccountData(
|
||||||
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition,
|
data *types.Response, userID string, req syncRequest, currentPos types.StreamPosition,
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,9 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
currentstateapi "github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
@ -39,6 +41,8 @@ func AddPublicRoutes(
|
||||||
consumer sarama.Consumer,
|
consumer sarama.Consumer,
|
||||||
userAPI userapi.UserInternalAPI,
|
userAPI userapi.UserInternalAPI,
|
||||||
rsAPI api.RoomserverInternalAPI,
|
rsAPI api.RoomserverInternalAPI,
|
||||||
|
keyAPI keyapi.KeyInternalAPI,
|
||||||
|
currentStateAPI currentstateapi.CurrentStateInternalAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
cfg *config.SyncAPI,
|
cfg *config.SyncAPI,
|
||||||
) {
|
) {
|
||||||
|
|
@ -58,7 +62,7 @@ func AddPublicRoutes(
|
||||||
logrus.WithError(err).Panicf("failed to start notifier")
|
logrus.WithError(err).Panicf("failed to start notifier")
|
||||||
}
|
}
|
||||||
|
|
||||||
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI)
|
requestPool := sync.NewRequestPool(syncDB, notifier, userAPI, keyAPI, currentStateAPI)
|
||||||
|
|
||||||
roomConsumer := consumers.NewOutputRoomEventConsumer(
|
roomConsumer := consumers.NewOutputRoomEventConsumer(
|
||||||
cfg, consumer, notifier, syncDB, rsAPI,
|
cfg, consumer, notifier, syncDB, rsAPI,
|
||||||
|
|
@ -88,5 +92,13 @@ func AddPublicRoutes(
|
||||||
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
|
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
|
||||||
|
cfg.Matrix.ServerName, string(cfg.Kafka.Topics.OutputKeyChangeEvent),
|
||||||
|
consumer, notifier, keyAPI, currentStateAPI, syncDB,
|
||||||
|
)
|
||||||
|
if err = keyChangeConsumer.Start(); err != nil {
|
||||||
|
logrus.WithError(err).Panicf("failed to start key change consumer")
|
||||||
|
}
|
||||||
|
|
||||||
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
|
routing.Setup(router, requestPool, syncDB, userAPI, federation, rsAPI, cfg)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,23 @@ var (
|
||||||
// StreamPosition represents the offset in the sync stream a client is at.
|
// StreamPosition represents the offset in the sync stream a client is at.
|
||||||
type StreamPosition int64
|
type StreamPosition int64
|
||||||
|
|
||||||
|
// LogPosition represents the offset in a Kafka log a client is at.
|
||||||
|
type LogPosition struct {
|
||||||
|
Partition int32
|
||||||
|
Offset int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAfter returns true if this position is after `lp`.
|
||||||
|
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
|
||||||
|
if lp == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if p.Partition != lp.Partition {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return p.Offset > lp.Offset
|
||||||
|
}
|
||||||
|
|
||||||
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
|
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
|
||||||
type StreamEvent struct {
|
type StreamEvent struct {
|
||||||
gomatrixserverlib.HeaderedEvent
|
gomatrixserverlib.HeaderedEvent
|
||||||
|
|
@ -90,6 +107,19 @@ const (
|
||||||
|
|
||||||
type StreamingToken struct {
|
type StreamingToken struct {
|
||||||
syncToken
|
syncToken
|
||||||
|
logs map[string]*LogPosition
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *StreamingToken) SetLog(name string, lp *LogPosition) {
|
||||||
|
t.logs[name] = lp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *StreamingToken) Log(name string) *LogPosition {
|
||||||
|
l, ok := t.logs[name]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *StreamingToken) PDUPosition() StreamPosition {
|
func (t *StreamingToken) PDUPosition() StreamPosition {
|
||||||
|
|
@ -99,7 +129,15 @@ func (t *StreamingToken) EDUPosition() StreamPosition {
|
||||||
return t.Positions[1]
|
return t.Positions[1]
|
||||||
}
|
}
|
||||||
func (t *StreamingToken) String() string {
|
func (t *StreamingToken) String() string {
|
||||||
return t.syncToken.String()
|
logStrings := []string{
|
||||||
|
t.syncToken.String(),
|
||||||
|
}
|
||||||
|
for name, lp := range t.logs {
|
||||||
|
logStr := fmt.Sprintf("%s-%d-%d", name, lp.Partition, lp.Offset)
|
||||||
|
logStrings = append(logStrings, logStr)
|
||||||
|
}
|
||||||
|
// E.g s11_22_33.dl0-134.ab1-441
|
||||||
|
return strings.Join(logStrings, ".")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAfter returns true if ANY position in this token is greater than `other`.
|
// IsAfter returns true if ANY position in this token is greater than `other`.
|
||||||
|
|
@ -109,12 +147,22 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for name := range t.logs {
|
||||||
|
otherLog := other.Log(name)
|
||||||
|
if otherLog == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if t.logs[name].IsAfter(otherLog) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
|
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
|
||||||
// If the latter StreamingToken contains a field that is not 0, it is considered an update,
|
// If the latter StreamingToken contains a field that is not 0, it is considered an update,
|
||||||
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
|
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
|
||||||
|
// If the other token has a log, they will replace any existing log on this token.
|
||||||
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
|
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
|
||||||
ret.Type = t.Type
|
ret.Type = t.Type
|
||||||
ret.Positions = make([]StreamPosition, len(t.Positions))
|
ret.Positions = make([]StreamPosition, len(t.Positions))
|
||||||
|
|
@ -125,6 +173,13 @@ func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken)
|
||||||
}
|
}
|
||||||
ret.Positions[i] = other.Positions[i]
|
ret.Positions[i] = other.Positions[i]
|
||||||
}
|
}
|
||||||
|
for name := range t.logs {
|
||||||
|
otherLog := other.Log(name)
|
||||||
|
if otherLog == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.logs[name] = otherLog
|
||||||
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -139,7 +194,7 @@ func (t *TopologyToken) PDUPosition() StreamPosition {
|
||||||
return t.Positions[1]
|
return t.Positions[1]
|
||||||
}
|
}
|
||||||
func (t *TopologyToken) StreamToken() StreamingToken {
|
func (t *TopologyToken) StreamToken() StreamingToken {
|
||||||
return NewStreamToken(t.PDUPosition(), 0)
|
return NewStreamToken(t.PDUPosition(), 0, nil)
|
||||||
}
|
}
|
||||||
func (t *TopologyToken) String() string {
|
func (t *TopologyToken) String() string {
|
||||||
return t.syncToken.String()
|
return t.syncToken.String()
|
||||||
|
|
@ -174,9 +229,9 @@ func (t *TopologyToken) Decrement() {
|
||||||
// error if the token couldn't be parsed into an int64, or if the token type
|
// error if the token couldn't be parsed into an int64, or if the token type
|
||||||
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
|
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
|
||||||
// case).
|
// case).
|
||||||
func newSyncTokenFromString(s string) (token *syncToken, err error) {
|
func newSyncTokenFromString(s string) (token *syncToken, categories []string, err error) {
|
||||||
if len(s) == 0 {
|
if len(s) == 0 {
|
||||||
return nil, ErrInvalidSyncTokenLen
|
return nil, nil, ErrInvalidSyncTokenLen
|
||||||
}
|
}
|
||||||
|
|
||||||
token = new(syncToken)
|
token = new(syncToken)
|
||||||
|
|
@ -185,16 +240,17 @@ func newSyncTokenFromString(s string) (token *syncToken, err error) {
|
||||||
switch t := SyncTokenType(s[:1]); t {
|
switch t := SyncTokenType(s[:1]); t {
|
||||||
case SyncTokenTypeStream, SyncTokenTypeTopology:
|
case SyncTokenTypeStream, SyncTokenTypeTopology:
|
||||||
token.Type = t
|
token.Type = t
|
||||||
positions = strings.Split(s[1:], "_")
|
categories = strings.Split(s[1:], ".")
|
||||||
|
positions = strings.Split(categories[0], "_")
|
||||||
default:
|
default:
|
||||||
return nil, ErrInvalidSyncTokenType
|
return nil, nil, ErrInvalidSyncTokenType
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pos := range positions {
|
for _, pos := range positions {
|
||||||
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
|
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
} else if posInt < 0 {
|
} else if posInt < 0 {
|
||||||
return nil, errors.New("negative position not allowed")
|
return nil, nil, errors.New("negative position not allowed")
|
||||||
} else {
|
} else {
|
||||||
token.Positions = append(token.Positions, StreamPosition(posInt))
|
token.Positions = append(token.Positions, StreamPosition(posInt))
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +271,7 @@ func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
|
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
|
||||||
t, err := newSyncTokenFromString(tok)
|
t, _, err := newSyncTokenFromString(tok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -233,16 +289,20 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStreamToken creates a new sync token for /sync
|
// NewStreamToken creates a new sync token for /sync
|
||||||
func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken {
|
func NewStreamToken(pduPos, eduPos StreamPosition, logs map[string]*LogPosition) StreamingToken {
|
||||||
|
if logs == nil {
|
||||||
|
logs = make(map[string]*LogPosition)
|
||||||
|
}
|
||||||
return StreamingToken{
|
return StreamingToken{
|
||||||
syncToken: syncToken{
|
syncToken: syncToken{
|
||||||
Type: SyncTokenTypeStream,
|
Type: SyncTokenTypeStream,
|
||||||
Positions: []StreamPosition{pduPos, eduPos},
|
Positions: []StreamPosition{pduPos, eduPos},
|
||||||
},
|
},
|
||||||
|
logs: logs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
||||||
t, err := newSyncTokenFromString(tok)
|
t, categories, err := newSyncTokenFromString(tok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -254,8 +314,35 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
||||||
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
|
err = fmt.Errorf("token %s wrong number of values, got %d want at least 2", tok, len(t.Positions))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logs := make(map[string]*LogPosition)
|
||||||
|
if len(categories) > 1 {
|
||||||
|
// dl-0-1234
|
||||||
|
// $log_name-$partition-$offset
|
||||||
|
for _, logStr := range categories[1:] {
|
||||||
|
segments := strings.Split(logStr, "-")
|
||||||
|
if len(segments) != 3 {
|
||||||
|
err = fmt.Errorf("token %s - invalid log: %s", tok, logStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var partition int64
|
||||||
|
partition, err = strconv.ParseInt(segments[1], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var offset int64
|
||||||
|
offset, err = strconv.ParseInt(segments[2], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logs[segments[0]] = &LogPosition{
|
||||||
|
Partition: int32(partition),
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return StreamingToken{
|
return StreamingToken{
|
||||||
syncToken: *t,
|
syncToken: *t,
|
||||||
|
logs: logs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,61 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSyncTokenWithLogs(t *testing.T) {
|
||||||
|
tests := map[string]*StreamingToken{
|
||||||
|
"s4_0": &StreamingToken{
|
||||||
|
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
|
||||||
|
logs: make(map[string]*LogPosition),
|
||||||
|
},
|
||||||
|
"s4_0.dl-0-123": &StreamingToken{
|
||||||
|
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
|
||||||
|
logs: map[string]*LogPosition{
|
||||||
|
"dl": &LogPosition{
|
||||||
|
Partition: 0,
|
||||||
|
Offset: 123,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"s4_0.dl-0-123.ab-1-14419482332": &StreamingToken{
|
||||||
|
syncToken: syncToken{Type: "s", Positions: []StreamPosition{4, 0}},
|
||||||
|
logs: map[string]*LogPosition{
|
||||||
|
"ab": &LogPosition{
|
||||||
|
Partition: 1,
|
||||||
|
Offset: 14419482332,
|
||||||
|
},
|
||||||
|
"dl": &LogPosition{
|
||||||
|
Partition: 0,
|
||||||
|
Offset: 123,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tok, want := range tests {
|
||||||
|
got, err := NewStreamTokenFromString(tok)
|
||||||
|
if err != nil {
|
||||||
|
if want == nil {
|
||||||
|
continue // error expected
|
||||||
|
}
|
||||||
|
t.Errorf("%s errored: %s", tok, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, *want) {
|
||||||
|
t.Errorf("%s mismatch: got %v want %v", tok, got, want)
|
||||||
|
}
|
||||||
|
if got.String() != tok {
|
||||||
|
t.Errorf("%s reserialisation mismatch: got %s want %s", tok, got.String(), tok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewSyncTokenFromString(t *testing.T) {
|
func TestNewSyncTokenFromString(t *testing.T) {
|
||||||
shouldPass := map[string]syncToken{
|
shouldPass := map[string]syncToken{
|
||||||
"s4_0": NewStreamToken(4, 0).syncToken,
|
"s4_0": NewStreamToken(4, 0, nil).syncToken,
|
||||||
"s3_1": NewStreamToken(3, 1).syncToken,
|
"s3_1": NewStreamToken(3, 1, nil).syncToken,
|
||||||
"t3_1": NewTopologyToken(3, 1).syncToken,
|
"t3_1": NewTopologyToken(3, 1).syncToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -21,7 +71,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for test, expected := range shouldPass {
|
for test, expected := range shouldPass {
|
||||||
result, err := newSyncTokenFromString(test)
|
result, _, err := newSyncTokenFromString(test)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
@ -31,7 +81,7 @@ func TestNewSyncTokenFromString(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range shouldFail {
|
for _, test := range shouldFail {
|
||||||
if _, err := newSyncTokenFromString(test); err == nil {
|
if _, _, err := newSyncTokenFromString(test); err == nil {
|
||||||
t.Errorf("input '%v' should have errored but didn't", test)
|
t.Errorf("input '%v' should have errored but didn't", test)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,7 @@ Can query specific device keys using POST
|
||||||
query for user with no keys returns empty key dict
|
query for user with no keys returns empty key dict
|
||||||
Can claim one time key using POST
|
Can claim one time key using POST
|
||||||
Can claim remote one time key using POST
|
Can claim remote one time key using POST
|
||||||
|
Local device key changes appear in v2 /sync
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue