WIP Event rejection

This commit is contained in:
Kegan Dougal 2020-09-15 16:10:56 +01:00
parent 965f068d1a
commit 225fb1576b
8 changed files with 31 additions and 22 deletions

View file

@ -46,10 +46,11 @@ func (r *Inputer) processRoomEvent(
// Check that the event passes authentication checks and work out // Check that the event passes authentication checks and work out
// the numeric IDs for the auth events. // the numeric IDs for the auth events.
isRejected := false
authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event, rejecting event")
return isRejected = true
} }
// If we don't have a transaction ID then get one. // If we don't have a transaction ID then get one.
@ -65,10 +66,14 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return "", fmt.Errorf("r.DB.StoreEvent: %w", err) return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
} }
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
if isRejected {
return event.EventID(), nil
}
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() { if redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event) r, rerr := eventutil.RedactEvent(redactionEvent, &event)

View file

@ -535,7 +535,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
var stateAtEvent types.StateAtEvent var stateAtEvent types.StateAtEvent
var redactedEventID string var redactedEventID string
var redactionEvent *gomatrixserverlib.Event var redactionEvent *gomatrixserverlib.Event
roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue

View file

@ -140,14 +140,6 @@ func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, event
return return
} }
func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage {
result := make([]json.RawMessage, len(events))
for i := range events {
result[i] = events[i].JSON()
}
return result
}
func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent {
t.Helper() t.Helper()
hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) hs := make([]gomatrixserverlib.HeaderedEvent, len(events))

View file

@ -70,6 +70,7 @@ type Database interface {
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent( StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
isRejected bool,
) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database

View file

@ -65,13 +65,14 @@ CREATE TABLE IF NOT EXISTS roomserver_events (
-- Needed for setting reference hashes when sending new events. -- Needed for setting reference hashes when sending new events.
reference_sha256 BYTEA NOT NULL, reference_sha256 BYTEA NOT NULL,
-- A list of numeric IDs for events that can authenticate this event. -- A list of numeric IDs for events that can authenticate this event.
auth_event_nids BIGINT[] NOT NULL auth_event_nids BIGINT[] NOT NULL,
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
); );
` `
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" + "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" +
" DO NOTHING" + " DO NOTHING" +
" RETURNING event_nid, state_snapshot_nid" " RETURNING event_nid, state_snapshot_nid"
@ -174,12 +175,14 @@ func (s *eventStatements) InsertEvent(
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
depth int64, depth int64,
isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.insertEventStmt.QueryRowContext( err := s.insertEventStmt.QueryRowContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
isRejected,
).Scan(&eventNID, &stateNID) ).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }

View file

@ -382,7 +382,7 @@ func (d *Database) GetLatestEventsForUpdate(
// nolint:gocyclo // nolint:gocyclo
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool,
) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -446,6 +446,7 @@ func (d *Database) StoreEvent(
event.EventReference().EventSHA256, event.EventReference().EventSHA256,
authEventNIDs, authEventNIDs,
event.Depth(), event.Depth(),
isRejected,
); err != nil { ); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID // We've already inserted the event so select the numeric event ID
@ -459,7 +460,9 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
}
return nil return nil
}) })
if err != nil { if err != nil {

View file

@ -41,13 +41,14 @@ const eventsSchema = `
depth INTEGER NOT NULL, depth INTEGER NOT NULL,
event_id TEXT NOT NULL UNIQUE, event_id TEXT NOT NULL UNIQUE,
reference_sha256 BLOB NOT NULL, reference_sha256 BLOB NOT NULL,
auth_event_nids TEXT NOT NULL DEFAULT '[]' auth_event_nids TEXT NOT NULL DEFAULT '[]',
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
); );
` `
const insertEventSQL = ` const insertEventSQL = `
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
` `
@ -150,13 +151,14 @@ func (s *eventStatements) InsertEvent(
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
depth int64, depth int64,
isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID // attempt to insert: the last_row_id is the event NID
var eventNID int64 var eventNID int64
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
result, err := insertStmt.ExecContext( result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected,
) )
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err

View file

@ -34,7 +34,10 @@ type EventStateKeys interface {
} }
type Events interface { type Events interface {
InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) InsertEvent(
ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string,
referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error)
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError