From db5a99e82b8d893b5cc20e1b5758d6e55451f2f4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 26 Aug 2020 14:37:05 +0100 Subject: [PATCH] Initial work on soft-fail --- roomserver/internal/input_authevents.go | 57 ++++++++++++++++++++++ roomserver/internal/input_events.go | 18 ++++++- roomserver/internal/input_latest_events.go | 35 +++++++------ roomserver/storage/interface.go | 2 + roomserver/storage/postgres/rooms_table.go | 16 ++++++ roomserver/storage/shared/storage.go | 26 ++++++++++ roomserver/storage/sqlite3/rooms_table.go | 16 ++++++ roomserver/storage/tables/interface.go | 1 + 8 files changed, 154 insertions(+), 17 deletions(-) diff --git a/roomserver/internal/input_authevents.go b/roomserver/internal/input_authevents.go index e3828f566..3858bbf27 100644 --- a/roomserver/internal/input_authevents.go +++ b/roomserver/internal/input_authevents.go @@ -16,13 +16,70 @@ package internal import ( "context" + "fmt" "sort" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) +// checkForSoftFail returns true if the event should be soft-failed +// and false otherwise. The return error value should be checked before +// the soft-fail bool. +func checkForSoftFail( + ctx context.Context, + db storage.Database, + event gomatrixserverlib.HeaderedEvent, +) (bool, error) { + // Look up the current state of the room. + authStateEntries, err := db.StateForRoomID(ctx, event.RoomID()) + if err != nil { + return false, fmt.Errorf("db.SnapshotNIDFromRoomID: %w", err) + } + + // As a special case, it's possible that the room will have no + // state because we haven't received a m.room.create event yet. + // If we're now processing the first create event then never + // soft-fail it. + if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate { + return false, nil + } + + // Work out which of the state events we actually need. + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) + + logger := logrus.WithField("room_id", event.RoomID()) + logger.Infof("EVENT %s TYPE %s", event.EventID(), event.Type()) + + logger.Infof("STATE NEEDED:") + for _, tuple := range stateNeeded.Tuples() { + logger.Infof("* %q %q", tuple.EventType, tuple.StateKey) + } + + // Load the actual auth events from the database. + authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + if err != nil { + return false, fmt.Errorf("loadAuthEvents: %w", err) + } + + logger.Info("STATE RETRIEVED:") + for _, e := range authEvents.events { + logger.Infof("* %q %q -> %s", e.Type(), *e.StateKey(), string(e.Content())) + } + + // Check if the event is allowed. + if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { + // return true, nil + logger.Info("SOFT-FAIL") + return false, fmt.Errorf("gomatrixserverlib.Allowed: %w", err) + } + + logger.Info("ALLOW") + return false, nil +} + // checkAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. func checkAuthEvents( diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index a63082990..8724df6ba 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -42,15 +42,27 @@ func (r *RoomserverInternalAPI) processRoomEvent( // Parse and validate the event JSON headered := input.Event event := headered.Unwrap() + softfail := false - // Check that the event passes authentication checks and work out - // the numeric IDs for the auth events. + // Check that the event passes authentication checks based on the + // event-specified auth events and work out the numeric IDs for those. authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if err != nil { logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") return } + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = checkForSoftFail(ctx, r.DB, headered) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).WithError(err).Info("Error authing soft-failed event") + } + // If we don't have a transaction ID then get one. if input.TransactionID != nil { tdID := input.TransactionID @@ -68,6 +80,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } + // if storing this event results in it being redacted then do so. if redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, &event) @@ -105,6 +118,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( event, // event input.SendAsServer, // send as server input.TransactionID, // transaction ID + softfail, // event soft-failed? ); err != nil { return "", fmt.Errorf("r.updateLatestEvents: %w", err) } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index f11a78d72..fbb2a96c6 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -54,6 +54,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, + softfail bool, ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { @@ -71,6 +72,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( event: event, sendAsServer: sendAsServer, transactionID: transactionID, + softfail: softfail, } if err = u.doUpdateLatestEvents(); err != nil { @@ -93,6 +95,7 @@ type latestEventsUpdater struct { stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID + softfail bool // Which server to send this event as. sendAsServer string // The eventID of the event that was processed before this one. @@ -178,22 +181,24 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.updateMemberships: %w", err) } - update, err := u.makeOutputNewRoomEvent() - if err != nil { - return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) - } - updates = append(updates, *update) + if !u.softfail { + update, err := u.makeOutputNewRoomEvent() + if err != nil { + return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) + } + updates = append(updates, *update) - // Send the event to the output logs. - // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. - // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but - // the write to the output log succeeds) - // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to - // send the event asynchronously but we would need to ensure that 1) the events are written to the log in - // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the - // necessary bookkeeping we'll keep the event sending synchronous for now. - if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { - return fmt.Errorf("u.api.WriteOutputEvents: %w", err) + // Send the event to the output logs. + // We do this inside the database transaction to ensure that we only mark an event as sent if we sent it. + // (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but + // the write to the output log succeeds) + // TODO: This assumes that writing the event to the output log is synchronous. It should be possible to + // send the event asynchronously but we would need to ensure that 1) the events are written to the log in + // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the + // necessary bookkeeping we'll keep the event sending synchronous for now. + if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { + return fmt.Errorf("u.api.WriteOutputEvents: %w", err) + } } if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 545885f78..3b2311b9f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -64,6 +64,8 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + // Look up current state for a room ID. + StateForRoomID(ctx context.Context, roomID string) ([]types.StateEntry, error) // Look up a room version from the room NID. GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 8e00cfdb8..baf9e1d83 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -74,6 +74,9 @@ const selectRoomVersionForRoomIDSQL = "" + const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" +const selectStateSnapshotNIDSQL = "" + + "SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -82,6 +85,7 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt + selectStateSnapshotNIDStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -98,6 +102,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, + {&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL}, }.Prepare(db) } @@ -196,3 +201,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) SelectStateSnapshotNID( + ctx context.Context, roomID string, +) (types.StateSnapshotNID, error) { + var stateSnapshotNID types.StateSnapshotNID + err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID) + if err == sql.ErrNoRows { + return 0, errors.New("room not found") + } + return stateSnapshotNID, err +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 50ab5dde5..345337407 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -174,6 +174,32 @@ func (d *Database) SnapshotNIDFromEventID( return stateNID, err } +func (d *Database) StateForRoomID( + ctx context.Context, roomID string, +) ([]types.StateEntry, error) { + stateSnapshotNID, err := d.RoomsTable.SelectStateSnapshotNID(ctx, roomID) + if err != nil || stateSnapshotNID == 0 { + // the room doesn't exist or it doesn't have state + return nil, nil + } + stateBlockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{stateSnapshotNID}) + if err != nil { + return nil, fmt.Errorf("d.StateSnapshotTable.BulkSelectStateBlockNIDs: %w", err) + } + if len(stateBlockNIDs) != 1 { + return nil, fmt.Errorf("expected one StateBlockNIDList, got %d", len(stateBlockNIDs)) + } + stateEventLists, err := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs[0].StateBlockNIDs) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + stateEventNIDs := []types.StateEntry{} + for _, stateEventList := range stateEventLists { + stateEventNIDs = append(stateEventNIDs, stateEventList.StateEntries...) + } + return stateEventNIDs, nil +} + func (d *Database) EventIDs( ctx context.Context, eventNIDs []types.EventNID, ) (map[types.EventNID]string, error) { diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 6541cc0cb..7e14efca4 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -64,6 +64,9 @@ const selectRoomVersionForRoomIDSQL = "" + const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" +const selectStateSnapshotNIDSQL = "" + + "SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -73,6 +76,7 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt + selectStateSnapshotNIDStmt *sql.Stmt } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -91,6 +95,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, + {&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL}, }.Prepare(db) } @@ -195,3 +200,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) SelectStateSnapshotNID( + ctx context.Context, roomID string, +) (types.StateSnapshotNID, error) { + var stateSnapshotNID types.StateSnapshotNID + err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID) + if err == sql.ErrNoRows { + return 0, errors.New("room not found") + } + return stateSnapshotNID, err +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 47c12c2ca..ba0da321f 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -65,6 +65,7 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) + SelectStateSnapshotNID(ctx context.Context, roomID string) (types.StateSnapshotNID, error) } type Transactions interface {