mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-23 06:43:09 -06:00
Fix state lookup
This commit is contained in:
parent
08b6d866f5
commit
ec60c49d24
|
|
@ -16,13 +16,14 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkForSoftFail returns true if the event should be soft-failed
|
// checkForSoftFail returns true if the event should be soft-failed
|
||||||
|
|
@ -33,10 +34,27 @@ func checkForSoftFail(
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
event gomatrixserverlib.HeaderedEvent,
|
event gomatrixserverlib.HeaderedEvent,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
// Look up the current state of the room.
|
// Work out if the room exists.
|
||||||
authStateEntries, err := db.StateForRoomID(ctx, event.RoomID())
|
roomNID, err := db.RoomNID(ctx, event.RoomID())
|
||||||
|
if roomNID == 0 || err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("db.SnapshotNIDFromRoomID: %w", err)
|
return false, fmt.Errorf("db.RoomNID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the room exist, gets the current state snapshot.
|
||||||
|
_, stateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID)
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("r.DB.LatestEventIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then get the state entries for the current state snapshot.
|
||||||
|
// We'll use this to check if the event is allowed right now.
|
||||||
|
roomState := state.NewStateResolution(db)
|
||||||
|
authStateEntries, err := roomState.LoadStateAtSnapshot(ctx, stateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// As a special case, it's possible that the room will have no
|
// As a special case, it's possible that the room will have no
|
||||||
|
|
@ -50,33 +68,17 @@ func checkForSoftFail(
|
||||||
// Work out which of the state events we actually need.
|
// Work out which of the state events we actually need.
|
||||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
|
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
|
||||||
|
|
||||||
logger := logrus.WithField("room_id", event.RoomID())
|
|
||||||
logger.Infof("EVENT %s TYPE %s", event.EventID(), event.Type())
|
|
||||||
|
|
||||||
logger.Infof("STATE NEEDED:")
|
|
||||||
for _, tuple := range stateNeeded.Tuples() {
|
|
||||||
logger.Infof("* %q %q", tuple.EventType, tuple.StateKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the actual auth events from the database.
|
// Load the actual auth events from the database.
|
||||||
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("loadAuthEvents: %w", err)
|
return true, fmt.Errorf("loadAuthEvents: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("STATE RETRIEVED:")
|
|
||||||
for _, e := range authEvents.events {
|
|
||||||
logger.Infof("* %q %q -> %s", e.Type(), *e.StateKey(), string(e.Content()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event is allowed.
|
// Check if the event is allowed.
|
||||||
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
|
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
|
||||||
// return true, nil
|
// return true, nil
|
||||||
logger.Info("SOFT-FAIL")
|
return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
|
||||||
return false, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("ALLOW")
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,15 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if softfail {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event_id": event.EventID(),
|
||||||
|
"type": event.Type(),
|
||||||
|
"room": event.RoomID(),
|
||||||
|
}).Info("Stored soft-failed event")
|
||||||
|
return event.EventID(), nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = r.updateLatestEvents(
|
if err = r.updateLatestEvents(
|
||||||
ctx, // context
|
ctx, // context
|
||||||
roomNID, // room NID to update
|
roomNID, // room NID to update
|
||||||
|
|
@ -118,7 +127,6 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
event, // event
|
event, // event
|
||||||
input.SendAsServer, // send as server
|
input.SendAsServer, // send as server
|
||||||
input.TransactionID, // transaction ID
|
input.TransactionID, // transaction ID
|
||||||
softfail, // event soft-failed?
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
|
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,6 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
sendAsServer string,
|
sendAsServer string,
|
||||||
transactionID *api.TransactionID,
|
transactionID *api.TransactionID,
|
||||||
softfail bool,
|
|
||||||
) (err error) {
|
) (err error) {
|
||||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -72,7 +71,6 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
event: event,
|
event: event,
|
||||||
sendAsServer: sendAsServer,
|
sendAsServer: sendAsServer,
|
||||||
transactionID: transactionID,
|
transactionID: transactionID,
|
||||||
softfail: softfail,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.doUpdateLatestEvents(); err != nil {
|
if err = u.doUpdateLatestEvents(); err != nil {
|
||||||
|
|
@ -95,7 +93,6 @@ type latestEventsUpdater struct {
|
||||||
stateAtEvent types.StateAtEvent
|
stateAtEvent types.StateAtEvent
|
||||||
event gomatrixserverlib.Event
|
event gomatrixserverlib.Event
|
||||||
transactionID *api.TransactionID
|
transactionID *api.TransactionID
|
||||||
softfail bool
|
|
||||||
// Which server to send this event as.
|
// Which server to send this event as.
|
||||||
sendAsServer string
|
sendAsServer string
|
||||||
// The eventID of the event that was processed before this one.
|
// The eventID of the event that was processed before this one.
|
||||||
|
|
@ -181,24 +178,22 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
return fmt.Errorf("u.api.updateMemberships: %w", err)
|
return fmt.Errorf("u.api.updateMemberships: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !u.softfail {
|
update, err := u.makeOutputNewRoomEvent()
|
||||||
update, err := u.makeOutputNewRoomEvent()
|
if err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
||||||
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
|
}
|
||||||
}
|
updates = append(updates, *update)
|
||||||
updates = append(updates, *update)
|
|
||||||
|
|
||||||
// Send the event to the output logs.
|
// Send the event to the output logs.
|
||||||
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
|
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
|
||||||
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
|
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
|
||||||
// the write to the output log succeeds)
|
// the write to the output log succeeds)
|
||||||
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
|
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
|
||||||
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
|
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
|
||||||
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
|
||||||
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
// necessary bookkeeping we'll keep the event sending synchronous for now.
|
||||||
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
|
||||||
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,6 @@ type Database interface {
|
||||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// Look up snapshot NID for an event ID string
|
// Look up snapshot NID for an event ID string
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
// Look up current state for a room ID.
|
|
||||||
StateForRoomID(ctx context.Context, roomID string) ([]types.StateEntry, error)
|
|
||||||
// Look up a room version from the room NID.
|
// Look up a room version from the room NID.
|
||||||
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
||||||
|
|
|
||||||
|
|
@ -174,38 +174,6 @@ func (d *Database) SnapshotNIDFromEventID(
|
||||||
return stateNID, err
|
return stateNID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateForRoomID(
|
|
||||||
ctx context.Context, roomID string,
|
|
||||||
) ([]types.StateEntry, error) {
|
|
||||||
roomNID, err := d.RoomNIDExcludingStubs(ctx, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if roomNID == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
_, stateSnapshotNID, _, err := d.LatestEventIDs(ctx, roomNID)
|
|
||||||
if err != nil || stateSnapshotNID == 0 {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateSnapshotNID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
|
||||||
stateEventLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateEventNIDs := []types.StateEntry{}
|
|
||||||
for _, stateEventList := range stateEventLists {
|
|
||||||
stateEventNIDs = append(stateEventNIDs, stateEventList.StateEntries...)
|
|
||||||
}
|
|
||||||
return stateEventNIDs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) EventIDs(
|
func (d *Database) EventIDs(
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
) (map[types.EventNID]string, error) {
|
) (map[types.EventNID]string, error) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue