From 8dc95062101b3906ffb83604e2abca02d9a3dd03 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Mon, 14 Sep 2020 16:39:38 +0100 Subject: [PATCH 1/4] Don't use more than 999 variables in SQLite querys. (#1425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Don't use more than 999 variables in SQLite querys. Solve this problem in a more general and reusable way. Also fix #1369 Add some unit tests. Signed-off-by: Henrik Sölver * Don't rely on testify for basic assertions * Readability improvements and linting Co-authored-by: Henrik Sölver --- go.mod | 1 + go.sum | 2 + internal/sqlutil/sql.go | 45 +++++ internal/sqlutil/sqlutil_test.go | 173 ++++++++++++++++++ .../storage/sqlite3/server_key_table.go | 67 +++---- 5 files changed, 255 insertions(+), 33 deletions(-) create mode 100644 internal/sqlutil/sqlutil_test.go diff --git a/go.mod b/go.mod index f1cb3c9be..6b1c03b5b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/matrix-org/dendrite require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Shopify/sarama v1.27.0 github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/gologme/log v1.2.0 diff --git a/go.sum b/go.sum index ac7827d9f..5c4f27a5d 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 1d2825d56..90562ded3 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -15,10 +15,14 @@ package sqlutil import ( + "context" "database/sql" "errors" "fmt" "runtime" + "strings" + + "github.com/matrix-org/util" ) // ErrUserExists is returned if a username already exists in the database. @@ -107,3 +111,44 @@ func SQLiteDriverName() string { } return "sqlite3" } + +func minOfInts(a, b int) int { + if a <= b { + return a + } + return b +} + +// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery. +type QueryProvider interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement +// SQLlite can handle. See https://www.sqlite.org/limits.html for more information. +const SQLite3MaxVariables = 999 + +// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error") + return err + } + err = rowHandler(rows) + if closeErr := rows.Close(); closeErr != nil { + util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows") + return err + } + if err != nil { + util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error") + return err + } + start = start + n + } + return nil +} diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go new file mode 100644 index 000000000..79469cddc --- /dev/null +++ b/internal/sqlutil/sqlutil_test.go @@ -0,0 +1,173 @@ +package sqlutil + +import ( + "context" + "database/sql" + "reflect" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" +) + +func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 3 long") + } +} + +func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 4 long") + } +} + +func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r1 := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + r2 := mock.NewRows([]string{"id"}). + AddRow(5) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4, 5} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 5 long") + } + if !reflect.DeepEqual(v, result) { + t.Fatalf("Result is not as expected: got %v want %v", v, result) + } +} + +func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + // adding a string ID should result in rows.Scan returning an error + r := mock.NewRows([]string{"id"}). + AddRow("hej"). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{-1, -2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]uint, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id uint + err = rows.Scan(&id) + if err != nil { + return err + } + result = append(result, id) + } + return nil + }) + if err == nil { + t.Fatalf("Call did not return an error") + } +} + +func assertNoError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { + return + } + t.Fatalf(msg) +} diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index f756ef5e3..2484d6368 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -18,9 +18,8 @@ package sqlite3 import ( "context" "database/sql" - "strings" + "fmt" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -88,48 +87,50 @@ func (s *serverKeyStatements) bulkSelectServerKeys( ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - var nameAndKeyIDs []string + nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) - + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { iKeyIDs[i] = v } - rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) + err := sqlutil.RunLimitedVariablesQuery( + ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, + func(rows *sql.Rows) error { + for rows.Next() { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + return nil + }, + ) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") - results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for rows.Next() { - var serverName string - var keyID string - var key string - var validUntilTS int64 - var expiredTS int64 - if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { - return nil, err - } - r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), - KeyID: gomatrixserverlib.KeyID(keyID), - } - vk := gomatrixserverlib.VerifyKey{} - err = vk.Key.Decode(key) - if err != nil { - return nil, err - } - results[r] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), - } - } return results, nil } From 965f068d1a6298b2ec733b0df983773a6ec8b622 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Sep 2020 11:17:46 +0100 Subject: [PATCH 2/4] Handle state with input event as new events (#1415) * SendEventWithState events as new * Use cumulative state IDs for final event * Error wrapping in calculateAndSetState * Handle overwriting same event type and state key * Hacky way to spot historical events * Don't exclude from sync * Don't generate output events when rewriting forward extremities * Update output event check * Historical output events * Define output room event type * Notify key changes on state * Don't send our membership event twice * Deduplicate state entries * Tweaks * Remove unnecessary nolint * Fix current state upsert in sync API * Send auth events as outliers, state events as rewrite * Sync API don't consume state events * Process events actually * Improve outlier check * Fix local room check * Remove extra room check, it seems to break the whole damn world * Fix federated join check * Fix nil pointer exception * Better comments on DeduplicateStateEntries * Reflow forced federated joins * Don't force federated join for possibly even local invites * Comment SendEventWithState better * Rewrite room state in sync API storage * Add TODO * Clean up all room data when receiving create event * Don't generate output events for rewrites, but instead notify that state is rewritten on the final new event * Rename to PurgeRoom * Exclude backfilled messages from /sync * Split out rewriting state from updating state from state res Co-authored-by: Kegan Dougal --- federationsender/internal/perform.go | 7 +- roomserver/api/input.go | 4 + roomserver/api/output.go | 14 + roomserver/api/wrapper.go | 101 ++++++++ roomserver/internal/helpers/auth.go | 2 +- roomserver/internal/input/input_events.go | 26 +- .../internal/input/input_latest_events.go | 12 +- roomserver/internal/perform/perform_join.go | 30 +-- roomserver/roomserver_test.go | 245 +++++++++++++++++- roomserver/types/types.go | 21 ++ roomserver/types/types_test.go | 26 ++ syncapi/consumers/roomserver.go | 6 + syncapi/storage/interface.go | 3 + .../postgres/backwards_extremities_table.go | 15 ++ .../postgres/current_room_state_table.go | 15 ++ .../postgres/output_room_events_table.go | 14 + .../output_room_events_topology_table.go | 15 ++ syncapi/storage/shared/syncserver.go | 23 ++ .../sqlite3/backwards_extremities_table.go | 15 ++ .../sqlite3/current_room_state_table.go | 17 +- .../sqlite3/output_room_events_table.go | 14 + .../output_room_events_topology_table.go | 14 + syncapi/storage/tables/interface.go | 7 + 23 files changed, 616 insertions(+), 30 deletions(-) create mode 100644 roomserver/types/types_test.go diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 90abae236..a0abf7ff0 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -98,7 +98,10 @@ func (r *FederationSenderInternalAPI) PerformJoin( response.LastError = &gomatrix.HTTPError{ Code: 0, WrappedError: nil, - Message: lastErr.Error(), + Message: "Unknown HTTP error", + } + if lastErr != nil { + response.LastError.Message = lastErr.Error() } } @@ -195,7 +198,7 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithState( + if err = roomserverAPI.SendEventWithRewrite( ctx, r.rsAPI, respState, event.Headered(respMakeJoin.RoomVersion), diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 73c4994a7..651c0e9f8 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -33,6 +33,10 @@ const ( // KindBackfill event extend the contiguous graph going backwards. // They always have state. KindBackfill = 3 + // KindRewrite events are used when rewriting the head of the room + // graph with entirely new state. The output events generated will + // be state events rather than timeline events. + KindRewrite = 4 ) // DoNotSendToOtherServers tells us not to send the event to other matrix diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 013ebdc83..d57f3b04c 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -68,6 +68,17 @@ type OutputEvent struct { NewPeek *OutputNewPeek `json:"new_peek,omitempty"` } +// Type of the OutputNewRoomEvent. +type OutputRoomEventType int + +const ( + // The event is a timeline event and likely just happened. + OutputRoomTimeline OutputRoomEventType = iota + + // The event is a state event and quite possibly happened in the past. + OutputRoomState +) + // An OutputNewRoomEvent is written when the roomserver receives a new event. // It contains the full matrix room event and enough information for a // consumer to construct the current state of the room and the state before the @@ -80,6 +91,9 @@ type OutputEvent struct { type OutputNewRoomEvent struct { // The Event. Event gomatrixserverlib.HeaderedEvent `json:"event"` + // Does the event completely rewrite the room state? If so, then AddsStateEventIDs + // will contain the entire room state. + RewritesState bool `json:"rewrites_state"` // The latest events in the room after this event. // This can be used to set the prev events for new events in the room. // This also can be used to get the full current state after this event. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 82a4a5719..e53393119 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -80,6 +80,107 @@ func SendEventWithState( return SendInputRoomEvents(ctx, rsAPI, ires) } +// SendEventWithRewrite writes an event with KindNew to the roomserver along +// with a number of rewrite and outlier events for state and auth events +// respectively. +func SendEventWithRewrite( + ctx context.Context, rsAPI RoomserverInternalAPI, state *gomatrixserverlib.RespState, + event gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool, +) error { + isCurrentState := map[string]struct{}{} + for _, se := range state.StateEvents { + isCurrentState[se.EventID()] = struct{}{} + } + + authAndStateEvents, err := state.Events() + if err != nil { + return err + } + + var ires []InputRoomEvent + var stateIDs []string + + // This function generates three things: + // A - A set of "rewrite" events, which will form the newly rewritten + // state before the event, which includes every rewrite event that + // came before it in its state + // B - A set of "outlier" events, which are auth events but not part + // of the rewritten state + // C - A "new" event, which include all of the rewrite events in its + // state + for _, authOrStateEvent := range authAndStateEvents { + if authOrStateEvent.StateKey() == nil { + continue + } + if haveEventIDs[authOrStateEvent.EventID()] { + continue + } + if event.StateKey() == nil { + continue + } + + // We will handle an event as if it's an outlier if one of the + // following conditions is true: + storeAsOutlier := false + if authOrStateEvent.Type() == event.Type() && *authOrStateEvent.StateKey() == *event.StateKey() { + // The event is a state event but the input event is going to + // replace it, therefore it can't be added to the state or we'll + // get duplicate state keys in the state block. We'll send it + // as an outlier because we don't know if something will be + // referring to it as an auth event, but need it to be stored + // just in case. + storeAsOutlier = true + } else if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok { + // The event is an auth event and isn't a part of the state set. + // We'll send it as an outlier because we need it to be stored + // in case something is referring to it as an auth event. + storeAsOutlier = true + } + + if storeAsOutlier { + ires = append(ires, InputRoomEvent{ + Kind: KindOutlier, + Event: authOrStateEvent.Headered(event.RoomVersion), + AuthEventIDs: authOrStateEvent.AuthEventIDs(), + }) + continue + } + + // If the event isn't an outlier then we'll instead send it as a + // rewrite event, so that it'll form part of the rewritten state. + // These events will go through the membership and latest event + // updaters and we will generate output events, but they will be + // flagged as non-current (i.e. didn't just happen) events. + // Each of these rewrite events includes all of the rewrite events + // that came before in their StateEventIDs. + ires = append(ires, InputRoomEvent{ + Kind: KindRewrite, + Event: authOrStateEvent.Headered(event.RoomVersion), + AuthEventIDs: authOrStateEvent.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + + // Add the event ID into the StateEventIDs of all subsequent + // rewrite events, and the new event. + stateIDs = append(stateIDs, authOrStateEvent.EventID()) + } + + // Send the final event as a new event, which will generate + // a timeline output event for it. All of the rewrite events + // that came before will be sent as StateEventIDs, forming a + // new clean state before the event. + ires = append(ires, InputRoomEvent{ + Kind: KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + + return SendInputRoomEvents(ctx, rsAPI, ires) +} + // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 060f0a0e9..524a54510 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -36,7 +36,7 @@ func CheckAuthEvents( if err != nil { return nil, err } - // TODO: check for duplicate state keys here. + authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 6ee679da6..daf1afcd3 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -86,7 +86,7 @@ func (r *Inputer) processRoomEvent( "event_id": event.EventID(), "type": event.Type(), "room": event.RoomID(), - }).Info("Stored outlier") + }).Debug("Stored outlier") return event.EventID(), nil } @@ -107,6 +107,15 @@ func (r *Inputer) processRoomEvent( } } + if input.Kind == api.KindRewrite { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Debug("Stored rewrite") + return event.EventID(), nil + } + if err = r.updateLatestEvents( ctx, // context roomInfo, // room info for the room being updated @@ -114,6 +123,7 @@ func (r *Inputer) processRoomEvent( event, // event input.SendAsServer, // send as server input.TransactionID, // transaction ID + input.HasState, // rewrites state? ); err != nil { return "", fmt.Errorf("r.updateLatestEvents: %w", err) } @@ -167,19 +177,25 @@ func (r *Inputer) calculateAndSetState( // Check that those state events are in the database and store the state. var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return err + return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) } + entries = types.DeduplicateStateEntries(entries) if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return err + return fmt.Errorf("r.DB.AddState: %w", err) } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { - return err + return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err) } } - return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + + err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + if err != nil { + return fmt.Errorf("r.DB.SetState: %w", err) + } + return nil } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 67a7d8a40..5c2a1de6a 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -54,6 +54,7 @@ func (r *Inputer) updateLatestEvents( event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, + rewritesState bool, ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { @@ -71,6 +72,7 @@ func (r *Inputer) updateLatestEvents( event: event, sendAsServer: sendAsServer, transactionID: transactionID, + rewritesState: rewritesState, } if err = u.doUpdateLatestEvents(); err != nil { @@ -93,6 +95,7 @@ type latestEventsUpdater struct { stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID + rewritesState bool // Which server to send this event as. sendAsServer string // The eventID of the event that was processed before this one. @@ -178,7 +181,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.updateMemberships: %w", err) } - update, err := u.makeOutputNewRoomEvent() + var update *api.OutputEvent + update, err = u.makeOutputNewRoomEvent() if err != nil { return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } @@ -305,6 +309,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) ore := api.OutputNewRoomEvent{ Event: u.event.Headered(u.roomInfo.RoomVersion), + RewritesState: u.rewritesState, LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, @@ -337,6 +342,11 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) } } + // State is rewritten if the input room event HasState and we actually produced a delta on state events. + // Without this check, /get_missing_events which produce events with associated (but not complete) state + // will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced + // a state conflict res which just so happens to include 2+ events we might purge the room state downstream. + ore.RewritesState = len(ore.AddsStateEventIDs) > 1 return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 3d1942272..f76806c73 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -183,33 +183,33 @@ func (r *Joiner) performJoinRoomByID( return "", fmt.Errorf("eb.SetContent: %w", err) } - // First work out if this is in response to an existing invite - // from a federated server. If it is then we avoid the situation - // where we might think we know about a room in the following - // section but don't know the latest state as all of our users - // have left. + // Force a federated join if we aren't in the room and we've been + // given some server names to try joining by. serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) + forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom + + // Force a federated join if we're dealing with a pending invite + // and we aren't in the room. isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) - if err == nil && isInvitePending && !serverInRoom { - // Check if there's an invite pending. + if err == nil && isInvitePending { _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) if ierr != nil { return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - // Check that the domain isn't ours. If it's local then we don't - // need to do anything as our own copy of the room state will be - // up-to-date. + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. if inviterDomain != r.Cfg.Matrix.ServerName { - // Add the server of the person who invited us to the server list, - // as they should be a fairly good bet. req.ServerNames = append(req.ServerNames, inviterDomain) - - // Perform a federated room join. - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + forceFederatedJoin = true } } + // If we should do a forced federated join then do that. + if forceFederatedJoin { + return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + } + // Try to construct an actual join event from the template. // If this succeeds then it is a sign that the room already exists // locally on the homeserver. diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 786d4f31f..5a67a1bed 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -1,12 +1,15 @@ package roomserver import ( + "bytes" "context" + "crypto/ed25519" "encoding/json" "fmt" "os" "reflect" "testing" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal/caching" @@ -80,7 +83,73 @@ func deleteDatabase() { } } -func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { +type fledglingEvent struct { + Type string + StateKey *string + Content interface{} + Sender string + RoomID string +} + +func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []gomatrixserverlib.HeaderedEvent) { + t.Helper() + depth := int64(1) + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + var prevs []string + roomState := make(map[gomatrixserverlib.StateKeyTuple]string) // state -> event ID + for _, ev := range events { + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Depth: depth, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + PrevEvents: prevs, + } + err := eb.SetContent(ev.Content) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) + } + stateNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&eb) + if err != nil { + t.Fatalf("mustCreateEvent: failed to work out auth_events : %s", err) + } + var authEvents []string + for _, tuple := range stateNeeded.Tuples() { + eventID := roomState[tuple] + if eventID != "" { + authEvents = append(authEvents, eventID) + } + } + eb.AuthEvents = authEvents + signedEvent, err := eb.Build(time.Now(), testOrigin, "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + depth++ + prevs = []string{signedEvent.EventID()} + if ev.StateKey != nil { + roomState[gomatrixserverlib.StateKeyTuple{ + EventType: ev.Type, + StateKey: *ev.StateKey, + }] = signedEvent.EventID() + } + result = append(result, signedEvent.Headered(roomVer)) + } + return +} + +func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage { + result := make([]json.RawMessage, len(events)) + for i := range events { + result[i] = events[i].JSON() + } + return result +} + +func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { + t.Helper() hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) for i := range events { e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver) @@ -93,7 +162,8 @@ func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js return hs } -func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { +func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyProducer) { + t.Helper() cfg := &config.Dendrite{} cfg.Defaults() cfg.Global.ServerName = testOrigin @@ -112,9 +182,14 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js Cfg: cfg, } - rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) - hevents := mustLoadEvents(t, ver, events) - if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { + return NewInternalAPI(base, &test.NopJSONVerifier{}), dp +} + +func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { + t.Helper() + rsAPI, dp := mustCreateRoomserverAPI(t) + hevents := mustLoadRawEvents(t, ver, events) + if err := api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents @@ -170,3 +245,163 @@ func TestOutputRedactedEvent(t *testing.T) { } } } + +// This tests that rewriting state via KindRewrite works correctly. +// This creates a small room with a create/join/name state, then replays it +// with a new room name. We expect the output events to contain the original events, +// followed by a single OutputNewRoomEvent with RewritesState set to true with the +// rewritten state events (with the 2nd room name). +func TestOutputRewritesState(t *testing.T) { + roomID := "!foo:" + string(testOrigin) + alice := "@alice:" + string(testOrigin) + emptyKey := "" + originalEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + "room_version": "6", + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world", + }, + StateKey: nil, + Type: "m.room.message", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + }) + rewriteEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name 2", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world 2", + }, + StateKey: nil, + Type: "m.room.message", + }, + }) + deleteDatabase() + rsAPI, producer := mustCreateRoomserverAPI(t) + defer deleteDatabase() + err := api.SendEvents(context.Background(), rsAPI, originalEvents, testOrigin, nil) + if err != nil { + t.Fatalf("failed to send original events: %s", err) + } + // assert we got them produced, this is just a sanity check and isn't the intention of this test + if len(producer.producedMessages) != len(originalEvents) { + t.Fatalf("SendEvents didn't result in same number of produced output events: got %d want %d", len(producer.producedMessages), len(originalEvents)) + } + producer.producedMessages = nil // we aren't actually interested in these events, just the rewrite ones + + var inputEvents []api.InputRoomEvent + // slowly build up the state IDs again, we're basically telling the roomserver what to store as a snapshot + var stateIDs []string + // skip the last event, we'll use this to tie together the rewrite as the KindNew event + for i := 0; i < len(rewriteEvents)-1; i++ { + ev := rewriteEvents[i] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindRewrite, + Event: ev, + AuthEventIDs: ev.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if ev.StateKey() != nil { + stateIDs = append(stateIDs, ev.EventID()) + } + } + lastEv := rewriteEvents[len(rewriteEvents)-1] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: lastEv, + AuthEventIDs: lastEv.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if err := api.SendInputRoomEvents(context.Background(), rsAPI, inputEvents); err != nil { + t.Fatalf("SendInputRoomEvents returned error for rewrite events: %s", err) + } + // we should just have one output event with the entire state of the room in it + if len(producer.producedMessages) != 1 { + t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages)) + } + outputEvent := producer.producedMessages[0] + if !outputEvent.NewRoomEvent.RewritesState { + t.Errorf("RewritesState flag not set on output event") + } + if !reflect.DeepEqual(stateIDs, outputEvent.NewRoomEvent.AddsStateEventIDs) { + t.Errorf("Output event is missing room state event IDs, got %v want %v", outputEvent.NewRoomEvent.AddsStateEventIDs, stateIDs) + } + if !bytes.Equal(outputEvent.NewRoomEvent.Event.JSON(), lastEv.JSON()) { + t.Errorf( + "Output event isn't the latest KindNew event:\ngot %s\nwant %s", + string(outputEvent.NewRoomEvent.Event.JSON()), + string(lastEv.JSON()), + ) + } + if len(outputEvent.NewRoomEvent.AddStateEvents) != len(stateIDs) { + t.Errorf("Output event is missing room state events themselves, got %d want %d", len(outputEvent.NewRoomEvent.AddStateEvents), len(stateIDs)) + } + // make sure the state got overwritten, check the room name + hasRoomName := false + for _, ev := range outputEvent.NewRoomEvent.AddStateEvents { + if ev.Type() == "m.room.name" { + hasRoomName = string(ev.Content()) == `{"name":"Room Name 2"}` + } + } + if !hasRoomName { + t.Errorf("Output event did not overwrite room state") + } +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 60f4b0fd5..f5b45763f 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,6 +16,8 @@ package types import ( + "sort" + "github.com/matrix-org/gomatrixserverlib" ) @@ -72,6 +74,25 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// Deduplicate takes a set of state entries and ensures that there are no +// duplicate (event type, state key) tuples. If there are then we dedupe +// them, making sure that the latest/highest NIDs are always chosen. +func DeduplicateStateEntries(a []StateEntry) []StateEntry { + if len(a) < 2 { + return a + } + sort.SliceStable(a, func(i, j int) bool { + return a[i].LessThan(a[j]) + }) + for i := 0; i < len(a)-1; i++ { + if a[i].StateKeyTuple == a[i+1].StateKeyTuple { + a = append(a[:i], a[i+1:]...) + i-- + } + } + return a +} + // StateAtEvent is the state before and after a matrix event. type StateAtEvent struct { // Should this state overwrite the latest events and memberships of the room? diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go new file mode 100644 index 000000000..b1e84b821 --- /dev/null +++ b/roomserver/types/types_test.go @@ -0,0 +1,26 @@ +package types + +import ( + "testing" +) + +func TestDeduplicateStateEntries(t *testing.T) { + entries := []StateEntry{ + {StateKeyTuple{1, 1}, 1}, + {StateKeyTuple{1, 1}, 2}, + {StateKeyTuple{1, 1}, 3}, + {StateKeyTuple{2, 2}, 4}, + {StateKeyTuple{2, 3}, 5}, + {StateKeyTuple{3, 3}, 6}, + } + expected := []EventNID{3, 4, 5, 6} + entries = DeduplicateStateEntries(entries) + if len(entries) != 4 { + t.Fatalf("Expected 4 entries, got %d entries", len(entries)) + } + for i, v := range entries { + if v.EventNID != expected[i] { + t.Fatalf("Expected position %d to be %d but got %d", i, expected[i], v.EventNID) + } + } +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index b6ab9bd50..d8d0a298a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -149,6 +149,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } + if msg.RewritesState { + if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil { + return fmt.Errorf("s.db.PurgeRoom: %w", err) + } + } + pduPos, err := s.db.WriteEvent( ctx, &ev, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 807c7f5e4..ce7f1c152 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -43,6 +43,9 @@ type Database interface { // Returns an error if there was a problem inserting this event. WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) + // PurgeRoom completely purges room state from the sync API. This is done when + // receiving an output event that completely resets the state. + PurgeRoom(ctx context.Context, roomID string) error // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 71569a108..130565882 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -46,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const deleteBackwardExtremitiesForRoomSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -67,6 +72,9 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } + if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -105,3 +113,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5cb7baadf..0ca9eed97 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -69,6 +69,9 @@ const upsertRoomStateSQL = "" + const deleteRoomStateByEventIDSQL = "" + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" +const DeleteRoomStateForRoomSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" @@ -98,6 +101,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt @@ -117,6 +121,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { return nil, err } + if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { + return nil, err + } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } @@ -214,6 +221,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( return err } +func (s *currentRoomStateStatements) DeleteRoomStateForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 5315de243..4b2101bbc 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -115,6 +115,9 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $8" +const deleteEventsForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -124,6 +127,7 @@ type outputRoomEventsStatements struct { selectEarlyEventsStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -156,6 +160,9 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { return nil, err } + if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -395,6 +402,13 @@ func (s *outputRoomEventsStatements) SelectEvents( return rowsToStreamEvents(rows) } +func (s *outputRoomEventsStatements) DeleteEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + return err +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 1ab3a1dc2..cbd20a075 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -71,12 +72,16 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + ") ORDER BY stream_position DESC LIMIT 1" +const deleteTopologyForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt + deleteTopologyForRoomStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -100,6 +105,9 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } + if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -167,3 +175,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } + +func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 94580adb8..05a8768e8 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -276,6 +276,29 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } +func (d *Database) PurgeRoom( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err) + } + if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err) + } + if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err) + } + return nil + }) +} + func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 116c33dc4..9a81e8e7f 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -46,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const deleteBackwardExtremitiesForRoomSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -70,6 +75,9 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } + if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -108,3 +116,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 6f822c90f..13d23be5f 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -51,12 +51,15 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s const upsertRoomStateSQL = "" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + - " ON CONFLICT (event_id, room_id, type, sender, contains_url)" + + " ON CONFLICT (room_id, type, state_key)" + " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" const deleteRoomStateByEventIDSQL = "" + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" +const DeleteRoomStateForRoomSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" @@ -88,6 +91,7 @@ type currentRoomStateStatements struct { streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt @@ -109,6 +113,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { return nil, err } + if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { + return nil, err + } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } @@ -203,6 +210,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( return err } +func (s *currentRoomStateStatements) DeleteRoomStateForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index f10d01066..587a40726 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -103,6 +103,9 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $8" // limit +const deleteEventsForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *streamIDStatements @@ -114,6 +117,7 @@ type outputRoomEventsStatements struct { selectEarlyEventsStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt } func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { @@ -149,6 +153,9 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { return nil, err } + if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -410,6 +417,13 @@ func (s *outputRoomEventsStatements) SelectEvents( return returnEvents, nil } +func (s *outputRoomEventsStatements) DeleteEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + return err +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index d8c97b7e3..d3ba9af62 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -65,6 +65,9 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 ORDER BY stream_position DESC" +const deleteTopologyForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt @@ -72,6 +75,7 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt + deleteTopologyForRoomStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -97,6 +101,9 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } + if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -164,3 +171,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } + +func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 631746c6e..da095be53 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -60,6 +60,8 @@ type Events interface { SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error + // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. + DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // Topology keeps track of the depths and stream positions for all events. @@ -77,6 +79,8 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) + // DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely. + DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } type CurrentRoomState interface { @@ -84,6 +88,7 @@ type CurrentRoomState interface { SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error + DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error // SelectCurrentState returns all the current state events for the given room. SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. @@ -118,6 +123,8 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + // DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely. + DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // SendToDevice tracks send-to-device messages which are sent to individual From ba6c7c4a5c4166b7085343886ab69ef331238ff4 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 15 Sep 2020 16:15:34 +0100 Subject: [PATCH 3/4] Disable prometheus to unbreak tests --- roomserver/roomserver_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 5a67a1bed..ef5901004 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -172,7 +172,7 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro dp := &dummyProducer{ topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent), } - cache, err := caching.NewInMemoryLRUCache(true) + cache, err := caching.NewInMemoryLRUCache(false) if err != nil { t.Fatalf("failed to make caches: %s", err) } From 18231f25b437d2f03b3be1e0536fc46d45c8691f Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 16 Sep 2020 13:00:52 +0100 Subject: [PATCH 4/4] Implement rejected events (#1426) * WIP Event rejection * Still send back errors for rejected events Instead, discard them at the federationapi /send layer rather than re-implementing checks at the clientapi/PerformJoin layer. * Implement rejected events Critically, rejected events CAN cause state resolution to happen as it can merge forks in the DAG. This is fine, _provided_ we do not add the rejected event when performing state resolution, which is what this PR does. It also fixes the error handling when NotAllowed happens, as we were checking too early and needlessly handling NotAllowed in more than one place. * Update test to match reality * Modify InputRoomEvents to no longer return an error Errors do not serialise across HTTP boundaries in polylith mode, so instead set fields on the InputRoomEventsResponse. Add `Err()` function to make the API shape basically the same. * Remove redundant returns; linting * Update blacklist --- cmd/roomserver-integration-tests/main.go | 3 +- federationapi/routing/send.go | 9 ++---- federationapi/routing/send_test.go | 14 ++++----- roomserver/api/api.go | 2 +- roomserver/api/api_trace.go | 7 ++--- roomserver/api/input.go | 16 ++++++++++ roomserver/api/wrapper.go | 3 +- roomserver/internal/alias.go | 3 +- roomserver/internal/input/input.go | 8 +++-- roomserver/internal/input/input_events.go | 31 +++++++++++++------ .../internal/perform/perform_backfill.go | 2 +- roomserver/internal/perform/perform_invite.go | 3 +- roomserver/internal/perform/perform_join.go | 3 +- roomserver/internal/perform/perform_leave.go | 3 +- roomserver/internal/query/query.go | 1 + roomserver/inthttp/client.go | 7 +++-- roomserver/inthttp/server.go | 4 +-- roomserver/roomserver_test.go | 8 ----- roomserver/state/state.go | 5 +-- roomserver/storage/interface.go | 1 + roomserver/storage/postgres/events_table.go | 12 ++++--- roomserver/storage/shared/storage.go | 7 +++-- roomserver/storage/sqlite3/events_table.go | 13 +++++--- roomserver/storage/tables/interface.go | 5 ++- roomserver/types/types.go | 3 ++ sytest-blacklist | 11 ++----- sytest-whitelist | 4 ++- 27 files changed, 114 insertions(+), 74 deletions(-) diff --git a/cmd/roomserver-integration-tests/main.go b/cmd/roomserver-integration-tests/main.go index 435747780..41ea6f4d8 100644 --- a/cmd/roomserver-integration-tests/main.go +++ b/cmd/roomserver-integration-tests/main.go @@ -215,7 +215,8 @@ func writeToRoomServer(input []string, roomserverURL string) error { if err != nil { return err } - return x.InputRoomEvents(context.Background(), &request, &response) + x.InputRoomEvents(context.Background(), &request, &response) + return response.Err() } // testRoomserver is used to run integration tests against a single roomserver. diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 9def7c3c3..cb7bea6c4 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -372,12 +372,9 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) } - // Check that the event is allowed by the state at the event. - if err := checkAllowedByState(e, gomatrixserverlib.UnwrapEventHeaders(stateResp.StateEvents)); err != nil { - return err - } - - // pass the event to the roomserver + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function return api.SendEvents( context.Background(), t.rsAPI, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index a714d07ee..4f447f372 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -89,12 +89,11 @@ func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } - return nil } func (t *testRoomserverAPI) PerformInvite( @@ -461,7 +460,8 @@ func TestBasicTransaction(t *testing.T) { assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } -// The purpose of this test is to check that if the event received fails auth checks the transaction is failed. +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { rsAPI := &testRoomserverAPI{ queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { @@ -479,11 +479,9 @@ func TestTransactionFailAuthChecks(t *testing.T) { testData[len(testData)-1], // a message event } txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{ - // expect the event to have an error - testEvents[len(testEvents)-1].EventID(), - }) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } // The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index eecefe322..2495157a6 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -16,7 +16,7 @@ type RoomserverInternalAPI interface { ctx context.Context, request *InputRoomEventsRequest, response *InputRoomEventsResponse, - ) error + ) PerformInvite( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 643309307..b7accb9a8 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -23,10 +23,9 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, -) error { - err := t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) - return err +) { + t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) } func (t *RoomserverInternalAPITrace) PerformInvite( diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 651c0e9f8..862a6fa1f 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -16,6 +16,8 @@ package api import ( + "fmt" + "github.com/matrix-org/gomatrixserverlib" ) @@ -87,4 +89,18 @@ type InputRoomEventsRequest struct { // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { + ErrMsg string // set if there was any error + NotAllowed bool // true if an event in the input was not allowed. +} + +func (r *InputRoomEventsResponse) Err() error { + if r.ErrMsg == "" { + return nil + } + if r.NotAllowed { + return &gomatrixserverlib.NotAllowed{ + Message: r.ErrMsg, + } + } + return fmt.Errorf("InputRoomEventsResponse: %s", r.ErrMsg) } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index e53393119..cc048ddd4 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -187,7 +187,8 @@ func SendInputRoomEvents( ) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse - return rsAPI.InputRoomEvents(ctx, &request, &response) + rsAPI.InputRoomEvents(ctx, &request, &response) + return response.Err() } // SendInvite event to the roomserver. diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index d576a8175..3e023d2a7 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -271,5 +271,6 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( var inputRes api.InputRoomEventsResponse // Send the request - return r.InputRoomEvents(ctx, &inputReq, &inputRes) + r.InputRoomEvents(ctx, &inputReq, &inputRes) + return inputRes.Err() } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 51d20ad39..d340ac218 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -110,7 +110,7 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { // Create a wait group. Each task that we dispatch will call Done on // this wait group so that we know when all of our events have been // processed. @@ -156,8 +156,10 @@ func (r *Inputer) InputRoomEvents( // that back to the caller. for _, task := range tasks { if task.err != nil { - return task.err + response.ErrMsg = task.err.Error() + _, rejected := task.err.(*gomatrixserverlib.NotAllowed) + response.NotAllowed = rejected + return } } - return nil } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index daf1afcd3..0558cd763 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -46,10 +46,11 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) - if err != nil { - logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") - return + isRejected := false + authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + if rejectionErr != nil { + logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event, rejecting event") + isRejected = true } // If we don't have a transaction ID then get one. @@ -65,12 +66,13 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs, isRejected) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } + // if storing this event results in it being redacted then do so. - if redactedEventID == event.EventID() { + if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, &event) if rerr != nil { return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) @@ -101,12 +103,22 @@ func (r *Inputer) processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected) if err != nil { return "", fmt.Errorf("r.calculateAndSetState: %w", err) } } + // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. + if isRejected { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Debug("Stored rejected event") + return event.EventID(), rejectionErr + } + if input.Kind == api.KindRewrite { logrus.WithFields(logrus.Fields{ "event_id": event.EventID(), @@ -157,11 +169,12 @@ func (r *Inputer) calculateAndSetState( roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, + isRejected bool, ) error { var err error roomState := state.NewStateResolution(r.DB, roomInfo) - if input.HasState { + if input.HasState && !isRejected { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID @@ -188,7 +201,7 @@ func (r *Inputer) calculateAndSetState( stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, isRejected); err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err) } } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 668c80787..eb1aa99b8 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -535,7 +535,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse var stateAtEvent types.StateAtEvent var redactedEventID string var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) + roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e06ad062d..d6a64e7e8 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -183,7 +183,8 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } } else { diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index f76806c73..e9aebb839 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -247,7 +247,8 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = inputRes.Err(); err != nil { var notAllowed *gomatrixserverlib.NotAllowed if errors.As(err, ¬Allowed) { return "", &api.PerformError{ diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index aaa3b5b16..6aaf1bf3e 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -139,7 +139,8 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b34ae7701..fb981447f 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -70,6 +70,7 @@ func (r *Queryer) QueryStateAfterEvents( if err != nil { switch err.(type) { case types.MissingEventError: + util.GetLogger(ctx).Errorf("QueryStateAfterEvents: MissingEventError: %s", err) return nil default: return err diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1ff1fc82b..f2510c753 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -149,12 +149,15 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") defer span.Finish() apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.ErrMsg = err.Error() + } } func (h *httpRoomserverInternalAPI) PerformInvite( diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 5816d4d82..8ffa9cf9f 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -20,9 +20,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } + r.InputRoomEvents(req.Context(), &request, &response) return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index ef5901004..912c5852f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -140,14 +140,6 @@ func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, event return } -func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage { - result := make([]json.RawMessage, len(events)) - for i := range events { - result[i] = events[i].JSON() - } - return result -} - func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { t.Helper() hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 37e6807a3..9ee6f40d4 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -159,7 +159,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents( } fullState = append(fullState, entries...) } - if prevState.IsStateEvent() { + if prevState.IsStateEvent() && !prevState.IsRejected { // If the prev event was a state event then add an entry for the event itself // so that we get the state after the event rather than the state before. fullState = append(fullState, prevState.StateEntry) @@ -523,6 +523,7 @@ func init() { func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, event gomatrixserverlib.Event, + isRejected bool, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -561,7 +562,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( if len(prevStates) == 1 { prevState := prevStates[0] - if prevState.EventStateKeyNID == 0 { + if prevState.EventStateKeyNID == 0 || prevState.IsRejected { // 3) None of the previous events were state events and they all // have the same state, so this event has exactly the same state // as the previous events. diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index be724da62..10a380e85 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -70,6 +70,7 @@ type Database interface { // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, + isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index e66efb097..c8eb8e2d2 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -65,13 +65,14 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( -- Needed for setting reference hashes when sending new events. reference_sha256 BYTEA NOT NULL, -- A list of numeric IDs for events that can authenticate this event. - auth_event_nids BIGINT[] NOT NULL + auth_event_nids BIGINT[] NOT NULL, + is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = "" + - "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" + + "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + " DO NOTHING" + " RETURNING event_nid, state_snapshot_nid" @@ -88,7 +89,7 @@ const bulkSelectStateEventByIDSQL = "" + " ORDER BY event_type_nid, event_state_key_nid ASC" const bulkSelectStateAtEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" const updateEventStateSQL = "" + @@ -174,12 +175,14 @@ func (s *eventStatements) InsertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, + isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 err := s.insertEventStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -255,6 +258,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( &result.EventStateKeyNID, &result.EventNID, &result.BeforeStateSnapshotNID, + &result.IsRejected, ); err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 262b0f2f8..e710b99b7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -382,7 +382,7 @@ func (d *Database) GetLatestEventsForUpdate( // nolint:gocyclo func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -446,6 +446,7 @@ func (d *Database) StoreEvent( event.EventReference().EventSHA256, authEventNIDs, event.Depth(), + isRejected, ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID @@ -459,7 +460,9 @@ func (d *Database) StoreEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + if !isRejected { // ignore rejected redaction events + redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + } return nil }) if err != nil { diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index a866c85d0..773e9ade3 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -41,13 +41,14 @@ const eventsSchema = ` depth INTEGER NOT NULL, event_id TEXT NOT NULL UNIQUE, reference_sha256 BLOB NOT NULL, - auth_event_nids TEXT NOT NULL DEFAULT '[]' + auth_event_nids TEXT NOT NULL DEFAULT '[]', + is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = ` - INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT DO NOTHING; ` @@ -63,7 +64,7 @@ const bulkSelectStateEventByIDSQL = "" + " ORDER BY event_type_nid, event_state_key_nid ASC" const bulkSelectStateAtEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" const updateEventStateSQL = "" + @@ -150,13 +151,14 @@ func (s *eventStatements) InsertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, + isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) result, err := insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, ) if err != nil { return 0, 0, err @@ -261,6 +263,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( &result.EventStateKeyNID, &result.EventNID, &result.BeforeStateSnapshotNID, + &result.IsRejected, ); err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index adb06212a..eba878ba5 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -34,7 +34,10 @@ type EventStateKeys interface { } type Events interface { - InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) + InsertEvent( + ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, + referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, + ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f5b45763f..c0fcef65e 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -101,6 +101,9 @@ type StateAtEvent struct { Overwrite bool // The state before the event. BeforeStateSnapshotNID StateSnapshotNID + // True if this StateEntry is rejected. State resolution should then treat this + // StateEntry as being a message event (not a state event). + IsRejected bool // The state entry for the event itself, allows us to calculate the state after the event. StateEntry } diff --git a/sytest-blacklist b/sytest-blacklist index 705c9ff4c..246e68303 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -40,11 +40,6 @@ Ignore invite in incremental sync New room members see their own join event Existing members see new members' join events -# Blacklisted because the federation work for these hasn't been finished yet. -Can recv device messages over federation -Device messages over federation wake up /sync -Wildcard device messages over federation wake up /sync - # See https://github.com/matrix-org/sytest/pull/901 Remote invited user can see room metadata @@ -56,8 +51,8 @@ Inbound federation accepts a second soft-failed event # Caused by https://github.com/matrix-org/sytest/pull/911 Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state -# We don't implement device lists yet -Device list doesn't change if remote server is down - # We don't implement lazy membership loading yet. The only membership state included in a gapped incremental sync is for senders in the timeline + +# flakey since implementing rejected events +Inbound federation correctly soft fails events \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 0adeaee6f..91516428d 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -470,4 +470,6 @@ We can't peek into rooms with shared history_visibility We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias -Peeked rooms only turn up in the sync for the device who peeked them \ No newline at end of file +Peeked rooms only turn up in the sync for the device who peeked them +Room state at a rejected message event is the same as its predecessor +Room state at a rejected state event is the same as its predecessor \ No newline at end of file