diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 788c880eb..26a97c877 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -447,29 +447,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - // un-publish old room - if err := r.DB.PublishRoom(ctx, oldRoomID, "", "", false); err != nil { - return fmt.Errorf("failed to unpublish room: %w", err) - } - // publish new room - if err := r.DB.PublishRoom(ctx, newRoomID, "", "", true); err != nil { - return fmt.Errorf("failed to publish room: %w", err) - } - - aliases, err := r.DB.GetAliasesForRoomID(ctx, oldRoomID) - if err != nil { - return fmt.Errorf("failed to get room aliases: %w", err) - } - - for _, alias := range aliases { - if err = r.DB.RemoveRoomAlias(ctx, alias); err != nil { - fmt.Errorf("failed to remove room alias: %w", err) - } - if err = r.DB.SetRoomAlias(ctx, alias, newRoomID, event.Sender()); err != nil { - return fmt.Errorf("failed to set room alias: %w", err) - } - } - return nil + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) } // processStateBefore works out what the state is before the event and diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c39a8cbba..06db4b2d8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,4 +172,5 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4455ec3bf..734023d1f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1386,6 +1386,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // un-publish old room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } + + // Migrate any existing room aliases + aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get room aliases: %w", err) + } + + for _, alias := range aliases { + if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil { + fmt.Errorf("failed to remove room alias: %w", err) + } + if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil { + return fmt.Errorf("failed to set room alias: %w", err) + } + } + return nil + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package!