Compare commits

...

7 commits

Author SHA1 Message Date
Till Faelligen 8588850169
Merge branch 's7evink/resolve-state' of github.com:matrix-org/dendrite into s7evink/resolve-state 2023-11-15 19:50:49 +01:00
Till Faelligen f14517ae43
Merge branch 'main' of github.com:matrix-org/dendrite into s7evink/resolve-state 2023-11-15 19:50:37 +01:00
Till Faelligen c1ca1cca45
Also update the snapshot on the latest events 2023-11-15 19:50:21 +01:00
Till Faelligen 2a09c2631b
Show difference after fixing state 2023-10-31 09:56:15 +01:00
Till Faelligen 16b97f9e6c
Tweaks, use the roomUpdater to avoid a DB deadlock 2023-10-30 10:10:22 +01:00
Till Faelligen 64f70f98cf
Add a -fix parameter 2023-10-30 08:53:02 +01:00
Till Faelligen 8193fc284f
Quick and dirty resolve state for all snapshots 2023-10-27 17:36:40 +02:00
7 changed files with 216 additions and 61 deletions

View file

@ -28,10 +28,13 @@ import (
// //
// Usage: ./resolve-state --roomversion=version snapshot [snapshot ...] // Usage: ./resolve-state --roomversion=version snapshot [snapshot ...]
// e.g. ./resolve-state --roomversion=5 1254 1235 1282 // e.g. ./resolve-state --roomversion=5 1254 1235 1282
// e.g. ./resolve-state -room_id '!abc:localhost'
var roomVersion = flag.String("roomversion", "5", "the room version to parse events as") var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
var filterType = flag.String("filtertype", "", "the event types to filter on") var filterType = flag.String("filtertype", "", "the event types to filter on")
var difference = flag.Bool("difference", false, "whether to calculate the difference between snapshots") var difference = flag.Bool("difference", false, "whether to calculate the difference between snapshots")
var roomID = flag.String("room_id", "", "roomID to get the state for, using this flag ignores any passed snapshot NIDs and calculates the resolved state using ALL state snapshots")
var fixState = flag.Bool("fix", false, "attempt to fix the room state")
// dummyQuerier implements QuerySenderIDAPI. Does **NOT** do any "magic" for pseudoID rooms // dummyQuerier implements QuerySenderIDAPI. Does **NOT** do any "magic" for pseudoID rooms
// to avoid having to "start" a full roomserver API. // to avoid having to "start" a full roomserver API.
@ -58,8 +61,6 @@ func main() {
args := flag.Args() args := flag.Args()
fmt.Println("Room version", *roomVersion)
snapshotNIDs := []types.StateSnapshotNID{} snapshotNIDs := []types.StateSnapshotNID{}
for _, arg := range args { for _, arg := range args {
if i, err := strconv.Atoi(arg); err == nil { if i, err := strconv.Atoi(arg); err == nil {
@ -89,70 +90,37 @@ func main() {
roomInfo := &types.RoomInfo{ roomInfo := &types.RoomInfo{
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion), RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
} }
if *roomID != "" {
roomInfo, err = roomserverDB.RoomInfo(ctx, *roomID)
if err != nil {
panic(err)
}
if roomInfo == nil {
panic("no room found")
}
snapshotNIDs, err = roomserverDB.GetAllStateSnapshots(ctx, roomInfo.RoomNID)
if err != nil {
panic(err)
}
}
fmt.Println("Room version", roomInfo.RoomVersion)
stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI) stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI)
fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs") fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs")
if *difference { if *difference {
if len(snapshotNIDs) != 2 { showDifference(ctx, snapshotNIDs, stateres, roomserverDB, roomInfo)
panic("need exactly two state snapshot NIDs to calculate difference")
}
var removed, added []types.StateEntry
removed, added, err = stateres.DifferenceBetweeenStateSnapshots(ctx, snapshotNIDs[0], snapshotNIDs[1])
if err != nil {
panic(err)
}
eventNIDMap := map[types.EventNID]struct{}{}
for _, entry := range append(removed, added...) {
eventNIDMap[entry.EventNID] = struct{}{}
}
eventNIDs := make([]types.EventNID, 0, len(eventNIDMap))
for eventNID := range eventNIDMap {
eventNIDs = append(eventNIDs, eventNID)
}
var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}
events := make(map[types.EventNID]gomatrixserverlib.PDU, len(eventEntries))
for _, entry := range eventEntries {
events[entry.EventNID] = entry.PDU
}
if len(removed) > 0 {
fmt.Println("Removed:")
for _, r := range removed {
event := events[r.EventNID]
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}
if len(removed) > 0 && len(added) > 0 {
fmt.Println()
}
if len(added) > 0 {
fmt.Println("Added:")
for _, a := range added {
event := events[a.EventNID]
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}
return return
} }
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
for _, snapshotNID := range snapshotNIDs {
for i, snapshotNID := range snapshotNIDs {
fmt.Printf("\r \a %d of %d", i, len(snapshotNIDs))
var entries []types.StateEntry var entries []types.StateEntry
entries, err = stateres.LoadStateAtSnapshot(ctx, snapshotNID) entries, err = stateres.LoadStateAtSnapshot(ctx, snapshotNID)
if err != nil { if err != nil {
@ -160,10 +128,11 @@ func main() {
} }
stateEntries = append(stateEntries, entries...) stateEntries = append(stateEntries, entries...)
} }
fmt.Println()
eventNIDMap := map[types.EventNID]struct{}{} eventNIDMap := map[types.EventNID]types.StateEntry{}
for _, entry := range stateEntries { for _, entry := range stateEntries {
eventNIDMap[entry.EventNID] = struct{}{} eventNIDMap[entry.EventNID] = entry
} }
eventNIDs := make([]types.EventNID, 0, len(eventNIDMap)) eventNIDs := make([]types.EventNID, 0, len(eventNIDMap))
@ -179,7 +148,9 @@ func main() {
authEventIDMap := make(map[string]struct{}) authEventIDMap := make(map[string]struct{})
events := make([]gomatrixserverlib.PDU, len(eventEntries)) events := make([]gomatrixserverlib.PDU, len(eventEntries))
eventIDNIDMap := make(map[string]types.EventNID)
for i := range eventEntries { for i := range eventEntries {
eventIDNIDMap[eventEntries[i].EventID()] = eventEntries[i].EventNID
events[i] = eventEntries[i].PDU events[i] = eventEntries[i].PDU
for _, authEventID := range eventEntries[i].AuthEventIDs() { for _, authEventID := range eventEntries[i].AuthEventIDs() {
authEventIDMap[authEventID] = struct{}{} authEventIDMap[authEventID] = struct{}{}
@ -198,17 +169,22 @@ func main() {
} }
authEvents := make([]gomatrixserverlib.PDU, len(authEventEntries)) authEvents := make([]gomatrixserverlib.PDU, len(authEventEntries))
resolvedRoomID := ""
for i := range authEventEntries { for i := range authEventEntries {
authEvents[i] = authEventEntries[i].PDU authEvents[i] = authEventEntries[i].PDU
if authEvents[i].RoomID().String() != "" {
resolvedRoomID = authEvents[i].RoomID().String()
}
} }
// Get the roomNID // Get the roomNID
roomInfo, err = roomserverDB.RoomInfo(ctx, authEvents[0].RoomID().String()) roomInfo, err = roomserverDB.RoomInfo(ctx, resolvedRoomID)
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Println("Resolving state") fmt.Println("Resolving state")
stateResStart := time.Now()
var resolved Events var resolved Events
resolved, err = gomatrixserverlib.ResolveConflicts( resolved, err = gomatrixserverlib.ResolveConflicts(
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
@ -226,7 +202,7 @@ func main() {
panic(err) panic(err)
} }
fmt.Println("Resolved state contains", len(resolved), "events") fmt.Printf("Resolved state contains %d events (resolution took %s)\n", len(resolved), time.Since(stateResStart))
sort.Sort(resolved) sort.Sort(resolved)
filteringEventType := *filterType filteringEventType := *filterType
count := 0 count := 0
@ -242,6 +218,142 @@ func main() {
fmt.Println() fmt.Println()
fmt.Println("Returned", count, "state events after filtering") fmt.Println("Returned", count, "state events after filtering")
if !*fixState {
return
}
fmt.Println()
fmt.Printf("\t\t !!! WARNING !!!\n")
fmt.Println("Attempting to fix the state of a room can make things even worse.")
fmt.Println("For the best result, please shut down Dendrite to avoid concurrent database changes.")
fmt.Println("If you have missing state events (e.g. users not in a room, missing power levels")
fmt.Println("make sure they would be added by checking the resolved state events above first (or by running without -fix).")
fmt.Println("If you are sure everything looks fine, press Return, if not, press CTRL+c.")
fmt.Scanln()
fmt.Println("Attempting to fix state")
initialSnapshotNID := roomInfo.StateSnapshotNID()
stateEntriesResolved := make([]types.StateEntry, len(resolved))
for i := range resolved {
eventNID := eventIDNIDMap[resolved[i].EventID()]
stateEntriesResolved[i] = eventNIDMap[eventNID]
}
var succeeded bool
roomUpdater, err := roomserverDB.GetRoomUpdater(ctx, roomInfo)
if err != nil {
panic(err)
}
defer sqlutil.EndTransactionWithCheck(roomUpdater, &succeeded, &err)
latestEvents := make([]types.StateAtEventAndReference, 0, len(roomUpdater.LatestEvents()))
for _, event := range roomUpdater.LatestEvents() {
// SetLatestEvents only uses the EventNID, so populate that
latestEvents = append(latestEvents, types.StateAtEventAndReference{
StateAtEvent: types.StateAtEvent{
StateEntry: types.StateEntry{
EventNID: event.EventNID,
},
},
})
}
var lastEventSent []types.Event
lastEventSent, err = roomUpdater.EventsFromIDs(ctx, roomInfo, []string{roomUpdater.LastEventIDSent()})
if err != nil {
fmt.Printf("Error: %s", err)
return
}
if len(lastEventSent) != 1 {
fmt.Printf("Error: expected to get one event from the database but didn't, got %d", len(lastEventSent))
return
}
var newSnapshotNID types.StateSnapshotNID
newSnapshotNID, err = roomUpdater.AddState(ctx, roomInfo.RoomNID, nil, stateEntriesResolved)
if err != nil {
fmt.Printf("Error: %s", err)
return
}
if err = roomUpdater.SetLatestEvents(roomInfo.RoomNID, latestEvents, lastEventSent[0].EventNID, newSnapshotNID); err != nil {
fmt.Printf("Error: %s", err)
return
}
for _, latestEvent := range roomUpdater.LatestEvents() {
if err = roomUpdater.SetState(ctx, latestEvent.EventNID, newSnapshotNID); err != nil {
fmt.Printf("Error: %s", err)
return
}
}
succeeded = true
if err = roomUpdater.Commit(); err != nil {
panic(err)
}
fmt.Printf("Successfully set new snapshot NID %d containing %d state events\n", newSnapshotNID, len(stateEntriesResolved))
showDifference(ctx, []types.StateSnapshotNID{newSnapshotNID, initialSnapshotNID}, stateres, roomserverDB, roomInfo)
}
func showDifference(ctx context.Context, snapshotNIDs []types.StateSnapshotNID, stateres state.StateResolution, roomserverDB storage.Database, roomInfo *types.RoomInfo) {
if len(snapshotNIDs) != 2 {
panic("need exactly two state snapshot NIDs to calculate difference")
}
removed, added, err := stateres.DifferenceBetweeenStateSnapshots(ctx, snapshotNIDs[0], snapshotNIDs[1])
if err != nil {
panic(err)
}
eventNIDMap := map[types.EventNID]struct{}{}
for _, entry := range append(removed, added...) {
eventNIDMap[entry.EventNID] = struct{}{}
}
eventNIDs := make([]types.EventNID, 0, len(eventNIDMap))
for eventNID := range eventNIDMap {
eventNIDs = append(eventNIDs, eventNID)
}
var eventEntries []types.Event
eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}
events := make(map[types.EventNID]gomatrixserverlib.PDU, len(eventEntries))
for _, entry := range eventEntries {
events[entry.EventNID] = entry.PDU
}
if len(removed) > 0 {
fmt.Println("Removed:")
for _, r := range removed {
event := events[r.EventNID]
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}
if len(removed) > 0 && len(added) > 0 {
fmt.Println()
}
if len(added) > 0 {
fmt.Println("Added:")
for _, a := range added {
event := events[a.EventNID]
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}
} }
type Events []gomatrixserverlib.PDU type Events []gomatrixserverlib.PDU

View file

@ -30,6 +30,7 @@ import (
type Database interface { type Database interface {
UserRoomKeys UserRoomKeys
GetAllStateSnapshots(ctx context.Context, roomNID types.RoomNID) ([]types.StateSnapshotNID, error)
// Do we support processing input events for more than one room at a time? // Do we support processing input events for more than one room at a time?
SupportsConcurrentRoomInputs() bool SupportsConcurrentRoomInputs() bool
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)

View file

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -114,11 +115,14 @@ WHERE re.event_id = ANY($2)
` `
const getAllSnapshotsSQL = "SELECT state_snapshot_nid FROM roomserver_state_snapshots WHERE room_nid = $1"
type stateSnapshotStatements struct { type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt
bulkSelectStateForHistoryVisibilityStmt *sql.Stmt bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt
getAllSnapshotsStmt *sql.Stmt
} }
func CreateStateSnapshotTable(db *sql.DB) error { func CreateStateSnapshotTable(db *sql.DB) error {
@ -134,9 +138,32 @@ func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) {
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
{&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL}, {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL},
{&s.getAllSnapshotsStmt, getAllSnapshotsSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *stateSnapshotStatements) GetAllStateSnapshots(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.StateSnapshotNID, error) {
stmt := sqlutil.TxStmt(txn, s.getAllSnapshotsStmt)
rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows")
nids := make([]types.StateSnapshotNID, 0, 2000)
var nid types.StateSnapshotNID
for rows.Next() {
if err := rows.Scan(&nid); err != nil {
return nil, err
}
nids = append(nids, nid)
}
return nids, rows.Err()
}
func (s *stateSnapshotStatements) InsertState( func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {

View file

@ -23,6 +23,7 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -36,6 +37,10 @@ type Database struct {
shared.Database shared.Database
} }
func (d *Database) GetAllStateSnapshots(ctx context.Context, roomNID types.RoomNID) ([]types.StateSnapshotNID, error) {
return d.StateSnapshotTable.GetAllStateSnapshots(ctx, nil, roomNID)
}
// Open a postgres database. // Open a postgres database.
func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database

View file

@ -184,3 +184,7 @@ func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID(
return res, rows.Err() return res, rows.Err()
} }
func (s *stateSnapshotStatements) GetAllStateSnapshots(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.StateSnapshotNID, error) {
return []types.StateSnapshotNID{}, fmt.Errorf("not implemented")
}

View file

@ -35,6 +35,10 @@ type Database struct {
shared.Database shared.Database
} }
func (d *Database) GetAllStateSnapshots(ctx context.Context, roomNID types.RoomNID) ([]types.StateSnapshotNID, error) {
return []types.StateSnapshotNID{}, fmt.Errorf("not implemented")
}
// Open a sqlite database. // Open a sqlite database.
func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) { func Open(ctx context.Context, conMan *sqlutil.Connections, dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) (*Database, error) {
var d Database var d Database

View file

@ -96,6 +96,8 @@ type StateSnapshot interface {
BulkSelectMembershipForHistoryVisibility( BulkSelectMembershipForHistoryVisibility(
ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
) (map[string]*types.HeaderedEvent, error) ) (map[string]*types.HeaderedEvent, error)
GetAllStateSnapshots(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.StateSnapshotNID, error)
} }
type StateBlock interface { type StateBlock interface {