diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 746a324fa..26454ef64 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -381,11 +381,50 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact} err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // if we are redacting a state event, also update the current_room_state table + if newEvent.StateKey() != nil { + if err = d.redactCurrentStateEvent(ctx, txn, newEvent, querier); err != nil { + return err + } + } return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) }) return err } +// redactCurrentStateEvent updates the JSON data in the current_room_state table +func (d *Database) redactCurrentStateEvent(ctx context.Context, txn *sql.Tx, newEvent *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error { + // resolve the state key, which may be user pseudoID + if *newEvent.StateKey() != "" { + validRoomID, err := spec.NewRoomID(newEvent.RoomID()) + if err != nil { + return err + } + var sku *spec.UserID + stateKey := newEvent.StateKey() + sku, err = querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey)) + if err == nil && sku != nil { + sKey := sku.String() + newEvent.StateKeyResolved = &sKey + } + } + + // get the current stream position of the event + streamEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, []string{newEvent.EventID()}) + if err == nil && len(streamEvents) > 0 { + var membershipPtr *string + var membership string + membership, err = streamEvents[0].Membership() + if err == nil { + membershipPtr = &membership + } + if err = d.CurrentRoomState.UpsertRoomState(ctx, txn, newEvent, membershipPtr, streamEvents[0].StreamPosition); err != nil { + return err + } + } + return nil +} + // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // Returns a map of room ID to list of events. func (d *Database) fetchStateEvents( diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index f57b0d618..a8f1723e7 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -994,7 +994,7 @@ func TestRedaction(t *testing.T) { alice := test.NewUser(t) room := test.NewRoom(t, alice) - redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) + redactedEvent := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "join", "displayname": "alice"}, test.WithStateKey(alice.ID)) redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()}) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := MustCreateDatabase(t, dbType) @@ -1028,5 +1028,25 @@ func TestRedaction(t *testing.T) { if depth.Exists() { t.Error("unexpected auth_events in redacted event") } + + dbTxn, err := db.NewDatabaseTransaction(context.Background()) + if err != nil { + t.Fatal(err) + } + filter := synctypes.DefaultStateFilter() + wantTypes := []string{spec.MRoomMember} + filter.Types = &wantTypes + evs, err = dbTxn.CurrentRoomState.SelectCurrentState(context.Background(), nil, redactedEvent.RoomID(), &filter, nil) + if err != nil { + t.Fatal(err) + } + if count := len(evs); count != 1 { + t.Fatalf("expected 1 event, got %d", count) + } + // we expect that the displayname does not exist anymore + displayname := gjson.GetBytes(evs[0].Content(), "displayname") + if displayname.Exists() { + t.Fatal("expected displayname to be redacted, but wasn't") + } }) }