diff --git a/sytest-whitelist b/sytest-whitelist index c2ed60b9f..bb4f0a279 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -759,4 +759,6 @@ Can filter rooms/{roomId}/members Current state appears in timeline in private history with many messages after AS can publish rooms in their own list AS and main public room lists are separate -/upgrade preserves direct room state \ No newline at end of file +/upgrade preserves direct room state +local user has tags copied to the new room +remote user has tags copied to the new room \ No newline at end of file diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index c818e1ad3..b6b30a095 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -2,7 +2,9 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "errors" "fmt" "strings" "sync" @@ -190,20 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { for _, membership := range localMembers { // Copy any existing push rules from old -> new room - if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership); err != nil { + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { return err } // preserve m.direct room state - if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership, roomSize); err != nil { + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + return err + } + + // copy existing m.tag entries, if any + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { return err } } return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID string, newRoomID string, membership *localMembership) error { - pushRules, err := s.db.QueryPushRules(ctx, membership.Localpart) +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) } @@ -222,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID s if err != nil { return err } - if err = s.db.SaveAccountData(ctx, membership.Localpart, "", "m.push_rules", rules); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { return fmt.Errorf("failed to update pushrules: %w", err) } } @@ -230,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID s } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID string, membership *localMembership, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil } // Get direct message state - directChatsRaw, err := s.db.GetAccountDataByType(ctx, membership.Localpart, "", "m.direct") + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") if err != nil { return fmt.Errorf("failed to get m.direct from database: %w", err) } @@ -260,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, if err != nil { return true } - if err = s.db.SaveAccountData(ctx, membership.Localpart, "", "m.direct", data); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { return true } } @@ -272,6 +279,17 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + if tag == nil { + return nil + } + return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) +} + func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil {