Merge branch 'master' into neilalexander/apis

This commit is contained in:
Neil Alexander 2020-08-13 09:34:35 +01:00
commit e02be084a3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
15 changed files with 310 additions and 72 deletions

View file

@ -33,7 +33,7 @@ Then point your favourite Matrix client at `http://localhost:8008`. For full ins
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
updates with CI. As of July 2020 we're at around 48% CS API coverage and 50% Federation coverage, though check updates with CI. As of August 2020 we're at around 52% CS API coverage and 65% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably: servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
- Receipts - Receipts
@ -42,7 +42,6 @@ servers such as matrix.org reasonably well. There's a long list of features that
- User Directory - User Directory
- Presence - Presence
- Guests - Guests
- E2E keys and device lists
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API). than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
@ -56,6 +55,7 @@ This means Dendrite supports amongst others:
- Media APIs - Media APIs
- Redaction - Redaction
- Tagging - Tagging
- E2E keys and device lists
# Contributing # Contributing

View file

@ -831,6 +831,7 @@ psh Trying to get push rules with unknown scope fails with 400
psh Trying to get push rules with unknown template fails with 400 psh Trying to get push rules with unknown template fails with 400
psh Trying to get push rules with unknown attribute fails with 400 psh Trying to get push rules with unknown attribute fails with 400
psh Trying to get push rules with unknown rule_id fails with 404 psh Trying to get push rules with unknown rule_id fails with 404
psh Rooms with names are correctly named in pushes
v1s GET /initialSync with non-numeric 'limit' v1s GET /initialSync with non-numeric 'limit'
v1s GET /events with non-numeric 'limit' v1s GET /events with non-numeric 'limit'
v1s GET /events with negative 'limit' v1s GET /events with negative 'limit'
@ -839,7 +840,7 @@ ath Event size limits
syn Check creating invalid filters returns 4xx syn Check creating invalid filters returns 4xx
f,pre New federated private chats get full presence information (SYN-115) f,pre New federated private chats get full presence information (SYN-115)
pre Left room members do not cause problems for presence pre Left room members do not cause problems for presence
crm Rooms can be created with an initial invite list (SYN-205) crm Rooms can be created with an initial invite list (SYN-205) (1 subtests)
typ Typing notifications don't leak typ Typing notifications don't leak
ban Non-present room members cannot ban others ban Non-present room members cannot ban others
psh Getting push rules doesn't corrupt the cache SYN-390 psh Getting push rules doesn't corrupt the cache SYN-390

View file

