Initial work on soft-fail

This commit is contained in:
Neil Alexander 2020-08-26 14:37:05 +01:00
parent 3205b9212d
commit db5a99e82b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 154 additions and 17 deletions

View file

@ -16,13 +16,70 @@ package internal
import (
"context"
"fmt"
"sort"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// checkForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before
// the soft-fail bool.
func checkForSoftFail(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
) (bool, error) {
// Look up the current state of the room.
authStateEntries, err := db.StateForRoomID(ctx, event.RoomID())
if err != nil {
return false, fmt.Errorf("db.SnapshotNIDFromRoomID: %w", err)
}
// As a special case, it's possible that the room will have no
// state because we haven't received a m.room.create event yet.
// If we're now processing the first create event then never
// soft-fail it.
if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
return false, nil
}
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
logger := logrus.WithField("room_id", event.RoomID())
logger.Infof("EVENT %s TYPE %s", event.EventID(), event.Type())
logger.Infof("STATE NEEDED:")
for _, tuple := range stateNeeded.Tuples() {
logger.Infof("* %q %q", tuple.EventType, tuple.StateKey)
}
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return false, fmt.Errorf("loadAuthEvents: %w", err)
}
logger.Info("STATE RETRIEVED:")
for _, e := range authEvents.events {
logger.Infof("* %q %q -> %s", e.Type(), *e.StateKey(), string(e.Content()))
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
// return true, nil
logger.Info("SOFT-FAIL")
return false, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
}
logger.Info("ALLOW")
return false, nil
}
// checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
func checkAuthEvents(

View file

@ -42,15 +42,27 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Parse and validate the event JSON
headered := input.Event
event := headered.Unwrap()
softfail := false
// Check that the event passes authentication checks and work out
// the numeric IDs for the auth events.
// Check that the event passes authentication checks based on the
// event-specified auth events and work out the numeric IDs for those.
authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
return
}
// Check that the event passes authentication checks based on the
// current room state.
softfail, err = checkForSoftFail(ctx, r.DB, headered)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).WithError(err).Info("Error authing soft-failed event")
}
// If we don't have a transaction ID then get one.
if input.TransactionID != nil {
tdID := input.TransactionID
@ -68,6 +80,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
if err != nil {
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
@ -105,6 +118,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
event, // event
input.SendAsServer, // send as server
input.TransactionID, // transaction ID
softfail, // event soft-failed?
); err != nil {
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
}

View file

@ -54,6 +54,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
event gomatrixserverlib.Event,
sendAsServer string,
transactionID *api.TransactionID,
softfail bool,
) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil {
@ -71,6 +72,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
event: event,
sendAsServer: sendAsServer,
transactionID: transactionID,
softfail: softfail,
}
if err = u.doUpdateLatestEvents(); err != nil {
@ -93,6 +95,7 @@ type latestEventsUpdater struct {
stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event
transactionID *api.TransactionID
softfail bool
// Which server to send this event as.
sendAsServer string
// The eventID of the event that was processed before this one.
@ -178,22 +181,24 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
return fmt.Errorf("u.api.updateMemberships: %w", err)
}
update, err := u.makeOutputNewRoomEvent()
if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
}
updates = append(updates, *update)
if !u.softfail {
update, err := u.makeOutputNewRoomEvent()
if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
}
updates = append(updates, *update)
// Send the event to the output logs.
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
// the write to the output log succeeds)
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
// Send the event to the output logs.
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
// the write to the output log succeeds)
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
}
}
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {

View file

@ -64,6 +64,8 @@ type Database interface {
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
// Look up current state for a room ID.
StateForRoomID(ctx context.Context, roomID string) ([]types.StateEntry, error)
// Look up a room version from the room NID.
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.

View file

@ -74,6 +74,9 @@ const selectRoomVersionForRoomIDSQL = "" +
const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
const selectStateSnapshotNIDSQL = "" +
"SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1"
type roomStatements struct {
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
@ -82,6 +85,7 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomIDStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectStateSnapshotNIDStmt *sql.Stmt
}
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -98,6 +102,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
{&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL},
}.Prepare(db)
}
@ -196,3 +201,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
}
return roomVersion, err
}
func (s *roomStatements) SelectStateSnapshotNID(
ctx context.Context, roomID string,
) (types.StateSnapshotNID, error) {
var stateSnapshotNID types.StateSnapshotNID
err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID)
if err == sql.ErrNoRows {
return 0, errors.New("room not found")
}
return stateSnapshotNID, err
}

View file

@ -174,6 +174,32 @@ func (d *Database) SnapshotNIDFromEventID(
return stateNID, err
}
func (d *Database) StateForRoomID(
ctx context.Context, roomID string,
) ([]types.StateEntry, error) {
stateSnapshotNID, err := d.RoomsTable.SelectStateSnapshotNID(ctx, roomID)
if err != nil || stateSnapshotNID == 0 {
// the room doesn't exist or it doesn't have state
return nil, nil
}
stateBlockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{stateSnapshotNID})
if err != nil {
return nil, fmt.Errorf("d.StateSnapshotTable.BulkSelectStateBlockNIDs: %w", err)
}
if len(stateBlockNIDs) != 1 {
return nil, fmt.Errorf("expected one StateBlockNIDList, got %d", len(stateBlockNIDs))
}
stateEventLists, err := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs[0].StateBlockNIDs)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
stateEventNIDs := []types.StateEntry{}
for _, stateEventList := range stateEventLists {
stateEventNIDs = append(stateEventNIDs, stateEventList.StateEntries...)
}
return stateEventNIDs, nil
}
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {

View file

@ -64,6 +64,9 @@ const selectRoomVersionForRoomIDSQL = "" +
const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
const selectStateSnapshotNIDSQL = "" +
"SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1"
type roomStatements struct {
db *sql.DB
insertRoomNIDStmt *sql.Stmt
@ -73,6 +76,7 @@ type roomStatements struct {
updateLatestEventNIDsStmt *sql.Stmt
selectRoomVersionForRoomIDStmt *sql.Stmt
selectRoomVersionForRoomNIDStmt *sql.Stmt
selectStateSnapshotNIDStmt *sql.Stmt
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
@ -91,6 +95,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
{&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL},
}.Prepare(db)
}
@ -195,3 +200,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
}
return roomVersion, err
}
func (s *roomStatements) SelectStateSnapshotNID(
ctx context.Context, roomID string,
) (types.StateSnapshotNID, error) {
var stateSnapshotNID types.StateSnapshotNID
err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID)
if err == sql.ErrNoRows {
return 0, errors.New("room not found")
}
return stateSnapshotNID, err
}

View file

@ -65,6 +65,7 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error)
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
SelectStateSnapshotNID(ctx context.Context, roomID string) (types.StateSnapshotNID, error)
}
type Transactions interface {