From dcedd1b6bf1e890ff425bdf1fcd8a2e0850778b5 Mon Sep 17 00:00:00 2001 From: devonh Date: Thu, 13 Oct 2022 14:38:13 +0000 Subject: [PATCH 1/3] Federation backoff fixes and tests (#2792) This fixes some edge cases where federation queue backoffs and blacklisting weren't behaving as expected. It also adds new tests for the federation queues to ensure their behaviour continues to work correctly. --- federationapi/queue/destinationqueue.go | 10 + federationapi/queue/queue.go | 1 + federationapi/queue/queue_test.go | 422 ++++++++++++++++++++++++ federationapi/statistics/statistics.go | 8 +- go.mod | 2 +- 5 files changed, 441 insertions(+), 2 deletions(-) create mode 100644 federationapi/queue/queue_test.go diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 5cb8cae1f..4ae554ef3 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -75,6 +75,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } + // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -108,6 +109,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re case oq.notify <- struct{}{}: default: } + } else { + oq.overflowed.Store(true) } } @@ -153,6 +156,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share case oq.notify <- struct{}{}: default: } + } else { + oq.overflowed.Store(true) } } @@ -335,6 +340,11 @@ func (oq *destinationQueue) backgroundSend() { // We failed to send the transaction. Mark it as a failure. oq.statistics.Failure() + // Queue up another attempt since the transaction failed. + select { + case oq.notify <- struct{}{}: + default: + } } else if transaction { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 8245aa5bd..5d352eca6 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -332,6 +332,7 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } + oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { queue.wakeQueueIfNeeded() } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go new file mode 100644 index 000000000..8e4a675f4 --- /dev/null +++ b/federationapi/queue/queue_test.go @@ -0,0 +1,422 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package queue + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + "gotest.tools/v3/poll" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/federationapi/storage" + "github.com/matrix-org/dendrite/federationapi/storage/shared" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var dbMutex sync.Mutex + +type fakeDatabase struct { + storage.Database + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} +} + +func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + receipt := &shared.Receipt{} + d.pendingPDUs[receipt] = &event + return receipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + receipt := &shared.Receipt{} + d.pendingEDUs[receipt] = &edu + return receipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingPDUs[receipt]; ok { + pdus[receipt] = event + } + } + } + return pdus, nil +} + +func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for receipt := range receipts { + if event, ok := d.pendingEDUs[receipt]; ok { + edus[receipt] = event + } + } + } + return edus, nil +} + +func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + if _, ok := d.pendingPDUs[receipt]; ok { + if _, ok := d.associatedPDUs[serverName]; !ok { + d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedPDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + if _, ok := d.pendingEDUs[receipt]; ok { + if _, ok := d.associatedEDUs[serverName]; !ok { + d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) + } + d.associatedEDUs[serverName][receipt] = struct{}{} + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, receipt := range receipts { + delete(pdus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, receipt := range receipts { + delete(edus, receipt) + } + } + + return nil +} + +func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { + dbMutex.Lock() + defer dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { + dbMutex.Lock() + defer dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { + dbMutex.Lock() + defer dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +type stubFederationRoomServerAPI struct { + rsapi.FederationRoomserverAPI +} + +func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { + res.Banned = false + return nil +} + +type stubFederationClient struct { + api.FederationClient + shouldTxSucceed bool + txCount atomic.Uint32 +} + +func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { + var result error + if !f.shouldTxSucceed { + result = fmt.Errorf("transaction failed") + } + + f.txCount.Add(1) + return gomatrixserverlib.RespSend{}, result +} + +func createDatabase() storage.Database { + return &fakeDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), + } +} + +func mustCreateEvent(t *testing.T) *gomatrixserverlib.HeaderedEvent { + t.Helper() + content := `{"type":"m.room.message"}` + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool) (storage.Database, *stubFederationClient, *OutgoingQueues) { + db := createDatabase() + + fc := &stubFederationClient{ + shouldTxSucceed: shouldTxSucceed, + txCount: *atomic.NewUint32(0), + } + rs := &stubFederationRoomServerAPI{} + stats := &statistics.Statistics{ + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + } + signingInfo := &SigningInfo{ + KeyID: "ed25519:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + } + queues := NewOutgoingQueues(db, process.NewProcessContext(), false, "localhost", fc, rs, stats, signingInfo) + + return db, fc, queues +} + +func TestSendTransactionOnSuccessRemovedFromDB(t *testing.T) { + ctx := context.Background() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues := testSetup(failuresUntilBlacklist, true) + + ev := mustCreateEvent(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= 1 { + data, err := db.GetPendingPDUs(ctx, destination, 100) + assert.NoError(t, err) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database") + } + return poll.Continue("waiting for more send attempts before checking database") + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendTransactionOnFailStoredInDB(t *testing.T) { + ctx := context.Background() + failuresUntilBlacklist := uint32(16) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues := testSetup(failuresUntilBlacklist, false) + + ev := mustCreateEvent(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending + if fc.txCount.Load() >= 2 { + data, err := db.GetPendingPDUs(ctx, destination, 100) + assert.NoError(t, err) + if len(data) == 1 { + return poll.Success() + } + return poll.Continue("waiting for event to be added to database") + } + return poll.Continue("waiting for more send attempts before checking database") + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendTransactionMultipleFailuresBlacklisted(t *testing.T) { + ctx := context.Background() + failuresUntilBlacklist := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues := testSetup(failuresUntilBlacklist, false) + + ev := mustCreateEvent(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= failuresUntilBlacklist { + data, err := db.GetPendingPDUs(ctx, destination, 100) + assert.NoError(t, err) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database") + } + return poll.Continue("waiting for more send attempts before checking database") + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestRetryServerSendsSuccessfully(t *testing.T) { + ctx := context.Background() + failuresUntilBlacklist := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues := testSetup(failuresUntilBlacklist, false) + + ev := mustCreateEvent(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + checkBlacklisted := func(log poll.LogT) poll.Result { + if fc.txCount.Load() >= failuresUntilBlacklist { + data, err := db.GetPendingPDUs(ctx, destination, 100) + assert.NoError(t, err) + if len(data) == 1 { + if val, _ := db.IsServerBlacklisted(destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be blacklisted") + } + return poll.Continue("waiting for event to be added to database") + } + return poll.Continue("waiting for more send attempts before checking database") + } + poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + fc.shouldTxSucceed = true + db.RemoveServerFromBlacklist(destination) + queues.RetryServer(destination) + checkRetry := func(log poll.LogT) poll.Result { + data, err := db.GetPendingPDUs(ctx, destination, 100) + assert.NoError(t, err) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database") + } + poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index db6d5c735..61a965791 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -95,8 +95,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.successCounter.Inc() s.backoffCount.Store(0) + s.successCounter.Inc() if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -174,6 +174,12 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } +// RemoveBlacklist removes the blacklisted status from the server. +func (s *ServerStatistics) RemoveBlacklist() { + s.cancel() + s.backoffCount.Store(0) +} + // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/go.mod b/go.mod index eefad89e6..eeae9608f 100644 --- a/go.mod +++ b/go.mod @@ -50,6 +50,7 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 + gotest.tools/v3 v3.0.3 nhooyr.io/websocket v1.8.7 ) @@ -128,7 +129,6 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gotest.tools/v3 v3.0.3 // indirect ) go 1.18 From f3be4b31850add1a33c932aa3fa7b0bb740554f2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 13 Oct 2022 16:06:50 +0100 Subject: [PATCH 2/3] Revert "Federation backoff fixes and tests (#2792)" This reverts commit dcedd1b6bf1e890ff425bdf1fcd8a2e0850778b5. --- federationapi/queue/destinationqueue.go | 10 - federationapi/queue/queue.go | 1 - federationapi/queue/queue_test.go | 422 ------------------------ federationapi/statistics/statistics.go | 8 +- go.mod | 2 +- 5 files changed, 2 insertions(+), 441 deletions(-) delete mode 100644 federationapi/queue/queue_test.go diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 4ae554ef3..5cb8cae1f 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -75,7 +75,6 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } - // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -109,8 +108,6 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re case oq.notify <- struct{}{}: default: } - } else { - oq.overflowed.Store(true) } } @@ -156,8 +153,6 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share case oq.notify <- struct{}{}: default: } - } else { - oq.overflowed.Store(true) } } @@ -340,11 +335,6 @@ func (oq *destinationQueue) backgroundSend() { // We failed to send the transaction. Mark it as a failure. oq.statistics.Failure() - // Queue up another attempt since the transaction failed. - select { - case oq.notify <- struct{}{}: - default: - } } else if transaction { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 5d352eca6..8245aa5bd 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -332,7 +332,6 @@ func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } - oqs.statistics.ForServer(srv).RemoveBlacklist() if queue := oqs.getQueue(srv); queue != nil { queue.wakeQueueIfNeeded() } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go deleted file mode 100644 index 8e4a675f4..000000000 --- a/federationapi/queue/queue_test.go +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package queue - -import ( - "context" - "encoding/json" - "fmt" - "sync" - "testing" - "time" - - "go.uber.org/atomic" - "gotest.tools/v3/poll" - - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/federationapi/statistics" - "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" - rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -var dbMutex sync.Mutex - -type fakeDatabase struct { - storage.Database - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - receipt := &shared.Receipt{} - d.pendingPDUs[receipt] = &event - return receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - receipt := &shared.Receipt{} - d.pendingEDUs[receipt] = &edu - return receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - if _, ok := d.associatedPDUs[serverName]; !ok { - d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[serverName][receipt] = struct{}{} - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - if _, ok := d.associatedEDUs[serverName]; !ok { - d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[serverName][receipt] = struct{}{} - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - dbMutex.Lock() - defer dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - dbMutex.Lock() - defer dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - dbMutex.Lock() - defer dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - -type stubFederationRoomServerAPI struct { - rsapi.FederationRoomserverAPI -} - -func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Context, req *rsapi.QueryServerBannedFromRoomRequest, res *rsapi.QueryServerBannedFromRoomResponse) error { - res.Banned = false - return nil -} - -type stubFederationClient struct { - api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 -} - -func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { - var result error - if !f.shouldTxSucceed { - result = fmt.Errorf("transaction failed") - } - - f.txCount.Add(1) - return gomatrixserverlib.RespSend{}, result -} - -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -func mustCreateEvent(t *testing.T) *gomatrixserverlib.HeaderedEvent { - t.Helper() - content := `{"type":"m.room.message"}` - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) - if err != nil { - t.Fatalf("failed to create event: %v", err) - } - return ev.Headered(gomatrixserverlib.RoomVersionV10) -} - -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool) (storage.Database, *stubFederationClient, *OutgoingQueues) { - db := createDatabase() - - fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), - } - rs := &stubFederationRoomServerAPI{} - stats := &statistics.Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - } - signingInfo := &SigningInfo{ - KeyID: "ed25519:auto", - PrivateKey: test.PrivateKeyA, - ServerName: "localhost", - } - queues := NewOutgoingQueues(db, process.NewProcessContext(), false, "localhost", fc, rs, stats, signingInfo) - - return db, fc, queues -} - -func TestSendTransactionOnSuccessRemovedFromDB(t *testing.T) { - ctx := context.Background() - failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues := testSetup(failuresUntilBlacklist, true) - - ev := mustCreateEvent(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) - assert.NoError(t, err) - - check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() >= 1 { - data, err := db.GetPendingPDUs(ctx, destination, 100) - assert.NoError(t, err) - if len(data) == 0 { - return poll.Success() - } - return poll.Continue("waiting for event to be removed from database") - } - return poll.Continue("waiting for more send attempts before checking database") - } - poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} - -func TestSendTransactionOnFailStoredInDB(t *testing.T) { - ctx := context.Background() - failuresUntilBlacklist := uint32(16) - destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues := testSetup(failuresUntilBlacklist, false) - - ev := mustCreateEvent(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) - assert.NoError(t, err) - - check := func(log poll.LogT) poll.Result { - // Wait for 2 backoff attempts to ensure there was adequate time to attempt sending - if fc.txCount.Load() >= 2 { - data, err := db.GetPendingPDUs(ctx, destination, 100) - assert.NoError(t, err) - if len(data) == 1 { - return poll.Success() - } - return poll.Continue("waiting for event to be added to database") - } - return poll.Continue("waiting for more send attempts before checking database") - } - poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} - -func TestSendTransactionMultipleFailuresBlacklisted(t *testing.T) { - ctx := context.Background() - failuresUntilBlacklist := uint32(2) - destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues := testSetup(failuresUntilBlacklist, false) - - ev := mustCreateEvent(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) - assert.NoError(t, err) - - check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() >= failuresUntilBlacklist { - data, err := db.GetPendingPDUs(ctx, destination, 100) - assert.NoError(t, err) - if len(data) == 1 { - if val, _ := db.IsServerBlacklisted(destination); val { - return poll.Success() - } - return poll.Continue("waiting for server to be blacklisted") - } - return poll.Continue("waiting for event to be added to database") - } - return poll.Continue("waiting for more send attempts before checking database") - } - poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} - -func TestRetryServerSendsSuccessfully(t *testing.T) { - ctx := context.Background() - failuresUntilBlacklist := uint32(1) - destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues := testSetup(failuresUntilBlacklist, false) - - ev := mustCreateEvent(t) - err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) - assert.NoError(t, err) - - checkBlacklisted := func(log poll.LogT) poll.Result { - if fc.txCount.Load() >= failuresUntilBlacklist { - data, err := db.GetPendingPDUs(ctx, destination, 100) - assert.NoError(t, err) - if len(data) == 1 { - if val, _ := db.IsServerBlacklisted(destination); val { - return poll.Success() - } - return poll.Continue("waiting for server to be blacklisted") - } - return poll.Continue("waiting for event to be added to database") - } - return poll.Continue("waiting for more send attempts before checking database") - } - poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) - - fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) - checkRetry := func(log poll.LogT) poll.Result { - data, err := db.GetPendingPDUs(ctx, destination, 100) - assert.NoError(t, err) - if len(data) == 0 { - return poll.Success() - } - return poll.Continue("waiting for event to be removed from database") - } - poll.WaitOn(t, checkRetry, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 61a965791..db6d5c735 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -95,8 +95,8 @@ func (s *ServerStatistics) cancel() { // we will unblacklist it. func (s *ServerStatistics) Success() { s.cancel() - s.backoffCount.Store(0) s.successCounter.Inc() + s.backoffCount.Store(0) if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -174,12 +174,6 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { - s.cancel() - s.backoffCount.Store(0) -} - // SuccessCount returns the number of successful requests. This is // usually useful in constructing transaction IDs. func (s *ServerStatistics) SuccessCount() uint32 { diff --git a/go.mod b/go.mod index eeae9608f..eefad89e6 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,6 @@ require ( golang.org/x/term v0.0.0-20220919170432-7a66f970e087 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 - gotest.tools/v3 v3.0.3 nhooyr.io/websocket v1.8.7 ) @@ -129,6 +128,7 @@ require ( gopkg.in/macaroon.v2 v2.1.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gotest.tools/v3 v3.0.3 // indirect ) go 1.18 From 088ad1dd21a06ad3c4ae2db2a73936ed3e0d809e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 14 Oct 2022 09:14:54 +0200 Subject: [PATCH 3/3] Fix `outliers whose auth_events are in a different room are correctly rejected` (#2791) Fixes `outliers whose auth_events are in a different room are correctly rejected`, by validating that auth events are all from the same room and not using rejected events for event auth. --- go.mod | 2 +- go.sum | 4 +- roomserver/internal/helpers/auth.go | 20 +++++- roomserver/internal/input/input_events.go | 31 ++++++--- .../internal/input/input_events_test.go | 63 +++++++++++++++++++ sytest-whitelist | 3 +- test/event.go | 7 +++ test/room.go | 9 ++- 8 files changed, 124 insertions(+), 15 deletions(-) create mode 100644 roomserver/internal/input/input_events_test.go diff --git a/go.mod b/go.mod index eefad89e6..23d5655d8 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 0d08ac692..a1069da37 100644 --- a/go.sum +++ b/go.sum @@ -384,8 +384,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64 h1:QJmfAPC3P0ZHJzYD/QtbNc5EztKlK1ipRWP5SO/m4jw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241 h1:e5o68MWeU7wjTvvNKmVo655oCYesoNRoPeBb1Xfz54g= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221014061925-a132619fa241/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c h1:iCHLYwwlPsf4TYFrvhKdhQoAM2lXzcmDZYqwBNWcnVk= github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 935a045df..03d8bca0b 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -19,10 +19,11 @@ import ( "fmt" "sort" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) // CheckForSoftFail returns true if the event should be soft-failed @@ -129,6 +130,12 @@ type authEvents struct { stateKeyNIDMap map[string]types.EventStateKeyNID state stateEntryMap events EventMap + valid bool +} + +// Valid verifies that all auth events are from the same room. +func (ae *authEvents) Valid() bool { + return ae.valid } // Create implements gomatrixserverlib.AuthEventProvider @@ -197,6 +204,7 @@ func loadAuthEvents( needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { + result.valid = true // Look up the numeric IDs for the state keys needed for auth. var neededStateKeys []string neededStateKeys = append(neededStateKeys, needed.Member...) @@ -218,6 +226,16 @@ func loadAuthEvents( if result.events, err = db.Events(ctx, eventNIDs); err != nil { return } + roomID := "" + for _, ev := range result.events { + if roomID == "" { + roomID = ev.RoomID() + } + if ev.RoomID() != roomID { + result.valid = false + break + } + } return } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index d1b6bc73e..60160e8e5 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,9 +19,16 @@ package input import ( "context" "database/sql" + "errors" "fmt" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -31,11 +38,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/opentracing/opentracing-go" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" ) // TODO: Does this value make sense? @@ -196,7 +198,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -336,7 +338,7 @@ func (r *Inputer) processRoomEvent( // doesn't have any associated state to store and we don't need to // notify anyone about it. if input.Kind == api.KindOutlier { - logger.Debug("Stored outlier") + logger.WithField("rejected", isRejected).Debug("Stored outlier") hooks.Run(hooks.KindNewEventPersisted, headered) return nil } @@ -536,6 +538,7 @@ func (r *Inputer) processStateBefore( func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, @@ -557,9 +560,19 @@ func (r *Inputer) fetchAuthEvents( continue } ev := authEvents[0] + + isRejected := false + if roomInfo != nil { + isRejected, err = r.DB.IsEventRejected(ctx, roomInfo.RoomNID, ev.EventID()) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("r.DB.IsEventRejected failed: %w", err) + } + } known[authEventID] = &ev // don't take the pointer of the iterated event - if err = auth.AddEvent(ev.Event); err != nil { - return fmt.Errorf("auth.AddEvent: %w", err) + if !isRejected { + if err = auth.AddEvent(ev.Event); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } } } diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go new file mode 100644 index 000000000..818e7715c --- /dev/null +++ b/roomserver/internal/input/input_events_test.go @@ -0,0 +1,63 @@ +package input + +import ( + "testing" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/test" +) + +func Test_EventAuth(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + + // create two rooms, so we can craft "illegal" auth events + room1 := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + + authEventIDs := make([]string, 0, 4) + authEvents := []*gomatrixserverlib.Event{} + + // Add the legal auth events from room2 + for _, x := range room2.Events() { + if x.Type() == gomatrixserverlib.MRoomCreate { + authEventIDs = append(authEventIDs, x.EventID()) + authEvents = append(authEvents, x.Event) + } + if x.Type() == gomatrixserverlib.MRoomPowerLevels { + authEventIDs = append(authEventIDs, x.EventID()) + authEvents = append(authEvents, x.Event) + } + if x.Type() == gomatrixserverlib.MRoomJoinRules { + authEventIDs = append(authEventIDs, x.EventID()) + authEvents = append(authEvents, x.Event) + } + } + + // Add the illegal auth event from room1 (rooms are different) + for _, x := range room1.Events() { + if x.Type() == gomatrixserverlib.MRoomMember { + authEventIDs = append(authEventIDs, x.EventID()) + authEvents = append(authEvents, x.Event) + } + } + + // Craft the illegal join event, with auth events from different rooms + ev := room2.CreateEvent(t, bob, "m.room.member", map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID), test.WithAuthIDs(authEventIDs)) + + // Add the auth events to the allower + allower := gomatrixserverlib.NewAuthEvents(nil) + for _, a := range authEvents { + if err := allower.AddEvent(a); err != nil { + t.Fatalf("allower.AddEvent failed: %v", err) + } + } + + // Finally check that the event is NOT allowed + if err := gomatrixserverlib.Allowed(ev.Event, &allower); err == nil { + t.Fatalf("event should not be allowed, but it was") + } +} diff --git a/sytest-whitelist b/sytest-whitelist index a3218ed70..2bd8b9403 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -743,4 +743,5 @@ User joining then leaving public room appears and dissappears from directory User in remote room doesn't appear in user directory after server left room User in shared private room does appear in user directory until leave Existing members see new member's presence -Inbound federation can return missing events for joined visibility \ No newline at end of file +Inbound federation can return missing events for joined visibility +outliers whose auth_events are in a different room are correctly rejected \ No newline at end of file diff --git a/test/event.go b/test/event.go index 73fc656bd..0c7bf4355 100644 --- a/test/event.go +++ b/test/event.go @@ -30,6 +30,7 @@ type eventMods struct { unsigned interface{} keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey + authEvents []string } type eventModifier func(e *eventMods) @@ -52,6 +53,12 @@ func WithUnsigned(unsigned interface{}) eventModifier { } } +func WithAuthIDs(evs []string) eventModifier { + return func(e *eventMods) { + e.authEvents = evs + } +} + func WithKeyID(keyID gomatrixserverlib.KeyID) eventModifier { return func(e *eventMods) { e.keyID = keyID diff --git a/test/room.go b/test/room.go index 94eb51bbe..4328bf84f 100644 --- a/test/room.go +++ b/test/room.go @@ -21,8 +21,9 @@ import ( "testing" "time" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/eventutil" ) type Preset int @@ -174,11 +175,17 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err != nil { t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err) } + refs, err := eventsNeeded.AuthEventReferences(&r.authEvents) if err != nil { t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err) } builder.AuthEvents = refs + + if len(mod.authEvents) > 0 { + builder.AuthEvents = mod.authEvents + } + ev, err := builder.Build( mod.originServerTS, mod.origin, mod.keyID, mod.privKey, r.Version,