mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 06:43:09 -06:00
Initial work on soft-fail
This commit is contained in:
parent
3205b9212d
commit
db5a99e82b
|
|
@ -16,13 +16,70 @@ package internal
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// checkForSoftFail returns true if the event should be soft-failed
|
||||
// and false otherwise. The return error value should be checked before
|
||||
// the soft-fail bool.
|
||||
func checkForSoftFail(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
event gomatrixserverlib.HeaderedEvent,
|
||||
) (bool, error) {
|
||||
// Look up the current state of the room.
|
||||
authStateEntries, err := db.StateForRoomID(ctx, event.RoomID())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("db.SnapshotNIDFromRoomID: %w", err)
|
||||
}
|
||||
|
||||
// As a special case, it's possible that the room will have no
|
||||
// state because we haven't received a m.room.create event yet.
|
||||
// If we're now processing the first create event then never
|
||||
// soft-fail it.
|
||||
if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Work out which of the state events we actually need.
|
||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
|
||||
|
||||
logger := logrus.WithField("room_id", event.RoomID())
|
||||
logger.Infof("EVENT %s TYPE %s", event.EventID(), event.Type())
|
||||
|
||||
logger.Infof("STATE NEEDED:")
|
||||
for _, tuple := range stateNeeded.Tuples() {
|
||||
logger.Infof("* %q %q", tuple.EventType, tuple.StateKey)
|
||||
}
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("loadAuthEvents: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("STATE RETRIEVED:")
|
||||
for _, e := range authEvents.events {
|
||||
logger.Infof("* %q %q -> %s", e.Type(), *e.StateKey(), string(e.Content()))
|
||||
}
|
||||
|
||||
// Check if the event is allowed.
|
||||
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
|
||||
// return true, nil
|
||||
logger.Info("SOFT-FAIL")
|
||||
return false, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("ALLOW")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// checkAuthEvents checks that the event passes authentication checks
|
||||
// Returns the numeric IDs for the auth events.
|
||||
func checkAuthEvents(
|
||||
|
|
|
|||
|
|
@ -42,15 +42,27 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
|||
// Parse and validate the event JSON
|
||||
headered := input.Event
|
||||
event := headered.Unwrap()
|
||||
softfail := false
|
||||
|
||||
// Check that the event passes authentication checks and work out
|
||||
// the numeric IDs for the auth events.
|
||||
// Check that the event passes authentication checks based on the
|
||||
// event-specified auth events and work out the numeric IDs for those.
|
||||
authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the event passes authentication checks based on the
|
||||
// current room state.
|
||||
softfail, err = checkForSoftFail(ctx, r.DB, headered)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event_id": event.EventID(),
|
||||
"type": event.Type(),
|
||||
"room": event.RoomID(),
|
||||
}).WithError(err).Info("Error authing soft-failed event")
|
||||
}
|
||||
|
||||
// If we don't have a transaction ID then get one.
|
||||
if input.TransactionID != nil {
|
||||
tdID := input.TransactionID
|
||||
|
|
@ -68,6 +80,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
|||
if err != nil {
|
||||
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
|
||||
}
|
||||
|
||||
// if storing this event results in it being redacted then do so.
|
||||
if redactedEventID == event.EventID() {
|
||||
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
|
||||
|
|
@ -105,6 +118,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
|||
event, // event
|
||||
input.SendAsServer, // send as server
|
||||
input.TransactionID, // transaction ID
|
||||
softfail, // event soft-failed?
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
|||
event gomatrixserverlib.Event,
|
||||
sendAsServer string,
|
||||
transactionID *api.TransactionID,
|
||||
softfail bool,
|
||||
) (err error) {
|
||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
||||
if err != nil {
|
||||
|
|
@ -71,6 +72,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
|||
event: event,
|
||||
sendAsServer: sendAsServer,
|
||||
transactionID: transactionID,
|
||||
softfail: softfail,
|
||||
}
|
||||
|
||||
if err = u.doUpdateLatestEvents(); err != nil {
|
||||
|
|
@ -93,6 +95,7 @@ type latestEventsUpdater struct {
|
|||
stateAtEvent types.StateAtEvent
|
||||
event gomatrixserverlib.Event
|
||||
transactionID *api.TransactionID
|
||||
softfail bool
|
||||
// Which server to send this event as.
|
||||
sendAsServer string
|
||||
// The eventID of the event that was processed before this one.
|
||||
|
|
@ -178,22 +181,24 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
|||
return fmt.Errorf("u.api.updateMemberships: %w", err)
|
||||
}
|
||||
|
||||
update, err := u.makeOutputNewRoomEvent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
||||
}
|
||||
updates = append(updates, *update)
|
||||
if !u.softfail {
|
||||
update, err := u.makeOutputNewRoomEvent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
||||
}
|
||||
updates = append(updates, *update)
|
||||
|
||||
// Send the event to the output logs.
|
||||
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
|
||||
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
|
||||
// the write to the output log succeeds)
|
||||
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
|
||||
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
|
||||
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
||||
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||
// Send the event to the output logs.
|
||||
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
|
||||
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
|
||||
// the write to the output log succeeds)
|
||||
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
|
||||
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
|
||||
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
||||
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ type Database interface {
|
|||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
// Look up snapshot NID for an event ID string
|
||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||
// Look up current state for a room ID.
|
||||
StateForRoomID(ctx context.Context, roomID string) ([]types.StateEntry, error)
|
||||
// Look up a room version from the room NID.
|
||||
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ const selectRoomVersionForRoomIDSQL = "" +
|
|||
const selectRoomVersionForRoomNIDSQL = "" +
|
||||
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const selectStateSnapshotNIDSQL = "" +
|
||||
"SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
type roomStatements struct {
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
|
|
@ -82,6 +85,7 @@ type roomStatements struct {
|
|||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
selectRoomVersionForRoomIDStmt *sql.Stmt
|
||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectStateSnapshotNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
|
|
@ -98,6 +102,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
|
||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||
{&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
|
@ -196,3 +201,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
|||
}
|
||||
return roomVersion, err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectStateSnapshotNID(
|
||||
ctx context.Context, roomID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
var stateSnapshotNID types.StateSnapshotNID
|
||||
err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, errors.New("room not found")
|
||||
}
|
||||
return stateSnapshotNID, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,6 +174,32 @@ func (d *Database) SnapshotNIDFromEventID(
|
|||
return stateNID, err
|
||||
}
|
||||
|
||||
func (d *Database) StateForRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
) ([]types.StateEntry, error) {
|
||||
stateSnapshotNID, err := d.RoomsTable.SelectStateSnapshotNID(ctx, roomID)
|
||||
if err != nil || stateSnapshotNID == 0 {
|
||||
// the room doesn't exist or it doesn't have state
|
||||
return nil, nil
|
||||
}
|
||||
stateBlockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{stateSnapshotNID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.StateSnapshotTable.BulkSelectStateBlockNIDs: %w", err)
|
||||
}
|
||||
if len(stateBlockNIDs) != 1 {
|
||||
return nil, fmt.Errorf("expected one StateBlockNIDList, got %d", len(stateBlockNIDs))
|
||||
}
|
||||
stateEventLists, err := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs[0].StateBlockNIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
|
||||
}
|
||||
stateEventNIDs := []types.StateEntry{}
|
||||
for _, stateEventList := range stateEventLists {
|
||||
stateEventNIDs = append(stateEventNIDs, stateEventList.StateEntries...)
|
||||
}
|
||||
return stateEventNIDs, nil
|
||||
}
|
||||
|
||||
func (d *Database) EventIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]string, error) {
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ const selectRoomVersionForRoomIDSQL = "" +
|
|||
const selectRoomVersionForRoomNIDSQL = "" +
|
||||
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
||||
|
||||
const selectStateSnapshotNIDSQL = "" +
|
||||
"SELECT state_snapshot_nid FROM roomserver_rooms WHERE room_id = $1"
|
||||
|
||||
type roomStatements struct {
|
||||
db *sql.DB
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
|
|
@ -73,6 +76,7 @@ type roomStatements struct {
|
|||
updateLatestEventNIDsStmt *sql.Stmt
|
||||
selectRoomVersionForRoomIDStmt *sql.Stmt
|
||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||
selectStateSnapshotNIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||
|
|
@ -91,6 +95,7 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
|
||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||
{&s.selectStateSnapshotNIDStmt, selectStateSnapshotNIDSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
|
|
@ -195,3 +200,14 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
|||
}
|
||||
return roomVersion, err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectStateSnapshotNID(
|
||||
ctx context.Context, roomID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
var stateSnapshotNID types.StateSnapshotNID
|
||||
err := s.selectStateSnapshotNIDStmt.QueryRowContext(ctx, roomID).Scan(&stateSnapshotNID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, errors.New("room not found")
|
||||
}
|
||||
return stateSnapshotNID, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ type Rooms interface {
|
|||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||
SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error)
|
||||
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||
SelectStateSnapshotNID(ctx context.Context, roomID string) (types.StateSnapshotNID, error)
|
||||
}
|
||||
|
||||
type Transactions interface {
|
||||
|
|
|
|||
Loading…
Reference in a new issue