From 40fd47957a84ed8ce2245cbd1f32f19ffa81558e Mon Sep 17 00:00:00 2001 From: Cnly Date: Sat, 27 Jul 2019 17:56:26 +0800 Subject: [PATCH] Implement event redaction Signed-off-by: Alex Chen --- clientapi/routing/redact.go | 148 +++++++++++++++ clientapi/routing/routing.go | 11 ++ roomserver/storage/event_json_table.go | 6 +- roomserver/storage/events_table.go | 10 +- roomserver/storage/redactions_table.go | 138 ++++++++++++++ roomserver/storage/sql.go | 2 + roomserver/storage/storage.go | 211 ++++++++++++++++++++-- syncapi/storage/redactions_table.go | 145 +++++++++++++++ syncapi/storage/syncserver.go | 240 ++++++++++++++++++++++++- 9 files changed, 882 insertions(+), 29 deletions(-) create mode 100644 clientapi/routing/redact.go create mode 100644 roomserver/storage/redactions_table.go create mode 100644 syncapi/storage/redactions_table.go diff --git a/clientapi/routing/redact.go b/clientapi/routing/redact.go new file mode 100644 index 000000000..45ab60b43 --- /dev/null +++ b/clientapi/routing/redact.go @@ -0,0 +1,148 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/common/transactions" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// https://matrix.org/docs/spec/client_server/r0.5.0#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid + +type redactRequest struct { + Reason string `json:"reason,omitempty"` +} + +type redactResponse struct { + EventID string `json:"event_id"` +} + +// Redact implements PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId} +func Redact( + req *http.Request, + device *authtypes.Device, + roomID, redactedEventID, txnID string, + cfg config.Dendrite, + queryAPI api.RoomserverQueryAPI, + producer *producers.RoomserverProducer, + txnCache *transactions.Cache, +) util.JSONResponse { + // TODO: Idempotency + + var redactReq redactRequest + if resErr := httputil.UnmarshalJSONRequest(req, &redactReq); resErr != nil { + return *resErr + } + + // Build a redaction event + builder := gomatrixserverlib.EventBuilder{ + Sender: device.UserID, + RoomID: roomID, + Redacts: redactedEventID, + Type: gomatrixserverlib.MRoomRedaction, + } + err := builder.SetContent(redactReq) + if err != nil { + return httputil.LogThenError(req, err) + } + + var queryRes api.QueryLatestEventsAndStateResponse + e, err := common.BuildEvent(req.Context(), &builder, cfg, time.Now(), queryAPI, &queryRes) + if err == common.ErrRoomNoExists { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Room does not exist"), + } + } else if err != nil { + return httputil.LogThenError(req, err) + } + + // Do some basic checks e.g. ensuring the user is in the room and can send m.room.redaction events + stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents)) + for i := range queryRes.StateEvents { + stateEvents[i] = &queryRes.StateEvents[i] + } + provider := gomatrixserverlib.NewAuthEvents(stateEvents) + if err = gomatrixserverlib.Allowed(*e, &provider); err != nil { + // TODO: Is the error returned with suitable HTTP status code? + if _, ok := err.(*gomatrixserverlib.NotAllowed); ok { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(err.Error()), + } + } + + return httputil.LogThenError(req, err) + } + + // Ensure the user can redact the specific event + + eventReq := api.QueryEventsByIDRequest{ + EventIDs: []string{redactedEventID}, + } + var eventResp api.QueryEventsByIDResponse + if err = queryAPI.QueryEventsByID(req.Context(), &eventReq, &eventResp); err != nil { + return httputil.LogThenError(req, err) + } + + if len(eventResp.Events) == 0 { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Event to redact not found"), + } + } + + redactedEvent := eventResp.Events[0] + + if redactedEvent.Sender() != device.UserID { + // TODO: Allow power users to redact others' events + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("You are not allowed to redact this event"), + } + } + + // Send the redaction event + + txnAndDeviceID := api.TransactionID{ + TransactionID: txnID, + DeviceID: device.ID, + } + + // pass the new event to the roomserver and receive the correct event ID + // event ID in case of duplicate transaction is discarded + eventID, err := producer.SendEvents( + req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, &txnAndDeviceID, + ) + if err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: redactResponse{eventID}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 8135e49af..5e0cc3058 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -158,6 +158,17 @@ func Setup( }), ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnID}", + common.MakeAuthAPI("redact", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return Redact(req, device, vars["roomID"], vars["eventID"], vars["txnID"], + cfg, queryAPI, producer, transactionsCache) + }), + ).Methods(http.MethodPut, http.MethodOptions) + r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { return Register(req, accountDB, deviceDB, &cfg) })).Methods(http.MethodPost, http.MethodOptions) diff --git a/roomserver/storage/event_json_table.go b/roomserver/storage/event_json_table.go index b81667d9d..7202b34c3 100644 --- a/roomserver/storage/event_json_table.go +++ b/roomserver/storage/event_json_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -67,9 +68,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { } func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + stmt := common.TxStmt(txn, s.insertEventJSONStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } diff --git a/roomserver/storage/events_table.go b/roomserver/storage/events_table.go index 5bad939fa..57ac7b674 100644 --- a/roomserver/storage/events_table.go +++ b/roomserver/storage/events_table.go @@ -155,7 +155,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { } func (s *eventStatements) insertEvent( - ctx context.Context, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -166,7 +166,8 @@ func (s *eventStatements) insertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.insertEventStmt.QueryRowContext( + stmt := common.TxStmt(txn, s.insertEventStmt) + err := stmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, ).Scan(&eventNID, &stateNID) @@ -174,11 +175,12 @@ func (s *eventStatements) insertEvent( } func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 - err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + stmt := common.TxStmt(txn, s.selectEventStmt) + err := stmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } diff --git a/roomserver/storage/redactions_table.go b/roomserver/storage/redactions_table.go new file mode 100644 index 000000000..f9f88ebf3 --- /dev/null +++ b/roomserver/storage/redactions_table.go @@ -0,0 +1,138 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const redactionsSchema = ` +-- The redactions table holds redactions. +CREATE TABLE IF NOT EXISTS roomserver_redactions ( + -- Local numeric ID for the redaction event. + event_nid BIGINT PRIMARY KEY, + -- String ID for the redacted event. + redacts TEXT NOT NULL, + -- Whether the redaction has been validated. + -- For use of the "accept first, validate later" strategy for rooms >= v3. + -- Should always be TRUE for rooms before v3. + validated BOOLEAN NOT NULL +); + +CREATE INDEX IF NOT EXISTS roomserver_redactions_redacts ON roomserver_redactions(redacts); +` + +const insertRedactionSQL = "" + + "INSERT INTO roomserver_redactions (event_nid, redacts, validated)" + + " VALUES ($1, $2, $3)" + +const bulkSelectRedactionSQL = "" + + "SELECT event_nid, redacts, validated FROM roomserver_redactions" + + " WHERE redacts = ANY($1)" + +const bulkUpdateValidationStatusSQL = "" + + " UPDATE roomserver_redactions SET validated = $2 WHERE event_nid = ANY($1)" + +type redactionStatements struct { + insertRedactionStmt *sql.Stmt + bulkSelectRedactionStmt *sql.Stmt + bulkUpdateValidationStatusStmt *sql.Stmt +} + +func (s *redactionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(redactionsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertRedactionStmt, insertRedactionSQL}, + {&s.bulkSelectRedactionStmt, bulkSelectRedactionSQL}, + {&s.bulkUpdateValidationStatusStmt, bulkUpdateValidationStatusSQL}, + }.prepare(db) +} + +func (s *redactionStatements) insertRedaction( + ctx context.Context, + txn *sql.Tx, + eventNID types.EventNID, + redactsEventID string, + validated bool, +) error { + stmt := common.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), redactsEventID, validated) + return err +} + +func (s *redactionStatements) bulkSelectRedaction( + ctx context.Context, + txn *sql.Tx, + eventIDs []string, +) ( + validated map[string]types.EventNID, + unvalidated map[string]types.EventNID, + err error, +) { + stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, nil, err + } + defer func() { err = rows.Close() }() + + validated = make(map[string]types.EventNID) + unvalidated = make(map[string]types.EventNID) + + var ( + redactedByNID types.EventNID + redactedEventID string + isValidated bool + ) + for rows.Next() { + if err = rows.Scan( + &redactedByNID, + &redactedEventID, + &isValidated, + ); err != nil { + return nil, nil, err + } + if isValidated { + validated[redactedEventID] = redactedByNID + } else { + unvalidated[redactedEventID] = redactedByNID + } + } + if err = rows.Err(); err != nil { + return nil, nil, err + } + + return validated, unvalidated, nil +} + +func (s *redactionStatements) bulkUpdateValidationStatus( + ctx context.Context, + txn *sql.Tx, + eventNIDs []types.EventNID, + newStatus bool, +) error { + stmt := common.TxStmt(txn, s.bulkUpdateValidationStatusStmt) + _, err := stmt.ExecContext(ctx, eventNIDsAsArray(eventNIDs), newStatus) + return err +} diff --git a/roomserver/storage/sql.go b/roomserver/storage/sql.go index 05efa8dd4..39fcb2754 100644 --- a/roomserver/storage/sql.go +++ b/roomserver/storage/sql.go @@ -31,6 +31,7 @@ type statements struct { inviteStatements membershipStatements transactionStatements + redactionStatements } func (s *statements) prepare(db *sql.DB) error { @@ -49,6 +50,7 @@ func (s *statements) prepare(db *sql.DB) error { s.inviteStatements.prepare, s.membershipStatements.prepare, s.transactionStatements.prepare, + s.redactionStatements.prepare, } { if err = prepare(db); err != nil { return err diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index f6c2fccd4..8ef270bb7 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -20,6 +20,7 @@ import ( // Import the postgres database driver. _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -84,26 +85,35 @@ func (d *Database) StoreEvent( } } - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.updateSpecialTablesForEvent( + ctx, txn, &event, eventNID, + ); err != nil { + return err + } + + return d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()) + }) + if err != nil { return 0, types.StateAtEvent{}, err } @@ -167,6 +177,24 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } +func (d *Database) updateSpecialTablesForEvent( + ctx context.Context, + txn *sql.Tx, + event *gomatrixserverlib.Event, + eventNID types.EventNID, +) (err error) { + switch event.Type() { + case gomatrixserverlib.MRoomRedaction: + // TODO: After we support room versioning, set validated = false only for rooms >= v3. + if err = d.statements.insertRedaction( + ctx, txn, eventNID, event.Redacts(), false, + ); err != nil { + return err + } + } + return nil +} + // StateEntriesForEventIDs implements input.EventDatabase func (d *Database) StateEntriesForEventIDs( ctx context.Context, eventIDs []string, @@ -210,7 +238,9 @@ func (d *Database) Events( if err != nil { return nil, err } + results := make([]types.Event, len(eventJSONs)) + eventPointers := make([]*gomatrixserverlib.Event, len(eventJSONs)) for i, eventJSON := range eventJSONs { result := &results[i] result.EventNID = eventJSON.EventNID @@ -219,10 +249,153 @@ func (d *Database) Events( if err != nil { return nil, err } + eventPointers[i] = &result.Event } + + if err = d.applyRedactions(ctx, eventPointers); err != nil { + return nil, err + } + return results, nil } +// applyRedactions applies necessary redactions to the given events. +// It will replace events referenced by the pointers with their redacted versions. +// It will update the validation status in the redactions table if there are +// redaction events newly validated. +func (d *Database) applyRedactions( + ctx context.Context, + eventPointers []*gomatrixserverlib.Event, +) error { + eventIDs := make([]string, len(eventPointers)) + for i, e := range eventPointers { + eventIDs[i] = e.EventID() + } + + validatedRedactions, unvalidatedRedactions, err := d.statements.bulkSelectRedaction(ctx, nil, eventIDs) + if err != nil { + return err + } + + totalPossibleRedactions := len(validatedRedactions) + len(unvalidatedRedactions) + + // Fast path if nothing to redact + if totalPossibleRedactions == 0 { + return nil + } + + redactionNIDToEvent, err := d.fetchRedactionEvents(ctx, validatedRedactions, unvalidatedRedactions) + if err != nil { + return err + } + + eventIDToEventPointer := make(map[string]*gomatrixserverlib.Event, len(eventPointers)) + for _, p := range eventPointers { + eventIDToEventPointer[p.EventID()] = p + } + + if len(unvalidatedRedactions) != 0 { + var newlyValidated map[string]types.EventNID + if newlyValidated, err = d.validateRedactions( + ctx, unvalidatedRedactions, redactionNIDToEvent, eventIDToEventPointer, + ); err != nil { + return err + } + for redactedEventID, redactedByNID := range newlyValidated { + validatedRedactions[redactedEventID] = redactedByNID + } + } + + for redactedEventID, redactedByNID := range validatedRedactions { + redactedEvent := eventIDToEventPointer[redactedEventID] + *redactedEvent = redactedEvent.Redact() + if err = redactedEvent.SetUnsignedField( + "redacted_because", + gomatrixserverlib.ToClientEvent( + *redactionNIDToEvent[redactedByNID], + gomatrixserverlib.FormatAll, + ), + ); err != nil { + return err + } + } + + return nil +} + +func (d *Database) fetchRedactionEvents( + ctx context.Context, + validatedRedactions, unvalidatedRedactions map[string]types.EventNID, +) (redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, err error) { + redactionEventsToFetch := make([]types.EventNID, 0, len(validatedRedactions)+len(unvalidatedRedactions)) + for _, nid := range validatedRedactions { + redactionEventsToFetch = append(redactionEventsToFetch, nid) + } + for _, nid := range unvalidatedRedactions { + redactionEventsToFetch = append(redactionEventsToFetch, nid) + } + + redactionJSONs, err := d.statements.bulkSelectEventJSON(ctx, redactionEventsToFetch) + if err != nil { + return nil, err + } + + redactionNIDToEvent = make(map[types.EventNID]*gomatrixserverlib.Event, len(redactionJSONs)) + for _, redactionJSON := range redactionJSONs { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(redactionJSON.EventJSON, false) + if err != nil { + return nil, err + } + redactionNIDToEvent[redactionJSON.EventNID] = &e + } + + return +} + +func (d *Database) validateRedactions( + ctx context.Context, + unvalidatedRedactions map[string]types.EventNID, + redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, + eventIDToEvent map[string]*gomatrixserverlib.Event, +) (validatedRedactions map[string]types.EventNID, err error) { + validatedRedactions = make(map[string]types.EventNID, len(unvalidatedRedactions)) + + var expectedDomain, redactorDomain gomatrixserverlib.ServerName + for redactedEventID, redactedByNID := range unvalidatedRedactions { + if _, expectedDomain, err = gomatrixserverlib.SplitID( + '@', eventIDToEvent[redactedEventID].Sender(), + ); err != nil { + return nil, err + } + if _, redactorDomain, err = gomatrixserverlib.SplitID( + '@', redactionNIDToEvent[redactedByNID].Sender(), + ); err != nil { + return nil, err + } + + if redactorDomain != expectedDomain { + // TODO: Still allow power users to redact + continue + } + + validatedRedactions[redactedEventID] = redactedByNID + } + + eventNIDs := make([]types.EventNID, 0, len(validatedRedactions)) + for _, nid := range validatedRedactions { + eventNIDs = append(eventNIDs, nid) + } + if err = d.statements.bulkUpdateValidationStatus( + ctx, nil, eventNIDs, true, + ); err != nil { + return nil, err + } + + // TODO: We might want to clear the unvalidated redactions + + return validatedRedactions, nil +} + // AddState implements input.EventDatabase func (d *Database) AddState( ctx context.Context, @@ -276,7 +449,7 @@ func (d *Database) StateEntries( func (d *Database) SnapshotNIDFromEventID( ctx context.Context, eventID string, ) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) + _, stateNID, err := d.statements.selectEvent(ctx, nil, eventID) return stateNID, err } diff --git a/syncapi/storage/redactions_table.go b/syncapi/storage/redactions_table.go new file mode 100644 index 000000000..452a9651e --- /dev/null +++ b/syncapi/storage/redactions_table.go @@ -0,0 +1,145 @@ +// Copyright 2019 Alex Chen +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" +) + +const redactionsSchema = ` +-- The redactions table holds redactions. +CREATE TABLE IF NOT EXISTS syncapi_redactions ( + -- The event ID for the redaction event. + event_id TEXT NOT NULL, + -- The event ID for the redacted event. + redacts TEXT NOT NULL, + -- Whether the redaction has been validated. + -- For use of the "accept first, validate later" strategy for rooms >= v3. + -- Should always be TRUE for rooms before v3. + validated BOOLEAN NOT NULL +); + +CREATE INDEX IF NOT EXISTS syncapi_redactions_redacts ON syncapi_redactions(redacts); +` + +const insertRedactionSQL = "" + + "INSERT INTO syncapi_redactions (event_id, redacts, validated)" + + " VALUES ($1, $2, $3)" + +const bulkSelectRedactionSQL = "" + + "SELECT event_id, redacts, validated FROM syncapi_redactions" + + " WHERE redacts = ANY($1)" + +const bulkUpdateValidationStatusSQL = "" + + " UPDATE syncapi_redactions SET validated = $2 WHERE event_id = ANY($1)" + +type redactionStatements struct { + insertRedactionStmt *sql.Stmt + bulkSelectRedactionStmt *sql.Stmt + bulkUpdateValidationStatusStmt *sql.Stmt +} + +// redactedToRedactionMap is a map in the form map[redactedEventID]redactionEventID. +type redactedToRedactionMap map[string]string + +func (s *redactionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(redactionsSchema) + if err != nil { + return + } + if s.insertRedactionStmt, err = db.Prepare(insertRedactionSQL); err != nil { + return + } + if s.bulkSelectRedactionStmt, err = db.Prepare(bulkSelectRedactionSQL); err != nil { + return + } + if s.bulkUpdateValidationStatusStmt, err = db.Prepare(bulkUpdateValidationStatusSQL); err != nil { + return + } + return +} + +func (s *redactionStatements) insertRedaction( + ctx context.Context, + txn *sql.Tx, + eventID string, + redactsEventID string, + validated bool, +) error { + stmt := common.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, eventID, redactsEventID, validated) + return err +} + +// bulkSelectRedaction returns the redactions for the given event IDs. +func (s *redactionStatements) bulkSelectRedaction( + ctx context.Context, + txn *sql.Tx, + eventIDs []string, +) ( + validated redactedToRedactionMap, + unvalidated redactedToRedactionMap, + err error, +) { + stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt) + rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) + if err != nil { + return nil, nil, err + } + defer func() { err = rows.Close() }() + + validated = make(redactedToRedactionMap) + unvalidated = make(redactedToRedactionMap) + + var ( + redactedByID string + redactedEventID string + isValidated bool + ) + for rows.Next() { + if err = rows.Scan( + &redactedByID, + &redactedEventID, + &isValidated, + ); err != nil { + return nil, nil, err + } + if isValidated { + validated[redactedEventID] = redactedByID + } else { + unvalidated[redactedEventID] = redactedByID + } + } + if err = rows.Err(); err != nil { + return nil, nil, err + } + + return validated, unvalidated, nil +} + +func (s *redactionStatements) bulkUpdateValidationStatus( + ctx context.Context, + txn *sql.Tx, + eventIDs []string, + newStatus bool, +) error { + stmt := common.TxStmt(txn, s.bulkUpdateValidationStatusStmt) + _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs), newStatus) + return err +} diff --git a/syncapi/storage/syncserver.go b/syncapi/storage/syncserver.go index b4d7ccbd2..0a924b520 100644 --- a/syncapi/storage/syncserver.go +++ b/syncapi/storage/syncserver.go @@ -60,6 +60,7 @@ type SyncServerDatasource struct { events outputRoomEventsStatements roomstate currentRoomStateStatements invites inviteEventsStatements + redactions redactionStatements typingCache *cache.TypingCache } @@ -85,6 +86,9 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er if err := d.invites.prepare(d.db); err != nil { return nil, err } + if err := d.redactions.prepare(d.db); err != nil { + return nil, err + } d.typingCache = cache.NewTypingCache() return &d, nil } @@ -128,6 +132,10 @@ func (d *SyncServerDatasource) WriteEvent( } pduPosition = pos + if err = d.updateSpecialTablesForEvent(ctx, txn, ev); err != nil { + return err + } + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil @@ -138,6 +146,23 @@ func (d *SyncServerDatasource) WriteEvent( return } +func (d *SyncServerDatasource) updateSpecialTablesForEvent( + ctx context.Context, + txn *sql.Tx, + event *gomatrixserverlib.Event, +) (err error) { + switch event.Type() { + case gomatrixserverlib.MRoomRedaction: + // TODO: After we support room versioning, set validated = false only for rooms >= v3. + if err = d.redactions.insertRedaction( + ctx, txn, event.EventID(), event.Redacts(), false, + ); err != nil { + return err + } + } + return nil +} + func (d *SyncServerDatasource) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, @@ -178,7 +203,18 @@ func (d *SyncServerDatasource) updateRoomState( func (d *SyncServerDatasource) GetStateEvent( ctx context.Context, roomID, evType, stateKey string, ) (*gomatrixserverlib.Event, error) { - return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) + e, err := d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) + if e == nil || err != nil { + return e, err + } + + if err = d.applyRedactionsForEventPointers( + ctx, nil, []*gomatrixserverlib.Event{e}, + ); err != nil { + return nil, err + } + + return e, nil } // GetStateEventsForRoom fetches the state events for a given room. @@ -189,7 +225,10 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( ) (stateEvents []gomatrixserverlib.Event, err error) { err = common.WithTransaction(d.db, func(txn *sql.Tx) error { stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) - return err + if err != nil { + return err + } + return d.applyRedactionsForEventLists(ctx, txn, stateEvents) }) return } @@ -418,8 +457,17 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // We don't include a device here as we don't need to send down // transaction IDs for complete syncs recentEvents := streamEventsToEvents(nil, recentStreamEvents) - stateEvents = removeDuplicates(stateEvents, recentEvents) + + // Note that we're not passing txn into applyRedactions because txn is + // readonly but we may need to write during validation of redactions. + // This may be optimised in the future. + if err = d.applyRedactionsForEventLists( + ctx, nil, recentEvents, stateEvents, + ); err != nil { + return + } + jr := types.NewJoinResponse() if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { // Use the short form of batch token for prev_batch @@ -549,10 +597,26 @@ func (d *SyncServerDatasource) addInvitesToResponse( if err != nil { return err } + + // Unzip the map into two lists so we can applyRedactions() + roomIDs := make([]string, 0, len(invites)) + inviteEvents := make([]gomatrixserverlib.Event, 0, len(invites)) for roomID, inviteEvent := range invites { + roomIDs = append(roomIDs, roomID) + inviteEvents = append(inviteEvents, inviteEvent) + } + + // Note that we're not passing txn into applyRedactions because txn may be + // readonly but we may need to write during validation of redactions. + // This may be optimised in the future. + if err = d.applyRedactionsForEventLists(ctx, nil, inviteEvents); err != nil { + return err + } + + for i, roomID := range roomIDs { ir := types.NewInviteResponse() ir.InviteState.Events = gomatrixserverlib.ToClientEvents( - []gomatrixserverlib.Event{inviteEvent}, gomatrixserverlib.FormatSync, + []gomatrixserverlib.Event{inviteEvents[i]}, gomatrixserverlib.FormatSync, ) // TODO: add the invite state from the invite event. res.Rooms.Invite[roomID] = *ir @@ -594,6 +658,15 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( return nil } + // Note that we're not passing txn into applyRedactions because txn is + // readonly but we may need to write during validation of redactions. + // This may be optimised in the future. + if err = d.applyRedactionsForEventLists( + ctx, nil, recentEvents, delta.stateEvents, + ); err != nil { + return err + } + switch delta.membership { case "join": jr := types.NewJoinResponse() @@ -790,6 +863,165 @@ func (d *SyncServerDatasource) getStateDeltas( return deltas, nil } +// applyRedactionsForEventLists applies necessary redactions to the events in the lists in-place. +// It will replace the events with their redacted versions. +// It will update the validation status in the redactions table if there are +// redaction events newly validated. +func (d *SyncServerDatasource) applyRedactionsForEventLists( + ctx context.Context, + txn *sql.Tx, + eventLists ...[]gomatrixserverlib.Event, +) error { + totalLen := 0 + for _, eventList := range eventLists { + totalLen += len(eventList) + } + + eventPointers := make([]*gomatrixserverlib.Event, 0, totalLen) + for _, eventList := range eventLists { + for i := range eventList { + eventPointers = append(eventPointers, &eventList[i]) + } + } + + return d.applyRedactionsForEventPointers(ctx, txn, eventPointers) +} + +// applyRedactionsForEventPointers applies necessary redactions to the events +// referenced by the given pointers. The events will be replaced with their +// redacted copies. +// There cannot be nil pointers in eventPointers. +func (d *SyncServerDatasource) applyRedactionsForEventPointers( + ctx context.Context, + txn *sql.Tx, + eventPointers []*gomatrixserverlib.Event, +) error { + eventIDs := make([]string, len(eventPointers)) + for i, e := range eventPointers { + eventIDs[i] = e.EventID() + } + + validatedRedactions, unvalidatedRedactions, err := d.redactions.bulkSelectRedaction(ctx, txn, eventIDs) + if err != nil { + return err + } + + totalPossibleRedactions := len(validatedRedactions) + len(unvalidatedRedactions) + + // Fast path if nothing to redact + if totalPossibleRedactions == 0 { + return nil + } + + redactionIDToEvent, err := d.fetchRedactionEvents(ctx, txn, validatedRedactions, unvalidatedRedactions) + if err != nil { + return err + } + + eventIDToEventPointer := make(map[string]*gomatrixserverlib.Event, len(eventPointers)) + for _, p := range eventPointers { + eventIDToEventPointer[p.EventID()] = p + } + + if len(unvalidatedRedactions) != 0 { + var newlyValidated redactedToRedactionMap + if newlyValidated, err = d.validateRedactions( + ctx, txn, unvalidatedRedactions, redactionIDToEvent, eventIDToEventPointer, + ); err != nil { + return err + } + for redactedEventID, redactedByID := range newlyValidated { + validatedRedactions[redactedEventID] = redactedByID + } + } + + for redactedEventID, redactedByID := range validatedRedactions { + redactedEvent := eventIDToEventPointer[redactedEventID] + *redactedEvent = redactedEvent.Redact() + if err = redactedEvent.SetUnsignedField( + "redacted_because", + gomatrixserverlib.ToClientEvent( + *redactionIDToEvent[redactedByID], gomatrixserverlib.FormatAll, + ), + ); err != nil { + return err + } + } + + return nil +} + +func (d *SyncServerDatasource) fetchRedactionEvents( + ctx context.Context, + txn *sql.Tx, + validatedRedactions, unvalidatedRedactions redactedToRedactionMap, +) (redactionIDToEvent map[string]*gomatrixserverlib.Event, err error) { + redactionEventsToFetch := make([]string, 0, len(validatedRedactions)+len(unvalidatedRedactions)) + for _, id := range validatedRedactions { + redactionEventsToFetch = append(redactionEventsToFetch, id) + } + for _, id := range unvalidatedRedactions { + redactionEventsToFetch = append(redactionEventsToFetch, id) + } + + redactionEvents, err := d.events.selectEvents(ctx, txn, redactionEventsToFetch) + if err != nil { + return nil, err + } + + redactionIDToEvent = make(map[string]*gomatrixserverlib.Event, len(redactionEvents)) + for _, redactionEvent := range redactionEvents { + redactionIDToEvent[redactionEvent.EventID()] = &redactionEvent.Event + } + + return +} + +func (d *SyncServerDatasource) validateRedactions( + ctx context.Context, + txn *sql.Tx, + unvalidatedRedactions redactedToRedactionMap, + redactionIDToEvent map[string]*gomatrixserverlib.Event, + eventIDToEvent map[string]*gomatrixserverlib.Event, +) (validatedRedactions redactedToRedactionMap, err error) { + validatedRedactions = make(redactedToRedactionMap, len(unvalidatedRedactions)) + + var expectedDomain, redactorDomain gomatrixserverlib.ServerName + for redactedEventID, redactedByID := range unvalidatedRedactions { + if _, expectedDomain, err = gomatrixserverlib.SplitID( + '@', eventIDToEvent[redactedEventID].Sender(), + ); err != nil { + return nil, err + } + if _, redactorDomain, err = gomatrixserverlib.SplitID( + '@', redactionIDToEvent[redactedByID].Sender(), + ); err != nil { + return nil, err + } + + if redactorDomain != expectedDomain { + // TODO: Still allow power users to redact + continue + } + + validatedRedactions[redactedEventID] = redactedByID + } + + eventIDs := make([]string, 0, len(validatedRedactions)) + for _, id := range validatedRedactions { + eventIDs = append(eventIDs, id) + } + if err = d.redactions.bulkUpdateValidationStatus( + ctx, txn, eventIDs, true, + ); err != nil { + return nil, err + } + + // TODO: We might want to clear the unvalidated redactions + + return validatedRedactions, nil +} + // streamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event.