@ -29,7 +29,9 @@ import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/internal"
"github.com/matrix-org/dendrite/currentstateserver/inthttp" "github.com/matrix-org/dendrite/currentstateserver/inthttp"
"github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -76,7 +78,24 @@ func init() {
} }
} }
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) error { func waitForOffsetProcessed(t *testing.T, db storage.Database, offset int64) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for {
poffsets, err := db.PartitionOffsets(ctx, kafkaTopic)
if err != nil {
t.Fatalf("failed to PartitionOffsets: %s", err)
}
for _, partition := range poffsets {
if partition.Offset >= offset {
return
}
}
time.Sleep(50 * time.Millisecond)
}
}
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) int64 {
value, err := json.Marshal(roomserverAPI.OutputEvent{ value, err := json.Marshal(roomserverAPI.OutputEvent{
Type: roomserverAPI.OutputTypeNewRoomEvent, Type: roomserverAPI.OutputTypeNewRoomEvent,
NewRoomEvent: out, NewRoomEvent: out,
@ -84,7 +103,7 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
if err != nil { if err != nil {
t.Fatalf("failed to marshal output event: %s", err) t.Fatalf("failed to marshal output event: %s", err)
} }
_, _, err = producer.SendMessage(&sarama.ProducerMessage{ _, offset, err := producer.SendMessage(&sarama.ProducerMessage{
Topic: kafkaTopic, Topic: kafkaTopic,
Key: sarama.StringEncoder(out.Event.RoomID()), Key: sarama.StringEncoder(out.Event.RoomID()),
Value: sarama.ByteEncoder(value), Value: sarama.ByteEncoder(value),
@ -92,10 +111,10 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
if err != nil { if err != nil {
t.Fatalf("failed to send message: %s", err) t.Fatalf("failed to send message: %s", err)
} }
return nil return offset
} }
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) { func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, storage.Database, sarama.SyncProducer, func()) {
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
cfg.Defaults() cfg.Defaults()
stateDBName := "test_state.db" stateDBName := "test_state.db"
@ -117,26 +136,28 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
if err != nil { if err != nil {
t.Fatalf("Failed to create naffka consumer: %s", err) t.Fatalf("Failed to create naffka consumer: %s", err)
} }
return NewInternalAPI(&cfg.CurrentStateServer, naff), naff, func() { stateAPI := NewInternalAPI(&cfg.CurrentStateServer, naff)
// type-cast to pull out the DB
stateAPIVal := stateAPI.(*internal.CurrentStateInternalAPI)
return stateAPI, stateAPIVal.DB, naff, func() {
os.Remove(naffkaDBName) os.Remove(naffkaDBName)
os.Remove(stateDBName) os.Remove(stateDBName)
} }
} }
func TestQueryCurrentState(t *testing.T) { func TestQueryCurrentState(t *testing.T) {
currStateAPI, producer, cancel := MustMakeInternalAPI(t) currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
defer cancel() defer cancel()
plTuple := gomatrixserverlib.StateKeyTuple{ plTuple := gomatrixserverlib.StateKeyTuple{
EventType: "m.room.power_levels", EventType: "m.room.power_levels",
StateKey: "", StateKey: "",
} }
plEvent := testEvents[4] plEvent := testEvents[4]
MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{ offset := MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{
Event: plEvent, Event: plEvent,
AddsStateEventIDs: []string{plEvent.EventID()}, AddsStateEventIDs: []string{plEvent.EventID()},
}) })
// we have no good way to know /when/ the server has consumed the event waitForOffsetProcessed(t, db, offset)
time.Sleep(100 * time.Millisecond)
testCases := []struct { testCases := []struct {
req api.QueryCurrentStateRequest req api.QueryCurrentStateRequest
@ -228,7 +249,7 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets. // This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
func TestQuerySharedUsers(t *testing.T) { func TestQuerySharedUsers(t *testing.T) {
currStateAPI, producer, cancel := MustMakeInternalAPI(t) currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
defer cancel() defer cancel()
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
@ -240,10 +261,8 @@ func TestQuerySharedUsers(t *testing.T) {
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave")) MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join")) offset := MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
waitForOffsetProcessed(t, db, offset)
// we don't know when the server has processed the events
time.Sleep(10 * time.Millisecond)
testCases := []struct { testCases := []struct {
req api.QuerySharedUsersRequest req api.QuerySharedUsersRequest

View file

@ -110,6 +110,11 @@ type OneTimeKeysCount struct {
type PerformUploadKeysRequest struct { type PerformUploadKeysRequest struct {
DeviceKeys []DeviceKeys DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
// Without this flag, requests to modify device display names would delete device keys.
OnlyDisplayNameUpdates bool
} }
// PerformUploadKeysResponse is the response to PerformUploadKeys // PerformUploadKeysResponse is the response to PerformUploadKeys

View file

@ -67,6 +67,11 @@ type DeviceListUpdater struct {
producer KeyChangeProducer producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName workerChans []chan gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select.
userIDToChan map[string]chan bool
userIDToChanMu *sync.Mutex
} }
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
@ -80,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
@ -98,12 +104,14 @@ func NewDeviceListUpdater(
numWorkers int, numWorkers int,
) *DeviceListUpdater { ) *DeviceListUpdater {
return &DeviceListUpdater{ return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex), userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{}, mu: &sync.Mutex{},
db: db, db: db,
producer: producer, producer: producer,
fedClient: fedClient, fedClient: fedClient,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
} }
} }
@ -137,6 +145,22 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
return u.userIDToMutex[userID] return u.userIDToMutex[userID]
} }
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
// Blocks until the device list is synced or the timeout is reached.
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
mu := u.mutex(userID)
mu.Lock()
err := u.db.MarkDeviceListStale(ctx, userID, true)
mu.Unlock()
if err != nil {
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
}
u.notifyWorkers(userID)
return nil
}
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
// which assumes when /send 200 OKs that the device lists have been updated.
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
isDeviceListStale, err := u.update(ctx, event) isDeviceListStale, err := u.update(ctx, event)
if err != nil { if err != nil {
@ -169,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"stream_id": event.StreamID, "stream_id": event.StreamID,
"prev_ids": event.PrevID, "prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName, "display_name": event.DeviceDisplayName,
"deleted": event.Deleted,
}).Info("DeviceListUpdater.Update") }).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users // if we haven't missed anything update the database and notify users
if exists { if exists {
k := event.Keys
if event.Deleted {
k = nil
}
keys := []api.DeviceMessage{ keys := []api.DeviceMessage{
{ {
DeviceKeys: api.DeviceKeys{ DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID, DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName, DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys, KeyJSON: k,
UserID: event.UserID, UserID: event.UserID,
}, },
StreamID: event.StreamID, StreamID: event.StreamID,
}, },
} }
err = u.db.StoreRemoteDeviceKeys(ctx, keys) err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
} }
@ -213,7 +242,35 @@ func (u *DeviceListUpdater) notifyWorkers(userID string) {
hash := fnv.New32a() hash := fnv.New32a()
_, _ = hash.Write([]byte(remoteServer)) _, _ = hash.Write([]byte(remoteServer))
index := int(hash.Sum32()) % len(u.workerChans) index := int(hash.Sum32()) % len(u.workerChans)
ch := u.assignChannel(userID)
u.workerChans[index] <- remoteServer u.workerChans[index] <- remoteServer
select {
case <-ch:
case <-time.After(10 * time.Second):
// we don't return an error in this case as it's not a failure condition.
// we mainly block for the benefit of sytest anyway
}
}
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
return ch
}
ch := make(chan bool)
u.userIDToChan[userID] = ch
return ch
}
func (u *DeviceListUpdater) clearChannel(userID string) {
u.userIDToChanMu.Lock()
defer u.userIDToChanMu.Unlock()
if ch, ok := u.userIDToChan[userID]; ok {
close(ch)
delete(u.userIDToChan, userID)
}
} }
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
@ -230,6 +287,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
} }
// on failure, spin up a short-lived goroutine to inject the server name again. // on failure, spin up a short-lived goroutine to inject the server name again.
scheduledRetries := make(map[gomatrixserverlib.ServerName]time.Time)
inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) { inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
time.Sleep(duration) time.Sleep(duration)
ch <- srv ch <- srv
@ -237,13 +295,20 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
for serverName := range ch { for serverName := range ch {
if !shouldProcess(serverName) { if !shouldProcess(serverName) {
// do not inject into the channel as we know there will be a sleeping goroutine if time.Now().Before(scheduledRetries[serverName]) {
// which will do it after the cooloff period expires // do not inject into the channel as we know there will be a sleeping goroutine
continue // which will do it after the cooloff period expires
continue
} else {
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
go inject(serverName, cooloffPeriod) // TODO: Backoff?
continue
}
} }
lastProcessed[serverName] = time.Now() lastProcessed[serverName] = time.Now()
shouldRetry := u.processServer(serverName) shouldRetry := u.processServer(serverName)
if shouldRetry { if shouldRetry {
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
go inject(serverName, cooloffPeriod) // TODO: Backoff? go inject(serverName, cooloffPeriod) // TODO: Backoff?
} }
} }
@ -277,6 +342,8 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
if err != nil { if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it") logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true hasFailures = true
} else {
u.clearChannel(userID)
} }
} }
return hasFailures return hasFailures
@ -301,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
}, },
} }
} }
err := u.db.StoreRemoteDeviceKeys(ctx, keys) err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
if err != nil { if err != nil {
return fmt.Errorf("failed to store remote device keys: %w", err) return fmt.Errorf("failed to store remote device keys: %w", err)
} }

