Also copy room tags to new room

This commit is contained in:
Till Faelligen 2022-11-07 09:16:53 +01:00
parent dc275427b5
commit 94050db8e1
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
2 changed files with 29 additions and 9 deletions

View file

@ -759,4 +759,6 @@ Can filter rooms/{roomId}/members
Current state appears in timeline in private history with many messages after Current state appears in timeline in private history with many messages after
AS can publish rooms in their own list AS can publish rooms in their own list
AS and main public room lists are separate AS and main public room lists are separate
/upgrade preserves direct room state /upgrade preserves direct room state
local user has tags copied to the new room
remote user has tags copied to the new room

View file

@ -2,7 +2,9 @@ package consumers
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@ -190,20 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers { for _, membership := range localMembers {
// Copy any existing push rules from old -> new room // Copy any existing push rules from old -> new room
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership); err != nil { if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
return err return err
} }
// preserve m.direct room state // preserve m.direct room state
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership, roomSize); err != nil { if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
return err
}
// copy existing m.tag entries, if any
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID string, newRoomID string, membership *localMembership) error { func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
pushRules, err := s.db.QueryPushRules(ctx, membership.Localpart) pushRules, err := s.db.QueryPushRules(ctx, localpart)
if err != nil { if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err) return fmt.Errorf("failed to query pushrules for user: %w", err)
} }
@ -222,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID s
if err != nil { if err != nil {
return err return err
} }
if err = s.db.SaveAccountData(ctx, membership.Localpart, "", "m.push_rules", rules); err != nil { if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err) return fmt.Errorf("failed to update pushrules: %w", err)
} }
} }
@ -230,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID s
} }
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID string, membership *localMembership, roomSize int) error { func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
// this is most likely not a DM, so skip updating m.direct state // this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 { if roomSize > 2 {
return nil return nil
} }
// Get direct message state // Get direct message state
directChatsRaw, err := s.db.GetAccountDataByType(ctx, membership.Localpart, "", "m.direct") directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
if err != nil { if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err) return fmt.Errorf("failed to get m.direct from database: %w", err)
} }
@ -260,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
if err != nil { if err != nil {
return true return true
} }
if err = s.db.SaveAccountData(ctx, membership.Localpart, "", "m.direct", data); err != nil { if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
return true return true
} }
} }
@ -272,6 +279,17 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
return nil return nil
} }
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if tag == nil {
return nil
}
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
}
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil { if err != nil {