Redact the event in the current room state table

This commit is contained in:
Till Faelligen 2023-06-30 14:32:24 +02:00
parent 8b5afcf680
commit 552eaf2940
No known key found for this signature in database
GPG key ID: ACCDC9606D472758
2 changed files with 60 additions and 1 deletions

View file

@ -381,11 +381,50 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact} newEvent := &rstypes.HeaderedEvent{PDU: eventToRedact}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// if we are redacting a state event, also update the current_room_state table
if newEvent.StateKey() != nil {
if err = d.redactCurrentStateEvent(ctx, txn, newEvent, querier); err != nil {
return err
}
}
return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
}) })
return err return err
} }
// redactCurrentStateEvent updates the JSON data in the current_room_state table
func (d *Database) redactCurrentStateEvent(ctx context.Context, txn *sql.Tx, newEvent *rstypes.HeaderedEvent, querier api.QuerySenderIDAPI) error {
// resolve the state key, which may be user pseudoID
if *newEvent.StateKey() != "" {
validRoomID, err := spec.NewRoomID(newEvent.RoomID())
if err != nil {
return err
}
var sku *spec.UserID
stateKey := newEvent.StateKey()
sku, err = querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err == nil && sku != nil {
sKey := sku.String()
newEvent.StateKeyResolved = &sKey
}
}
// get the current stream position of the event
streamEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, []string{newEvent.EventID()})
if err == nil && len(streamEvents) > 0 {
var membershipPtr *string
var membership string
membership, err = streamEvents[0].Membership()
if err == nil {
membershipPtr = &membership
}
if err = d.CurrentRoomState.UpsertRoomState(ctx, txn, newEvent, membershipPtr, streamEvents[0].StreamPosition); err != nil {
return err
}
}
return nil
}
// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. // fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
// Returns a map of room ID to list of events. // Returns a map of room ID to list of events.
func (d *Database) fetchStateEvents( func (d *Database) fetchStateEvents(

View file

@ -994,7 +994,7 @@ func TestRedaction(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
redactedEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"}) redactedEvent := room.CreateAndInsert(t, alice, "m.room.member", map[string]interface{}{"membership": "join", "displayname": "alice"}, test.WithStateKey(alice.ID))
redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()}) redactionEvent := room.CreateEvent(t, alice, spec.MRoomRedaction, map[string]string{"redacts": redactedEvent.EventID()})
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := MustCreateDatabase(t, dbType) db, close := MustCreateDatabase(t, dbType)
@ -1028,5 +1028,25 @@ func TestRedaction(t *testing.T) {
if depth.Exists() { if depth.Exists() {
t.Error("unexpected auth_events in redacted event") t.Error("unexpected auth_events in redacted event")
} }
dbTxn, err := db.NewDatabaseTransaction(context.Background())
if err != nil {
t.Fatal(err)
}
filter := synctypes.DefaultStateFilter()
wantTypes := []string{spec.MRoomMember}
filter.Types = &wantTypes
evs, err = dbTxn.CurrentRoomState.SelectCurrentState(context.Background(), nil, redactedEvent.RoomID(), &filter, nil)
if err != nil {
t.Fatal(err)
}
if count := len(evs); count != 1 {
t.Fatalf("expected 1 event, got %d", count)
}
// we expect that the displayname does not exist anymore
displayname := gjson.GetBytes(evs[0].Content(), "displayname")
if displayname.Exists() {
t.Fatal("expected displayname to be redacted, but wasn't")
}
}) })
} }