mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 22:03:10 -06:00
Merge branch 'master' into neilalexander/apis
This commit is contained in:
commit
e02be084a3
|
|
@ -33,7 +33,7 @@ Then point your favourite Matrix client at `http://localhost:8008`. For full ins
|
||||||
|
|
||||||
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
|
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
|
||||||
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
|
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
|
||||||
updates with CI. As of July 2020 we're at around 48% CS API coverage and 50% Federation coverage, though check
|
updates with CI. As of August 2020 we're at around 52% CS API coverage and 65% Federation coverage, though check
|
||||||
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
|
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
|
||||||
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
|
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
|
||||||
- Receipts
|
- Receipts
|
||||||
|
|
@ -42,7 +42,6 @@ servers such as matrix.org reasonably well. There's a long list of features that
|
||||||
- User Directory
|
- User Directory
|
||||||
- Presence
|
- Presence
|
||||||
- Guests
|
- Guests
|
||||||
- E2E keys and device lists
|
|
||||||
|
|
||||||
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
|
We are prioritising features that will benefit single-user homeservers first (e.g Receipts, E2E) rather
|
||||||
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
|
than features that massive deployments may be interested in (User Directory, OpenID, Guests, Admin APIs, AS API).
|
||||||
|
|
@ -56,6 +55,7 @@ This means Dendrite supports amongst others:
|
||||||
- Media APIs
|
- Media APIs
|
||||||
- Redaction
|
- Redaction
|
||||||
- Tagging
|
- Tagging
|
||||||
|
- E2E keys and device lists
|
||||||
|
|
||||||
|
|
||||||
# Contributing
|
# Contributing
|
||||||
|
|
|
||||||
|
|
@ -831,6 +831,7 @@ psh Trying to get push rules with unknown scope fails with 400
|
||||||
psh Trying to get push rules with unknown template fails with 400
|
psh Trying to get push rules with unknown template fails with 400
|
||||||
psh Trying to get push rules with unknown attribute fails with 400
|
psh Trying to get push rules with unknown attribute fails with 400
|
||||||
psh Trying to get push rules with unknown rule_id fails with 404
|
psh Trying to get push rules with unknown rule_id fails with 404
|
||||||
|
psh Rooms with names are correctly named in pushes
|
||||||
v1s GET /initialSync with non-numeric 'limit'
|
v1s GET /initialSync with non-numeric 'limit'
|
||||||
v1s GET /events with non-numeric 'limit'
|
v1s GET /events with non-numeric 'limit'
|
||||||
v1s GET /events with negative 'limit'
|
v1s GET /events with negative 'limit'
|
||||||
|
|
@ -839,7 +840,7 @@ ath Event size limits
|
||||||
syn Check creating invalid filters returns 4xx
|
syn Check creating invalid filters returns 4xx
|
||||||
f,pre New federated private chats get full presence information (SYN-115)
|
f,pre New federated private chats get full presence information (SYN-115)
|
||||||
pre Left room members do not cause problems for presence
|
pre Left room members do not cause problems for presence
|
||||||
crm Rooms can be created with an initial invite list (SYN-205)
|
crm Rooms can be created with an initial invite list (SYN-205) (1 subtests)
|
||||||
typ Typing notifications don't leak
|
typ Typing notifications don't leak
|
||||||
ban Non-present room members cannot ban others
|
ban Non-present room members cannot ban others
|
||||||
psh Getting push rules doesn't corrupt the cache SYN-390
|
psh Getting push rules doesn't corrupt the cache SYN-390
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,9 @@ import (
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/api"
|
"github.com/matrix-org/dendrite/currentstateserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/currentstateserver/internal"
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/inthttp"
|
"github.com/matrix-org/dendrite/currentstateserver/inthttp"
|
||||||
|
"github.com/matrix-org/dendrite/currentstateserver/storage"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
@ -76,7 +78,24 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) error {
|
func waitForOffsetProcessed(t *testing.T, db storage.Database, offset int64) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
poffsets, err := db.PartitionOffsets(ctx, kafkaTopic)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to PartitionOffsets: %s", err)
|
||||||
|
}
|
||||||
|
for _, partition := range poffsets {
|
||||||
|
if partition.Offset >= offset {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *roomserverAPI.OutputNewRoomEvent) int64 {
|
||||||
value, err := json.Marshal(roomserverAPI.OutputEvent{
|
value, err := json.Marshal(roomserverAPI.OutputEvent{
|
||||||
Type: roomserverAPI.OutputTypeNewRoomEvent,
|
Type: roomserverAPI.OutputTypeNewRoomEvent,
|
||||||
NewRoomEvent: out,
|
NewRoomEvent: out,
|
||||||
|
|
@ -84,7 +103,7 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to marshal output event: %s", err)
|
t.Fatalf("failed to marshal output event: %s", err)
|
||||||
}
|
}
|
||||||
_, _, err = producer.SendMessage(&sarama.ProducerMessage{
|
_, offset, err := producer.SendMessage(&sarama.ProducerMessage{
|
||||||
Topic: kafkaTopic,
|
Topic: kafkaTopic,
|
||||||
Key: sarama.StringEncoder(out.Event.RoomID()),
|
Key: sarama.StringEncoder(out.Event.RoomID()),
|
||||||
Value: sarama.ByteEncoder(value),
|
Value: sarama.ByteEncoder(value),
|
||||||
|
|
@ -92,10 +111,10 @@ func MustWriteOutputEvent(t *testing.T, producer sarama.SyncProducer, out *rooms
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to send message: %s", err)
|
t.Fatalf("failed to send message: %s", err)
|
||||||
}
|
}
|
||||||
return nil
|
return offset
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.SyncProducer, func()) {
|
func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, storage.Database, sarama.SyncProducer, func()) {
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
cfg.Defaults()
|
cfg.Defaults()
|
||||||
stateDBName := "test_state.db"
|
stateDBName := "test_state.db"
|
||||||
|
|
@ -117,26 +136,28 @@ func MustMakeInternalAPI(t *testing.T) (api.CurrentStateInternalAPI, sarama.Sync
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create naffka consumer: %s", err)
|
t.Fatalf("Failed to create naffka consumer: %s", err)
|
||||||
}
|
}
|
||||||
return NewInternalAPI(&cfg.CurrentStateServer, naff), naff, func() {
|
stateAPI := NewInternalAPI(&cfg.CurrentStateServer, naff)
|
||||||
|
// type-cast to pull out the DB
|
||||||
|
stateAPIVal := stateAPI.(*internal.CurrentStateInternalAPI)
|
||||||
|
return stateAPI, stateAPIVal.DB, naff, func() {
|
||||||
os.Remove(naffkaDBName)
|
os.Remove(naffkaDBName)
|
||||||
os.Remove(stateDBName)
|
os.Remove(stateDBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryCurrentState(t *testing.T) {
|
func TestQueryCurrentState(t *testing.T) {
|
||||||
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
plTuple := gomatrixserverlib.StateKeyTuple{
|
plTuple := gomatrixserverlib.StateKeyTuple{
|
||||||
EventType: "m.room.power_levels",
|
EventType: "m.room.power_levels",
|
||||||
StateKey: "",
|
StateKey: "",
|
||||||
}
|
}
|
||||||
plEvent := testEvents[4]
|
plEvent := testEvents[4]
|
||||||
MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{
|
offset := MustWriteOutputEvent(t, producer, &roomserverAPI.OutputNewRoomEvent{
|
||||||
Event: plEvent,
|
Event: plEvent,
|
||||||
AddsStateEventIDs: []string{plEvent.EventID()},
|
AddsStateEventIDs: []string{plEvent.EventID()},
|
||||||
})
|
})
|
||||||
// we have no good way to know /when/ the server has consumed the event
|
waitForOffsetProcessed(t, db, offset)
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
req api.QueryCurrentStateRequest
|
req api.QueryCurrentStateRequest
|
||||||
|
|
@ -228,7 +249,7 @@ func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *r
|
||||||
|
|
||||||
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
|
||||||
func TestQuerySharedUsers(t *testing.T) {
|
func TestQuerySharedUsers(t *testing.T) {
|
||||||
currStateAPI, producer, cancel := MustMakeInternalAPI(t)
|
currStateAPI, db, producer, cancel := MustMakeInternalAPI(t)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
|
||||||
|
|
@ -240,10 +261,8 @@ func TestQuerySharedUsers(t *testing.T) {
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
|
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
|
||||||
|
|
||||||
MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
offset := MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
|
||||||
|
waitForOffsetProcessed(t, db, offset)
|
||||||
// we don't know when the server has processed the events
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
req api.QuerySharedUsersRequest
|
req api.QuerySharedUsersRequest
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,11 @@ type OneTimeKeysCount struct {
|
||||||
type PerformUploadKeysRequest struct {
|
type PerformUploadKeysRequest struct {
|
||||||
DeviceKeys []DeviceKeys
|
DeviceKeys []DeviceKeys
|
||||||
OneTimeKeys []OneTimeKeys
|
OneTimeKeys []OneTimeKeys
|
||||||
|
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||||
|
// the display name for their respective device, and NOT to modify the keys. The key
|
||||||
|
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||||
|
// Without this flag, requests to modify device display names would delete device keys.
|
||||||
|
OnlyDisplayNameUpdates bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformUploadKeysResponse is the response to PerformUploadKeys
|
// PerformUploadKeysResponse is the response to PerformUploadKeys
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,11 @@ type DeviceListUpdater struct {
|
||||||
producer KeyChangeProducer
|
producer KeyChangeProducer
|
||||||
fedClient *gomatrixserverlib.FederationClient
|
fedClient *gomatrixserverlib.FederationClient
|
||||||
workerChans []chan gomatrixserverlib.ServerName
|
workerChans []chan gomatrixserverlib.ServerName
|
||||||
|
|
||||||
|
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
||||||
|
// block on or timeout via a select.
|
||||||
|
userIDToChan map[string]chan bool
|
||||||
|
userIDToChanMu *sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||||
|
|
@ -80,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
|
||||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||||
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||||
|
|
||||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
|
|
@ -98,12 +104,14 @@ func NewDeviceListUpdater(
|
||||||
numWorkers int,
|
numWorkers int,
|
||||||
) *DeviceListUpdater {
|
) *DeviceListUpdater {
|
||||||
return &DeviceListUpdater{
|
return &DeviceListUpdater{
|
||||||
userIDToMutex: make(map[string]*sync.Mutex),
|
userIDToMutex: make(map[string]*sync.Mutex),
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
db: db,
|
db: db,
|
||||||
producer: producer,
|
producer: producer,
|
||||||
fedClient: fedClient,
|
fedClient: fedClient,
|
||||||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||||
|
userIDToChan: make(map[string]chan bool),
|
||||||
|
userIDToChanMu: &sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -137,6 +145,22 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||||
return u.userIDToMutex[userID]
|
return u.userIDToMutex[userID]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
|
||||||
|
// Blocks until the device list is synced or the timeout is reached.
|
||||||
|
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
|
||||||
|
mu := u.mutex(userID)
|
||||||
|
mu.Lock()
|
||||||
|
err := u.db.MarkDeviceListStale(ctx, userID, true)
|
||||||
|
mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
|
||||||
|
}
|
||||||
|
u.notifyWorkers(userID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
|
||||||
|
// which assumes when /send 200 OKs that the device lists have been updated.
|
||||||
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
||||||
isDeviceListStale, err := u.update(ctx, event)
|
isDeviceListStale, err := u.update(ctx, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -169,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
||||||
"stream_id": event.StreamID,
|
"stream_id": event.StreamID,
|
||||||
"prev_ids": event.PrevID,
|
"prev_ids": event.PrevID,
|
||||||
"display_name": event.DeviceDisplayName,
|
"display_name": event.DeviceDisplayName,
|
||||||
|
"deleted": event.Deleted,
|
||||||
}).Info("DeviceListUpdater.Update")
|
}).Info("DeviceListUpdater.Update")
|
||||||
|
|
||||||
// if we haven't missed anything update the database and notify users
|
// if we haven't missed anything update the database and notify users
|
||||||
if exists {
|
if exists {
|
||||||
|
k := event.Keys
|
||||||
|
if event.Deleted {
|
||||||
|
k = nil
|
||||||
|
}
|
||||||
keys := []api.DeviceMessage{
|
keys := []api.DeviceMessage{
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: api.DeviceKeys{
|
||||||
DeviceID: event.DeviceID,
|
DeviceID: event.DeviceID,
|
||||||
DisplayName: event.DeviceDisplayName,
|
DisplayName: event.DeviceDisplayName,
|
||||||
KeyJSON: event.Keys,
|
KeyJSON: k,
|
||||||
UserID: event.UserID,
|
UserID: event.UserID,
|
||||||
},
|
},
|
||||||
StreamID: event.StreamID,
|
StreamID: event.StreamID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys)
|
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||||
}
|
}
|
||||||
|
|
@ -213,7 +242,35 @@ func (u *DeviceListUpdater) notifyWorkers(userID string) {
|
||||||
hash := fnv.New32a()
|
hash := fnv.New32a()
|
||||||
_, _ = hash.Write([]byte(remoteServer))
|
_, _ = hash.Write([]byte(remoteServer))
|
||||||
index := int(hash.Sum32()) % len(u.workerChans)
|
index := int(hash.Sum32()) % len(u.workerChans)
|
||||||
|
|
||||||
|
ch := u.assignChannel(userID)
|
||||||
u.workerChans[index] <- remoteServer
|
u.workerChans[index] <- remoteServer
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
// we don't return an error in this case as it's not a failure condition.
|
||||||
|
// we mainly block for the benefit of sytest anyway
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
|
||||||
|
u.userIDToChanMu.Lock()
|
||||||
|
defer u.userIDToChanMu.Unlock()
|
||||||
|
if ch, ok := u.userIDToChan[userID]; ok {
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
ch := make(chan bool)
|
||||||
|
u.userIDToChan[userID] = ch
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *DeviceListUpdater) clearChannel(userID string) {
|
||||||
|
u.userIDToChanMu.Lock()
|
||||||
|
defer u.userIDToChanMu.Unlock()
|
||||||
|
if ch, ok := u.userIDToChan[userID]; ok {
|
||||||
|
close(ch)
|
||||||
|
delete(u.userIDToChan, userID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||||
|
|
@ -230,6 +287,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// on failure, spin up a short-lived goroutine to inject the server name again.
|
// on failure, spin up a short-lived goroutine to inject the server name again.
|
||||||
|
scheduledRetries := make(map[gomatrixserverlib.ServerName]time.Time)
|
||||||
inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
|
inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
|
||||||
time.Sleep(duration)
|
time.Sleep(duration)
|
||||||
ch <- srv
|
ch <- srv
|
||||||
|
|
@ -237,13 +295,20 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||||
|
|
||||||
for serverName := range ch {
|
for serverName := range ch {
|
||||||
if !shouldProcess(serverName) {
|
if !shouldProcess(serverName) {
|
||||||
// do not inject into the channel as we know there will be a sleeping goroutine
|
if time.Now().Before(scheduledRetries[serverName]) {
|
||||||
// which will do it after the cooloff period expires
|
// do not inject into the channel as we know there will be a sleeping goroutine
|
||||||
continue
|
// which will do it after the cooloff period expires
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
|
||||||
|
go inject(serverName, cooloffPeriod) // TODO: Backoff?
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lastProcessed[serverName] = time.Now()
|
lastProcessed[serverName] = time.Now()
|
||||||
shouldRetry := u.processServer(serverName)
|
shouldRetry := u.processServer(serverName)
|
||||||
if shouldRetry {
|
if shouldRetry {
|
||||||
|
scheduledRetries[serverName] = time.Now().Add(cooloffPeriod)
|
||||||
go inject(serverName, cooloffPeriod) // TODO: Backoff?
|
go inject(serverName, cooloffPeriod) // TODO: Backoff?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -277,6 +342,8 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
|
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
|
||||||
hasFailures = true
|
hasFailures = true
|
||||||
|
} else {
|
||||||
|
u.clearChannel(userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hasFailures
|
return hasFailures
|
||||||
|
|
@ -301,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys.
|
||||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||||
d.storedKeys = append(d.storedKeys, keys...)
|
d.storedKeys = append(d.storedKeys, keys...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -204,16 +204,6 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Update returned an error: %s", err)
|
t.Fatalf("Update returned an error: %s", err)
|
||||||
}
|
}
|
||||||
// At this point we show have this device list marked as stale and not store the keys or emitted anything
|
|
||||||
if !db.staleUsers[event.UserID] {
|
|
||||||
t.Errorf("%s not marked as stale", event.UserID)
|
|
||||||
}
|
|
||||||
if len(producer.events) > 0 {
|
|
||||||
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
|
|
||||||
}
|
|
||||||
if len(db.storedKeys) > 0 {
|
|
||||||
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
|
|
||||||
}
|
|
||||||
t.Log("waiting for /users/devices to be called...")
|
t.Log("waiting for /users/devices to be called...")
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
// wait a bit for db to be updated...
|
// wait a bit for db to be updated...
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ import (
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
||||||
maxStreamID = m.StreamID
|
maxStreamID = m.StreamID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.Devices = msgs
|
// remove deleted devices
|
||||||
|
var result []api.DeviceMessage
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m.KeyJSON == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, m)
|
||||||
|
}
|
||||||
|
res.Devices = result
|
||||||
res.StreamID = maxStreamID
|
res.StreamID = maxStreamID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||||
fetchRemote := make(map[string]map[string][]string)
|
fetchRemote := make(map[string]map[string][]string)
|
||||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||||
for userID, deviceIDs := range userToDeviceMap {
|
for userID, deviceIDs := range userToDeviceMap {
|
||||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
// we can't safely return keys from the db when all devices are requested as we don't
|
||||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
|
||||||
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
|
|
||||||
// know if one has just been added.
|
// know if one has just been added.
|
||||||
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
|
if len(deviceIDs) > 0 {
|
||||||
if _, ok := fetchRemote[domain]; !ok {
|
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
||||||
fetchRemote[domain] = make(map[string][]string)
|
if err == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if res.DeviceKeys[userID] == nil {
|
// fetch device lists from remote
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
if _, ok := fetchRemote[domain]; !ok {
|
||||||
}
|
fetchRemote[domain] = make(map[string][]string)
|
||||||
for _, key := range keys {
|
|
||||||
// inject the display name
|
|
||||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
|
||||||
}{key.DisplayName})
|
|
||||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
|
||||||
}
|
}
|
||||||
|
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fetchRemote
|
return fetchRemote
|
||||||
|
|
@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
// for users who we do not have any knowledge about, try to start doing device list updates for them
|
||||||
|
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
|
||||||
|
// lack a stream ID.
|
||||||
|
var userIDsForAllDevices []string
|
||||||
|
for userID, deviceIDs := range devKeys {
|
||||||
|
if len(deviceIDs) == 0 {
|
||||||
|
userIDsForAllDevices = append(userIDsForAllDevices, userID)
|
||||||
|
delete(devKeys, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, userID := range userIDsForAllDevices {
|
||||||
|
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
logrus.ErrorKey: err,
|
||||||
|
"user_id": userID,
|
||||||
|
"server": serverName,
|
||||||
|
}).Error("Failed to manually update device lists for user")
|
||||||
|
// try to do it via /keys/query
|
||||||
|
devKeys[userID] = []string{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||||
|
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||||
|
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
logrus.ErrorKey: err,
|
||||||
|
"user_id": userID,
|
||||||
|
"server": serverName,
|
||||||
|
}).Error("Failed to manually update device lists for user")
|
||||||
|
// try to do it via /keys/query
|
||||||
|
devKeys[userID] = []string{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(devKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
|
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failMu.Lock()
|
failMu.Lock()
|
||||||
|
|
@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
|
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||||
|
) error {
|
||||||
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||||
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||||
|
}
|
||||||
|
if len(keys) < len(deviceIDs) {
|
||||||
|
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
|
||||||
|
}
|
||||||
|
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||||
|
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||||
|
}
|
||||||
|
if res.DeviceKeys[userID] == nil {
|
||||||
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
if len(key.KeyJSON) == 0 {
|
||||||
|
continue // ignore deleted keys
|
||||||
|
}
|
||||||
|
// inject the display name
|
||||||
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||||
|
DisplayName string `json:"device_display_name,omitempty"`
|
||||||
|
}{key.DisplayName})
|
||||||
|
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
var keysToStore []api.DeviceMessage
|
var keysToStore []api.DeviceMessage
|
||||||
// assert that the user ID / device ID are not lying for each key
|
// assert that the user ID / device ID are not lying for each key
|
||||||
|
|
@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.OnlyDisplayNameUpdates {
|
||||||
|
// add the display name field from keysToStore into existingKeys
|
||||||
|
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
||||||
|
}
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage)
|
||||||
}
|
}
|
||||||
return a.Producer.ProduceKeyChanges(keysAdded)
|
return a.Producer.ProduceKeyChanges(keysAdded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
|
||||||
|
for i, existingDevice := range existing {
|
||||||
|
for _, newDevice := range new {
|
||||||
|
if existingDevice.DeviceID != newDevice.DeviceID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingDevice.DisplayName = newDevice.DisplayName
|
||||||
|
existing[i] = existingDevice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,9 @@ type Database interface {
|
||||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||||
// for this (user, device). Does not modify the stream ID for keys.
|
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||||
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||||
|
|
||||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
|
||||||
const countStreamIDsForUserSQL = "" +
|
const countStreamIDsForUserSQL = "" +
|
||||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||||
|
|
||||||
|
const deleteAllDeviceKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
|
|
@ -68,6 +71,7 @@ type deviceKeysStatements struct {
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
countStreamIDsForUserStmt *sql.Stmt
|
countStreamIDsForUserStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
|
||||||
return count == len(prevIDs), nil
|
return count == len(prevIDs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||||
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||||
|
for _, userID := range clearUserIDs {
|
||||||
|
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
|
||||||
const countStreamIDsForUserSQL = "" +
|
const countStreamIDsForUserSQL = "" +
|
||||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||||
|
|
||||||
|
const deleteAllDeviceKeysSQL = "" +
|
||||||
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer *sqlutil.TransactionWriter
|
writer *sqlutil.TransactionWriter
|
||||||
|
|
@ -65,6 +68,7 @@ type deviceKeysStatements struct {
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
|
@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||||
|
_, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ type DeviceKeys interface {
|
||||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||||
|
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyChanges interface {
|
type KeyChanges interface {
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,8 @@ func DeviceListCatchup(
|
||||||
return hasNew, nil
|
return hasNew, nil
|
||||||
}
|
}
|
||||||
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
|
// QueryKeyChanges gets ALL users who have changed keys, we want the ones who share rooms with the user.
|
||||||
queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs)
|
var sharedUsersMap map[string]int
|
||||||
|
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, stateAPI, userID, queryRes.UserIDs)
|
||||||
util.GetLogger(ctx).Debugf(
|
util.GetLogger(ctx).Debugf(
|
||||||
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
|
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
|
||||||
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
|
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
|
||||||
|
|
@ -114,13 +115,20 @@ func DeviceListCatchup(
|
||||||
}
|
}
|
||||||
// if the response has any join/leave events, add them now.
|
// if the response has any join/leave events, add them now.
|
||||||
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||||
for _, userID := range membershipEvents(res) {
|
joinUserIDs, leaveUserIDs := membershipEvents(res)
|
||||||
|
for _, userID := range joinUserIDs {
|
||||||
if !userSet[userID] {
|
if !userSet[userID] {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
hasNew = true
|
hasNew = true
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, userID := range leaveUserIDs {
|
||||||
|
if sharedUsersMap[userID] == 0 {
|
||||||
|
// we no longer share a room with this user when they left, so add to left list.
|
||||||
|
res.DeviceLists.Left = append(res.DeviceLists.Left, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
// set the new token
|
// set the new token
|
||||||
to.SetLog(DeviceListLogName, &types.LogPosition{
|
to.SetLog(DeviceListLogName, &types.LogPosition{
|
||||||
Partition: queryRes.Partition,
|
Partition: queryRes.Partition,
|
||||||
|
|
@ -221,7 +229,7 @@ func TrackChangedUsers(
|
||||||
|
|
||||||
func filterSharedUsers(
|
func filterSharedUsers(
|
||||||
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string,
|
ctx context.Context, stateAPI currentstateAPI.CurrentStateInternalAPI, userID string, usersWithChangedKeys []string,
|
||||||
) []string {
|
) (map[string]int, []string) {
|
||||||
var result []string
|
var result []string
|
||||||
var sharedUsersRes currentstateAPI.QuerySharedUsersResponse
|
var sharedUsersRes currentstateAPI.QuerySharedUsersResponse
|
||||||
err := stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
err := stateAPI.QuerySharedUsers(ctx, ¤tstateAPI.QuerySharedUsersRequest{
|
||||||
|
|
@ -229,7 +237,7 @@ func filterSharedUsers(
|
||||||
}, &sharedUsersRes)
|
}, &sharedUsersRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// default to all users so we do needless queries rather than miss some important device update
|
// default to all users so we do needless queries rather than miss some important device update
|
||||||
return usersWithChangedKeys
|
return nil, usersWithChangedKeys
|
||||||
}
|
}
|
||||||
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
||||||
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
||||||
|
|
@ -241,7 +249,7 @@ func filterSharedUsers(
|
||||||
result = append(result, uid)
|
result = append(result, uid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return sharedUsersRes.UserIDsToCount, result
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinedRooms(res *types.Response, userID string) []string {
|
func joinedRooms(res *types.Response, userID string) []string {
|
||||||
|
|
@ -288,16 +296,16 @@ func membershipEventPresent(events []gomatrixserverlib.ClientEvent, userID strin
|
||||||
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
|
// "For optimal performance, Alice should be added to changed in Bob's sync only when she adds a new device,
|
||||||
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
|
// or when Alice and Bob now share a room but didn't share any room previously. However, for the sake of simpler
|
||||||
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
|
// logic, a server may add Alice to changed when Alice and Bob share a new room, even if they previously already shared a room."
|
||||||
func membershipEvents(res *types.Response) (userIDs []string) {
|
func membershipEvents(res *types.Response) (joinUserIDs, leaveUserIDs []string) {
|
||||||
for _, room := range res.Rooms.Join {
|
for _, room := range res.Rooms.Join {
|
||||||
for _, ev := range room.Timeline.Events {
|
for _, ev := range room.Timeline.Events {
|
||||||
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
|
if ev.Type == gomatrixserverlib.MRoomMember && ev.StateKey != nil {
|
||||||
if strings.Contains(string(ev.Content), `"join"`) {
|
if strings.Contains(string(ev.Content), `"join"`) {
|
||||||
userIDs = append(userIDs, *ev.StateKey)
|
joinUserIDs = append(joinUserIDs, *ev.StateKey)
|
||||||
} else if strings.Contains(string(ev.Content), `"leave"`) {
|
} else if strings.Contains(string(ev.Content), `"leave"`) {
|
||||||
userIDs = append(userIDs, *ev.StateKey)
|
leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
|
||||||
} else if strings.Contains(string(ev.Content), `"ban"`) {
|
} else if strings.Contains(string(ev.Content), `"ban"`) {
|
||||||
userIDs = append(userIDs, *ev.StateKey)
|
leaveUserIDs = append(leaveUserIDs, *ev.StateKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,12 @@ Server correctly handles incoming m.device_list_update
|
||||||
Device deletion propagates over federation
|
Device deletion propagates over federation
|
||||||
If remote user leaves room, changes device and rejoins we see update in sync
|
If remote user leaves room, changes device and rejoins we see update in sync
|
||||||
If remote user leaves room, changes device and rejoins we see update in /keys/changes
|
If remote user leaves room, changes device and rejoins we see update in /keys/changes
|
||||||
|
If remote user leaves room we no longer receive device updates
|
||||||
|
If a device list update goes missing, the server resyncs on the next one
|
||||||
|
Get left notifs in sync and /keys/changes when other user leaves
|
||||||
Can query remote device keys using POST after notification
|
Can query remote device keys using POST after notification
|
||||||
|
Server correctly resyncs when client query keys and there is no remote cache
|
||||||
|
Server correctly resyncs when server leaves and rejoins a room
|
||||||
Can add account data
|
Can add account data
|
||||||
Can add account data to room
|
Can add account data to room
|
||||||
Can get account data without syncing
|
Can get account data without syncing
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,27 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
|
||||||
|
// display name has changed: update the device key
|
||||||
|
var uploadRes keyapi.PerformUploadKeysResponse
|
||||||
|
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||||
|
DeviceKeys: []keyapi.DeviceKeys{
|
||||||
|
{
|
||||||
|
DeviceID: dev.ID,
|
||||||
|
DisplayName: *req.DisplayName,
|
||||||
|
KeyJSON: nil,
|
||||||
|
UserID: dev.UserID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
OnlyDisplayNameUpdates: true,
|
||||||
|
}, &uploadRes)
|
||||||
|
if uploadRes.Error != nil {
|
||||||
|
return fmt.Errorf("Failed to update device key display name: %v", uploadRes.Error)
|
||||||
|
}
|
||||||
|
if len(uploadRes.KeyErrors) > 0 {
|
||||||
|
return fmt.Errorf("Failed to update device key display name, key errors: %+v", uploadRes.KeyErrors)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue