diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index f8d888b71..a6438d5a7 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -113,60 +113,7 @@ func main() { fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs") if *difference { - if len(snapshotNIDs) != 2 { - panic("need exactly two state snapshot NIDs to calculate difference") - } - var removed, added []types.StateEntry - removed, added, err = stateres.DifferenceBetweeenStateSnapshots(ctx, snapshotNIDs[0], snapshotNIDs[1]) - if err != nil { - panic(err) - } - - eventNIDMap := map[types.EventNID]struct{}{} - for _, entry := range append(removed, added...) { - eventNIDMap[entry.EventNID] = struct{}{} - } - - eventNIDs := make([]types.EventNID, 0, len(eventNIDMap)) - for eventNID := range eventNIDMap { - eventNIDs = append(eventNIDs, eventNID) - } - - var eventEntries []types.Event - eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) - if err != nil { - panic(err) - } - - events := make(map[types.EventNID]gomatrixserverlib.PDU, len(eventEntries)) - for _, entry := range eventEntries { - events[entry.EventNID] = entry.PDU - } - - if len(removed) > 0 { - fmt.Println("Removed:") - for _, r := range removed { - event := events[r.EventNID] - fmt.Println() - fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey()) - fmt.Printf(" %s\n", string(event.Content())) - } - } - - if len(removed) > 0 && len(added) > 0 { - fmt.Println() - } - - if len(added) > 0 { - fmt.Println("Added:") - for _, a := range added { - event := events[a.EventNID] - fmt.Println() - fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey()) - fmt.Printf(" %s\n", string(event.Content())) - } - } - + showDifference(ctx, snapshotNIDs, stateres, roomserverDB, roomInfo) return } @@ -222,12 +169,16 @@ func main() { } authEvents := make([]gomatrixserverlib.PDU, len(authEventEntries)) + resolvedRoomID := "" for i := range authEventEntries { authEvents[i] = authEventEntries[i].PDU + if authEvents[i].RoomID().String() != "" { + resolvedRoomID = authEvents[i].RoomID().String() + } } // Get the roomNID - roomInfo, err = roomserverDB.RoomInfo(ctx, authEvents[0].RoomID().String()) + roomInfo, err = roomserverDB.RoomInfo(ctx, resolvedRoomID) if err != nil { panic(err) } @@ -283,6 +234,8 @@ func main() { fmt.Println("Attempting to fix state") + initialSnapshotNID := roomInfo.StateSnapshotNID() + stateEntriesResolved := make([]types.StateEntry, len(resolved)) for i := range resolved { eventNID := eventIDNIDMap[resolved[i].EventID()] @@ -339,8 +292,68 @@ func main() { } succeeded = true + if err = roomUpdater.Commit(); err != nil { + panic(err) + } + fmt.Printf("Successfully set new snapshot NID %d containing %d state events\n", newSnapshotNID, len(stateEntriesResolved)) + showDifference(ctx, []types.StateSnapshotNID{newSnapshotNID, initialSnapshotNID}, stateres, roomserverDB, roomInfo) - fmt.Printf("Successfully set new snapshot NID %d containing %d state events", newSnapshotNID, len(stateEntriesResolved)) +} + +func showDifference(ctx context.Context, snapshotNIDs []types.StateSnapshotNID, stateres state.StateResolution, roomserverDB storage.Database, roomInfo *types.RoomInfo) { + if len(snapshotNIDs) != 2 { + panic("need exactly two state snapshot NIDs to calculate difference") + } + + removed, added, err := stateres.DifferenceBetweeenStateSnapshots(ctx, snapshotNIDs[0], snapshotNIDs[1]) + if err != nil { + panic(err) + } + + eventNIDMap := map[types.EventNID]struct{}{} + for _, entry := range append(removed, added...) { + eventNIDMap[entry.EventNID] = struct{}{} + } + + eventNIDs := make([]types.EventNID, 0, len(eventNIDMap)) + for eventNID := range eventNIDMap { + eventNIDs = append(eventNIDs, eventNID) + } + + var eventEntries []types.Event + eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs) + if err != nil { + panic(err) + } + + events := make(map[types.EventNID]gomatrixserverlib.PDU, len(eventEntries)) + for _, entry := range eventEntries { + events[entry.EventNID] = entry.PDU + } + + if len(removed) > 0 { + fmt.Println("Removed:") + for _, r := range removed { + event := events[r.EventNID] + fmt.Println() + fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey()) + fmt.Printf(" %s\n", string(event.Content())) + } + } + + if len(removed) > 0 && len(added) > 0 { + fmt.Println() + } + + if len(added) > 0 { + fmt.Println("Added:") + for _, a := range added { + event := events[a.EventNID] + fmt.Println() + fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey()) + fmt.Printf(" %s\n", string(event.Content())) + } + } } type Events []gomatrixserverlib.PDU