Merge branch 'master' into neilalexander/apis

This commit is contained in:
Neil Alexander 2020-08-13 09:34:35 +01:00
commit e02be084a3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 310 additions and 72 deletions

View file

@ -33,7 +33,7 @@ Then point your favourite Matrix client at `http://localhost:8008`. For full ins
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
updates with CI. As of July 2020 we're at around 48% CS API coverage and 50% Federation coverage, though check
updates with CI. As of August 2020 we're at around 52% CS API coverage and 65% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
- Receipts
@ -42,7 +42,6 @@ servers such as matrix.org reasonably well. There's a long list of features that
- User Directory
- Presence
- Guests
- E2E keys and device lists
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
@ -56,6 +55,7 @@ This means Dendrite supports amongst others:
- Media APIs
- Redaction
- Tagging
- E2E keys and device lists
# Contributing

View file

@ -831,6 +831,7 @@ psh Trying to get push rules with unknown scope fails with 400
psh Trying to get push rules with unknown template fails with 400
psh Trying to get push rules with unknown attribute fails with 400
psh Trying to get push rules with unknown rule_id fails with 404
psh Rooms with names are correctly named in pushes
v1s GET /initialSync with non-numeric 'limit'
v1s GET /events with non-numeric 'limit'
v1s GET /events with negative 'limit'
@ -839,7 +840,7 @@ ath Event size limits
syn Check creating invalid filters returns 4xx
f,pre New federated private chats get full presence information (SYN-115)
pre Left room members do not cause problems for presence
crm Rooms can be created with an initial invite list (SYN-205)
crm Rooms can be created with an initial invite list (SYN-205) (1 subtests)
typ Typing notifications don't leak
ban Non-present room members cannot ban others
psh Getting push rules doesn't corrupt the cache SYN-390

View file

@ -29,7 +29,9 @@ import (
"github.com/Shopify/sarama"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/internal"
"github.com/matrix-org/dendrite/currentstateserver/inthttp"
"github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -76,7 +78,24 @@ func init() {
}
}
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) error {
func waitForOffsetProcessed(t *testing.T, db storage.Database, offset int64) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for {
poffsets, err := db.PartitionOffsets(ctx, kafkaTopic)
if err != nil {
t.Fatalf("failed to PartitionOffsets: %s", err)
}
for _, partition := range poffsets {
if partition.Offset >= offset {
return
}
}
time.Sleep(50 * time.Millisecond)
}
}
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) int64 {
value, err := json.Marshal(roomserverAPI.OutputEvent{
Type: roomserverAPI.OutputTypeNewRoomEvent,
NewRoomEvent: out,
@ -84,7 +103,7 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
if err != nil {
t.Fatalf("failed to marshal output event: %s", err)
}
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
_, offset, err := producer.SendMessage(&sarama.ProducerMessage{
Topic: kafkaTopic,
Key: sarama.StringEncoder(out.Event.RoomID()),
Value: sarama.ByteEncoder(value),
@ -92,10 +111,10 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
if err != nil {
t.Fatalf("failed to send message: %s", err)
}
return nil
return offset
}
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) {
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, storage.Database, sarama.SyncProducer, func()) {
cfg := &config.Dendrite{}
cfg.Defaults()
stateDBName := "test_state.db"
@ -117,26 +136,28 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
if err != nil {
t.Fatalf("Failed to create naffka consumer: %s", err)
}
return NewInternalAPI(&cfg.CurrentStateServer, naff), naff, func() {
stateAPI := NewInternalAPI(&cfg.CurrentStateServer, naff)
// type-cast to pull out the DB
stateAPIVal := stateAPI.(*internal.CurrentStateInternalAPI)
return stateAPI, stateAPIVal.DB, naff, func() {
os.Remove(naffkaDBName)
os.Remove(stateDBName)
}
}
func TestQueryCurrentState(t *testing.T) {
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
defer cancel()
plTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.power_levels",
StateKey: "",
}
plEvent := testEvents[4]
MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{
offset := MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{
Event: plEvent,
AddsStateEventIDs: []string{plEvent.EventID()},
})
// we have no good way to know /when/ the server has consumed the event
time.Sleep(100 * time.Millisecond)
waitForOffsetProcessed(t, db, offset)
testCases := []struct {
req api.QueryCurrentStateRequest
@ -228,7 +249,7 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
func TestQuerySharedUsers(t *testing.T) {
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
defer cancel()
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
@ -240,10 +261,8 @@ func TestQuerySharedUsers(t *testing.T) {
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
// we don't know when the server has processed the events
time.Sleep(10 * time.Millisecond)
offset := MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
waitForOffsetProcessed(t, db, offset)
testCases := []struct {
req api.QuerySharedUsersRequest

View file

@ -110,6 +110,11 @@ type OneTimeKeysCount struct {
type PerformUploadKeysRequest struct {
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
// Without this flag, requests to modify device display names would delete device keys.
OnlyDisplayNameUpdates bool
}
// PerformUploadKeysResponse is the response to PerformUploadKeys

View file

@ -67,6 +67,11 @@ type DeviceListUpdater struct {
producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
}
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -80,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
@ -98,12 +104,14 @@ func NewDeviceListUpdater(
numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{},
db: db,
producer: producer,
fedClient: fedClient,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{},
db: db,
producer: producer,
fedClient: fedClient,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
}
}
@ -137,6 +145,22 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
return u.userIDToMutex[userID]
}
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
// Blocks until the device list is synced or the timeout is reached.
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
mu := u.mutex(userID)
mu.Lock()
err := u.db.MarkDeviceListStale(ctx, userID, true)
mu.Unlock()
if err != nil {
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
}
u.notifyWorkers(userID)
return nil
}
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
// which assumes when /send 200 OKs that the device lists have been updated.
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
isDeviceListStale, err := u.update(ctx, event)
if err != nil {
@ -169,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"stream_id": event.StreamID,
"prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
"deleted": event.Deleted,
}).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
if exists {
k := event.Keys
if event.Deleted {
k = nil
}
keys := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys,
KeyJSON: k,
UserID: event.UserID,
},
StreamID: event.StreamID,
},
}
err = u.db.StoreRemoteDeviceKeys(ctx, keys)
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
if err != nil {
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
}
@ -213,7 +242,35 @@ func (u *DeviceListUpdater) notifyWorkers(userID string) {
hash := fnv.New32a()
_, _ = hash.Write([]byte(remoteServer))
index := int(hash.Sum32()) % len(u.workerChans)
ch := u.assignChannel(userID)
u.workerChans[index] <- remoteServer
select {
case <-ch:
case <-time.After(10 * time.Second):
// we don't return an error in this case as it's not a failure condition.
// we mainly block for the benefit of sytest anyway
}
}
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
return ch
}
ch := make(chan bool)
u.userIDToChan[userID] = ch
return ch
}
func (u *DeviceListUpdater) clearChannel(userID string) {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
close(ch)
delete(u.userIDToChan, userID)
}
}
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
@ -230,6 +287,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
}
// on failure, spin up a short-lived goroutine to inject the server name again.
scheduledRetries := make(map[gomatrixserverlib.ServerName]time.Time)
inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
time.Sleep(duration)
ch <- srv
@ -237,13 +295,20 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
for serverName := range ch {
if !shouldProcess(serverName) {
// do not inject into the channel as we know there will be a sleeping goroutine
// which will do it after the cooloff period expires
continue
if time.Now().Before(scheduledRetries[serverName]) {
// do not inject into the channel as we know there will be a sleeping goroutine
// which will do it after the cooloff period expires
continue
} else {
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
go inject(serverName, cooloffPeriod) // TODO: Backoff?
continue
}
}
lastProcessed[serverName] = time.Now()
shouldRetry := u.processServer(serverName)
if shouldRetry {
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
go inject(serverName, cooloffPeriod) // TODO: Backoff?
}
}
@ -277,6 +342,8 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true
} else {
u.clearChannel(userID)
}
}
return hasFailures
@ -301,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
},
}
}
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
if err != nil {
return fmt.Errorf("failed to store remote device keys: %w", err)
}

View file

