Fix state lookup

This commit is contained in:
Neil Alexander 2020-08-28 15:04:23 +01:00
parent 08b6d866f5
commit ec60c49d24
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
5 changed files with 48 additions and 77 deletions

View file

@ -16,13 +16,14 @@ package internal
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
) )
// checkForSoftFail returns true if the event should be soft-failed // checkForSoftFail returns true if the event should be soft-failed
@ -33,10 +34,27 @@ func checkForSoftFail(
db storage.Database, db storage.Database,
event gomatrixserverlib.HeaderedEvent, event gomatrixserverlib.HeaderedEvent,
) (bool, error) { ) (bool, error) {
// Look up the current state of the room. // Work out if the room exists.
authStateEntries, err := db.StateForRoomID(ctx, event.RoomID()) roomNID, err := db.RoomNID(ctx, event.RoomID())
if roomNID == 0 || err == sql.ErrNoRows {
return false, nil
}
if err != nil { if err != nil {
return false, fmt.Errorf("db.SnapshotNIDFromRoomID: %w", err) return false, fmt.Errorf("db.RoomNID: %w", err)
}
// If the room exist, gets the current state snapshot.
_, stateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID)
if err != nil {
return true, fmt.Errorf("r.DB.LatestEventIDs: %w", err)
}
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db)
authStateEntries, err := roomState.LoadStateAtSnapshot(ctx, stateSnapshotNID)
if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
} }
// As a special case, it's possible that the room will have no // As a special case, it's possible that the room will have no
@ -50,33 +68,17 @@ func checkForSoftFail(
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
logger := logrus.WithField("room_id", event.RoomID())
logger.Infof("EVENT %s TYPE %s", event.EventID(), event.Type())
logger.Infof("STATE NEEDED:")
for _, tuple := range stateNeeded.Tuples() {
logger.Infof("* %q %q", tuple.EventType, tuple.StateKey)
}
// Load the actual auth events from the database. // Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil { if err != nil {
return false, fmt.Errorf("loadAuthEvents: %w", err) return true, fmt.Errorf("loadAuthEvents: %w", err)
}
logger.Info("STATE RETRIEVED:")
for _, e := range authEvents.events {
logger.Infof("* %q %q -> %s", e.Type(), *e.StateKey(), string(e.Content()))
} }
// Check if the event is allowed. // Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
// return true, nil // return true, nil
logger.Info("SOFT-FAIL") return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
return false, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
} }
logger.Info("ALLOW")
return false, nil return false, nil
} }

View file

@ -111,6 +111,15 @@ func (r *RoomserverInternalAPI) processRoomEvent(
} }
} }
if softfail {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).Info("Stored soft-failed event")
return event.EventID(), nil
}
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
roomNID, // room NID to update roomNID, // room NID to update
@ -118,7 +127,6 @@ func (r *RoomserverInternalAPI) processRoomEvent(
event, // event event, // event
input.SendAsServer, // send as server input.SendAsServer, // send as server
input.TransactionID, // transaction ID input.TransactionID, // transaction ID
softfail, // event soft-failed?
); err != nil { ); err != nil {
return "", fmt.Errorf("r.updateLatestEvents: %w", err) return "", fmt.Errorf("r.updateLatestEvents: %w", err)
} }

View file

@ -54,7 +54,6 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
sendAsServer string, sendAsServer string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
softfail bool,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil { if err != nil {
@ -72,7 +71,6 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
event: event, event: event,
sendAsServer: sendAsServer, sendAsServer: sendAsServer,
transactionID: transactionID, transactionID: transactionID,
softfail: softfail,
} }
if err = u.doUpdateLatestEvents(); err != nil { if err = u.doUpdateLatestEvents(); err != nil {
@ -95,7 +93,6 @@ type latestEventsUpdater struct {
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event event gomatrixserverlib.Event
transactionID *api.TransactionID transactionID *api.TransactionID
softfail bool
// Which server to send this event as. // Which server to send this event as.
sendAsServer string sendAsServer string
// The eventID of the event that was processed before this one. // The eventID of the event that was processed before this one.
@ -181,7 +178,6 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
return fmt.Errorf("u.api.updateMemberships: %w", err) return fmt.Errorf("u.api.updateMemberships: %w", err)
} }
if !u.softfail {
update, err := u.makeOutputNewRoomEvent() update, err := u.makeOutputNewRoomEvent()
if err != nil { if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
@ -199,7 +195,6 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return fmt.Errorf("u.api.WriteOutputEvents: %w", err) return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
} }
}
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
return fmt.Errorf("u.updater.SetLatestEvents: %w", err) return fmt.Errorf("u.updater.SetLatestEvents: %w", err)

View file

@ -64,8 +64,6 @@ type Database interface {
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string // Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
// Look up current state for a room ID.
StateForRoomID(ctx context.Context, roomID string) ([]types.StateEntry, error)
// Look up a room version from the room NID. // Look up a room version from the room NID.
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.

View file

@ -174,38 +174,6 @@ func (d *Database) SnapshotNIDFromEventID(
return stateNID, err return stateNID, err
} }
func (d *Database) StateForRoomID(
ctx context.Context, roomID string,
) ([]types.StateEntry, error) {
roomNID, err := d.RoomNIDExcludingStubs(ctx, roomID)
if err != nil {
return nil, err
}
if roomNID == 0 {
return nil, nil
}
_, stateSnapshotNID, _, err := d.LatestEventIDs(ctx, roomNID)
if err != nil || stateSnapshotNID == 0 {
return nil, err
}
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateSnapshotNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEventLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil {
return nil, err
}
stateEventNIDs := []types.StateEntry{}
for _, stateEventList := range stateEventLists {
stateEventNIDs = append(stateEventNIDs, stateEventList.StateEntries...)
}
return stateEventNIDs, nil
}
func (d *Database) EventIDs( func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID, ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) { ) (map[types.EventNID]string, error) {