View file

@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys.
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
d.storedKeys = append(d.storedKeys, keys...) d.storedKeys = append(d.storedKeys, keys...)
return nil return nil
} }
@ -204,16 +204,6 @@ func TestUpdateNoPrevID(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Update returned an error: %s", err) t.Fatalf("Update returned an error: %s", err)
} }
// At this point we show have this device list marked as stale and not store the keys or emitted anything
if !db.staleUsers[event.UserID] {
t.Errorf("%s not marked as stale", event.UserID)
}
if len(producer.events) > 0 {
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
}
if len(db.storedKeys) > 0 {
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
}
t.Log("waiting for /users/devices to be called...") t.Log("waiting for /users/devices to be called...")
wg.Wait() wg.Wait()
// wait a bit for db to be updated... // wait a bit for db to be updated...

View file

@ -28,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
maxStreamID = m.StreamID maxStreamID = m.StreamID
} }
} }
res.Devices = msgs // remove deleted devices
var result []api.DeviceMessage
for _, m := range msgs {
if m.KeyJSON == nil {
continue
}
result = append(result, m)
}
res.Devices = result
res.StreamID = maxStreamID res.StreamID = maxStreamID
} }
@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
fetchRemote := make(map[string]map[string][]string) fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys { for domain, userToDeviceMap := range domainToDeviceKeys {
for userID, deviceIDs := range userToDeviceMap { for userID, deviceIDs := range userToDeviceMap {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) // we can't safely return keys from the db when all devices are requested as we don't
// if we can't query the db or there are fewer keys than requested, fetch from remote.
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added. // know if one has just been added.
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) { if len(deviceIDs) > 0 {
if _, ok := fetchRemote[domain]; !ok { err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
fetchRemote[domain] = make(map[string][]string) if err == nil {
continue
} }
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
continue
} }
if res.DeviceKeys[userID] == nil { // fetch device lists from remote
res.DeviceKeys[userID] = make(map[string]json.RawMessage) if _, ok := fetchRemote[domain]; !ok {
} fetchRemote[domain] = make(map[string][]string)
for _, key := range keys {
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
} }
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
} }
} }
return fetchRemote return fetchRemote
@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys(
defer wg.Done() defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout) fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
// for users who we do not have any knowledge about, try to start doing device list updates for them
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
// lack a stream ID.
var userIDsForAllDevices []string
for userID, deviceIDs := range devKeys {
if len(deviceIDs) == 0 {
userIDsForAllDevices = append(userIDsForAllDevices, userID)
delete(devKeys, userID)
}
}
for _, userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
"server": serverName,
}).Error("Failed to manually update device lists for user")
// try to do it via /keys/query
devKeys[userID] = []string{}
continue
}
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
if err != nil {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
"server": serverName,
}).Error("Failed to manually update device lists for user")
// try to do it via /keys/query
devKeys[userID] = []string{}
continue
}
}
if len(devKeys) == 0 {
return
}
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
if err != nil { if err != nil {
failMu.Lock() failMu.Lock()
@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys(
} }
} }
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
}
if len(keys) < len(deviceIDs) {
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
}
if len(deviceIDs) == 0 && len(keys) == 0 {
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
}
if res.DeviceKeys[userID] == nil {
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
}
for _, key := range keys {
if len(key.KeyJSON) == 0 {
continue // ignore deleted keys
}
// inject the display name
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
}{key.DisplayName})
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
}
return nil
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceMessage var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key // assert that the user ID / device ID are not lying for each key
@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
} }
return return
} }
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes // store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore) err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil { if err != nil {
@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage)
} }
return a.Producer.ProduceKeyChanges(keysAdded) return a.Producer.ProduceKeyChanges(keysAdded)
} }
func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
for i, existingDevice := range existing {
for _, newDevice := range new {
if existingDevice.DeviceID != newDevice.DeviceID {
continue
}
existingDevice.DisplayName = newDevice.DisplayName
existing[i] = existingDevice
}
}
return existing
}

View file

@ -43,8 +43,9 @@ type Database interface {
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user. // PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)

View file

@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" + const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt countStreamIDsForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
return nil return nil
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {

View file

@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
return count == len(prevIDs), nil return count == len(prevIDs), nil
} }
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
}) })
} }

View file

@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" + const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter writer *sqlutil.TransactionWriter
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {

View file

@ -38,6 +38,7 @@ type DeviceKeys interface {
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
} }
type KeyChanges interface { type KeyChanges interface {

View file

@ -96,7 +96,8 @@ func DeviceListCatchup(
return hasNew, nil return hasNew, nil
} }
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user. // QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs) var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf( util.GetLogger(ctx).Debugf(
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
@ -114,13 +115,20 @@ func DeviceListCatchup(
} }
// if the response has any join/leave events, add them now. // if the response has any join/leave events, add them now.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them. // TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
for _, userID := range membershipEvents(res) { joinUserIDs, leaveUserIDs := membershipEvents(res)
for _, userID := range joinUserIDs {
if !userSet[userID] { if !userSet[userID] {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID) res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true hasNew = true
userSet[userID] = true userSet[userID] = true
} }
} }
for _, userID := range leaveUserIDs {
if sharedUsersMap[userID] == 0 {
// we no longer share a room with this user when they left, so add to left list.
res.DeviceLists.Left = append(res.DeviceLists.Left, userID)
}
}
// set the new token // set the new token
to.SetLog(DeviceListLogName, &types.LogPosition{ to.SetLog(DeviceListLogName, &types.LogPosition{
Partition: queryRes.Partition, Partition: queryRes.Partition,
@ -221,7 +229,7 @@ func TrackChangedUsers(
func filterSharedUsers( func filterSharedUsers(
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string, ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string,
) []string { ) (map[string]int, []string) {
var result []string var result []string
var sharedUsersRes currentstateAPI.QuerySharedUsersResponse var sharedUsersRes currentstateAPI.QuerySharedUsersResponse
err := stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{ err := stateAPI.QuerySharedUsers(ctx, &currentstateAPI.QuerySharedUsersRequest{
@ -229,7 +237,7 @@ func filterSharedUsers(
}, &sharedUsersRes) }, &sharedUsersRes)
if err != nil { if err != nil {
// default to all users so we do needless queries rather than miss some important device update // default to all users so we do needless queries rather than miss some important device update
return usersWithChangedKeys return nil, usersWithChangedKeys
} }
// We forcibly put ourselves in this list because we should be notified about our own device updates // We forcibly put ourselves in this list because we should be notified about our own device updates
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
@ -241,7 +249,7 @@ func filterSharedUsers(
result = append(result, uid) result = append(result, uid)
} }
} }
return result return sharedUsersRes.UserIDsToCount, result
} }
func joinedRooms(res *types.Response, userID string) []string { func joinedRooms(res *types.Response, userID string) []string {
@ -288,16 +296,16 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device, // "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler // or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room." // logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
func membershipEvents(res *types.Response) (userIDs []string) { func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) {
for _, room := range res.Rooms.Join { for _, room := range res.Rooms.Join {
for _, ev := range room.Timeline.Events { for _, ev := range room.Timeline.Events {
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil { if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
if strings.Contains(string(ev.Content), `"join"`) { if strings.Contains(string(ev.Content), `"join"`) {
userIDs = append(userIDs, *ev.StateKey) joinUserIDs = append(joinUserIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"leave"`) { } else if strings.Contains(string(ev.Content), `"leave"`) {
userIDs = append(userIDs, *ev.StateKey) leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
} else if strings.Contains(string(ev.Content), `"ban"`) { } else if strings.Contains(string(ev.Content), `"ban"`) {
userIDs = append(userIDs, *ev.StateKey) leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
} }
} }
} }

View file

@ -142,7 +142,12 @@ Server correctly handles incoming m.device_list_update
Device deletion propagates over federation Device deletion propagates over federation
If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one
Get left notifs in sync and /keys/changes when other user leaves
Can query remote device keys using POST after notification Can query remote device keys using POST after notification
Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room
Can add account data Can add account data
Can add account data to room Can add account data to room
Can get account data without syncing Can get account data without syncing

View file

@ -180,6 +180,27 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err return err
} }
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
DeviceKeys: []keyapi.DeviceKeys{
{
DeviceID: dev.ID,
DisplayName: *req.DisplayName,
KeyJSON: nil,
UserID: dev.UserID,
},
},
OnlyDisplayNameUpdates: true,
}, &uploadRes)
if uploadRes.Error != nil {
return fmt.Errorf("Failed to update device key display name: %v", uploadRes.Error)
}
if len(uploadRes.KeyErrors) > 0 {
return fmt.Errorf("Failed to update device key display name, key errors: %+v", uploadRes.KeyErrors)
}
}
return nil return nil
} }