@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
d.storedKeys = append(d.storedKeys, keys...)
return nil
}
@ -204,16 +204,6 @@ func TestUpdateNoPrevID(t *testing.T) {
if err != nil {
t.Fatalf("Update returned an error: %s", err)
}
// At this point we show have this device list marked as stale and not store the keys or emitted anything
if !db.staleUsers[event.UserID] {
t.Errorf("%s not marked as stale", event.UserID)
}
if len(producer.events) > 0 {
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
}
if len(db.storedKeys) > 0 {
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
}
t.Log("waiting for /users/devices to be called...")
wg.Wait()
// wait a bit for db to be updated...

View file

@ -28,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
maxStreamID = m.StreamID
}
}
res.Devices = msgs
// remove deleted devices
var result []api.DeviceMessage
for _, m := range msgs {
if m.KeyJSON == nil {
continue
}
result = append(result, m)
}
res.Devices = result
res.StreamID = maxStreamID
}
@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
for userID, deviceIDs := range userToDeviceMap {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
// we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
if _, ok := fetchRemote[domain]; !ok {
fetchRemote[domain] = make(map[string][]string)
if len(deviceIDs) > 0 {
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
if err == nil {
continue
}
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
continue
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, key := range keys {
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
// fetch device lists from remote
if _, ok := fetchRemote[domain]; !ok {
fetchRemote[domain] = make(map[string][]string)
}
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
}
}
return fetchRemote
@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys(
defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// for users who we do not have any knowledge about, try to start doing device list updates for them
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
// lack a stream ID.
var userIDsForAllDevices []string
for userID, deviceIDs := range devKeys {
if len(deviceIDs) == 0 {
userIDsForAllDevices = append(userIDsForAllDevices, userID)
delete(devKeys, userID)
}
}
for _, userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
"server": serverName,
}).Error("Failed to manually update device lists for user")
// try to do it via /keys/query
devKeys[userID] = []string{}
continue
}
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
"server": serverName,
}).Error("Failed to manually update device lists for user")
// try to do it via /keys/query
devKeys[userID] = []string{}
continue
}
}
if len(devKeys) == 0 {
return
}
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
if err != nil {
failMu.Lock()
@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
}
if len(keys) < len(deviceIDs) {
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
}
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, key := range keys {
if len(key.KeyJSON) == 0 {
continue // ignore deleted keys
}
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
}
return nil
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
return
}
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage)
}
return a.Producer.ProduceKeyChanges(keysAdded)
}
func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
for i, existingDevice := range existing {
for _, newDevice := range new {
if existingDevice.DeviceID != newDevice.DeviceID {
continue
}
existingDevice.DisplayName = newDevice.DisplayName
existing[i] = existingDevice
}
}
return existing
}

View file

@ -43,8 +43,9 @@ type Database interface {
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)

View file

@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
}
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
return nil
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {

View file

@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
return count == len(prevIDs), nil
}
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}

View file

@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {

View file

@ -38,6 +38,7 @@ type DeviceKeys interface {
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type KeyChanges interface {

View file

@ -96,7 +96,8 @@ func DeviceListCatchup(
return hasNew, nil
}
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs)
var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf(
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
@ -114,13 +115,20 @@ func DeviceListCatchup(
}
// if the response has any join/leave events, add them now.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
for _, userID := range membershipEvents(res) {
joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs {
if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
}
}
for _, userID := range leaveUserIDs {
if sharedUsersMap[userID] == 0 {
// we no longer share a room with this user when they left, so add to left list.
res.DeviceLists.Left = append(res.DeviceLists.Left, userID)
}
}
// set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{
Partition: queryRes.Partition,
@ -221,7 +229,7 @@ func TrackChangedUsers(
func filterSharedUsers(
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string,
) []string {
) (map[string]int, []string) {
var result []string
var sharedUsersRes currentstateAPI.QuerySharedUsersResponse
err := stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
@ -229,7 +237,7 @@ func filterSharedUsers(
}, &sharedUsersRes)
if err != nil {
// default to all users so we do needless queries rather than miss some important device update
return usersWithChangedKeys
return nil, usersWithChangedKeys
}
// We forcibly put ourselves in this list because we should be notified about our own device updates
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
@ -241,7 +249,7 @@ func filterSharedUsers(
result = append(result, uid)
}
}
return result
return sharedUsersRes.UserIDsToCount, result
}
func joinedRooms(res *types.Response, userID string) []string {
@ -288,16 +296,16 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
func membershipEvents(res *types.Response) (userIDs []string) {
func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) {
for _, room := range res.Rooms.Join {
for _, ev := range room.Timeline.Events {
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
if strings.Contains(string(ev.Content), `"join"`) {
userIDs = append(userIDs, *ev.StateKey)
joinUserIDs = append(joinUserIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"leave"`) {
userIDs = append(userIDs, *ev.StateKey)
leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"ban"`) {
userIDs = append(userIDs, *ev.StateKey)
leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
}
}
}

View file

@ -142,7 +142,12 @@ Server correctly handles incoming m.device_list_update
Device deletion propagates over federation
If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one
Get left notifs in sync and /keys/changes when other user leaves
Can query remote device keys using POST after notification
Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room
Can add account data
Can add account data to room
Can get account data without syncing

View file

@ -180,6 +180,27 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
}
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
DeviceKeys: []keyapi.DeviceKeys{
{
DeviceID: dev.ID,
DisplayName: *req.DisplayName,
KeyJSON: nil,
UserID: dev.UserID,
},
},
OnlyDisplayNameUpdates: true,
}, &uploadRes)
if uploadRes.Error != nil {
return fmt.Errorf("Failed to update device key display name: %v", uploadRes.Error)
}
if len(uploadRes.KeyErrors) > 0 {
return fmt.Errorf("Failed to update device key display name, key errors: %+v", uploadRes.KeyErrors)
}
}
return nil
}