Implement event redaction

Signed-off-by: Alex Chen <minecnly@gmail.com>
This commit is contained in:
Cnly 2019-07-27 17:56:26 +08:00
parent 78032b3f4c
commit 40fd47957a
9 changed files with 882 additions and 29 deletions

148
clientapi/routing/redact.go Normal file
View file

@ -0,0 +1,148 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// https://matrix.org/docs/spec/client_server/r0.5.0#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid
type redactRequest struct {
Reason string `json:"reason,omitempty"`
}
type redactResponse struct {
EventID string `json:"event_id"`
}
// Redact implements PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}
func Redact(
req *http.Request,
device *authtypes.Device,
roomID, redactedEventID, txnID string,
cfg config.Dendrite,
queryAPI api.RoomserverQueryAPI,
producer *producers.RoomserverProducer,
txnCache *transactions.Cache,
) util.JSONResponse {
// TODO: Idempotency
var redactReq redactRequest
if resErr := httputil.UnmarshalJSONRequest(req, &redactReq); resErr != nil {
return *resErr
}
// Build a redaction event
builder := gomatrixserverlib.EventBuilder{
Sender: device.UserID,
RoomID: roomID,
Redacts: redactedEventID,
Type: gomatrixserverlib.MRoomRedaction,
}
err := builder.SetContent(redactReq)
if err != nil {
return httputil.LogThenError(req, err)
}
var queryRes api.QueryLatestEventsAndStateResponse
e, err := common.BuildEvent(req.Context(), &builder, cfg, time.Now(), queryAPI, &queryRes)
if err == common.ErrRoomNoExists {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Room does not exist"),
}
} else if err != nil {
return httputil.LogThenError(req, err)
}
// Do some basic checks e.g. ensuring the user is in the room and can send m.room.redaction events
stateEvents := make([]*gomatrixserverlib.Event, len(queryRes.StateEvents))
for i := range queryRes.StateEvents {
stateEvents[i] = &queryRes.StateEvents[i]
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
if err = gomatrixserverlib.Allowed(*e, &provider); err != nil {
// TODO: Is the error returned with suitable HTTP status code?
if _, ok := err.(*gomatrixserverlib.NotAllowed); ok {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(err.Error()),
}
}
return httputil.LogThenError(req, err)
}
// Ensure the user can redact the specific event
eventReq := api.QueryEventsByIDRequest{
EventIDs: []string{redactedEventID},
}
var eventResp api.QueryEventsByIDResponse
if err = queryAPI.QueryEventsByID(req.Context(), &eventReq, &eventResp); err != nil {
return httputil.LogThenError(req, err)
}
if len(eventResp.Events) == 0 {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Event to redact not found"),
}
}
redactedEvent := eventResp.Events[0]
if redactedEvent.Sender() != device.UserID {
// TODO: Allow power users to redact others' events
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("You are not allowed to redact this event"),
}
}
// Send the redaction event
txnAndDeviceID := api.TransactionID{
TransactionID: txnID,
DeviceID: device.ID,
}
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, &txnAndDeviceID,
)
if err != nil {
return httputil.LogThenError(req, err)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: redactResponse{eventID},
}
}

View file

@ -158,6 +158,17 @@ func Setup(
}), }),
).Methods(http.MethodPut, http.MethodOptions) ).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/redact/{eventID}/{txnID}",
common.MakeAuthAPI("redact", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars, err := common.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return Redact(req, device, vars["roomID"], vars["eventID"], vars["txnID"],
cfg, queryAPI, producer, transactionsCache)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg) return Register(req, accountDB, deviceDB, &cfg)
})).Methods(http.MethodPost, http.MethodOptions) })).Methods(http.MethodPost, http.MethodOptions)

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -67,9 +68,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
} }
func (s *eventJSONStatements) insertEventJSON( func (s *eventJSONStatements) insertEventJSON(
ctx context.Context, eventNID types.EventNID, eventJSON []byte, ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error { ) error {
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) stmt := common.TxStmt(txn, s.insertEventJSONStmt)
_, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON)
return err return err
} }

View file

@ -155,7 +155,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
} }
func (s *eventStatements) insertEvent( func (s *eventStatements) insertEvent(
ctx context.Context, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID,
eventTypeNID types.EventTypeNID, eventTypeNID types.EventTypeNID,
eventStateKeyNID types.EventStateKeyNID, eventStateKeyNID types.EventStateKeyNID,
@ -166,7 +166,8 @@ func (s *eventStatements) insertEvent(
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.insertEventStmt.QueryRowContext( stmt := common.TxStmt(txn, s.insertEventStmt)
err := stmt.QueryRowContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
).Scan(&eventNID, &stateNID) ).Scan(&eventNID, &stateNID)
@ -174,11 +175,12 @@ func (s *eventStatements) insertEvent(
} }
func (s *eventStatements) selectEvent( func (s *eventStatements) selectEvent(
ctx context.Context, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) stmt := common.TxStmt(txn, s.selectEventStmt)
err := stmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }

View file

@ -0,0 +1,138 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
const redactionsSchema = `
-- The redactions table holds redactions.
CREATE TABLE IF NOT EXISTS roomserver_redactions (
-- Local numeric ID for the redaction event.
event_nid BIGINT PRIMARY KEY,
-- String ID for the redacted event.
redacts TEXT NOT NULL,
-- Whether the redaction has been validated.
-- For use of the "accept first, validate later" strategy for rooms >= v3.
-- Should always be TRUE for rooms before v3.
validated BOOLEAN NOT NULL
);
CREATE INDEX IF NOT EXISTS roomserver_redactions_redacts ON roomserver_redactions(redacts);
`
const insertRedactionSQL = "" +
"INSERT INTO roomserver_redactions (event_nid, redacts, validated)" +
" VALUES ($1, $2, $3)"
const bulkSelectRedactionSQL = "" +
"SELECT event_nid, redacts, validated FROM roomserver_redactions" +
" WHERE redacts = ANY($1)"
const bulkUpdateValidationStatusSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE event_nid = ANY($1)"
type redactionStatements struct {
insertRedactionStmt *sql.Stmt
bulkSelectRedactionStmt *sql.Stmt
bulkUpdateValidationStatusStmt *sql.Stmt
}
func (s *redactionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(redactionsSchema)
if err != nil {
return
}
return statementList{
{&s.insertRedactionStmt, insertRedactionSQL},
{&s.bulkSelectRedactionStmt, bulkSelectRedactionSQL},
{&s.bulkUpdateValidationStatusStmt, bulkUpdateValidationStatusSQL},
}.prepare(db)
}
func (s *redactionStatements) insertRedaction(
ctx context.Context,
txn *sql.Tx,
eventNID types.EventNID,
redactsEventID string,
validated bool,
) error {
stmt := common.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, int64(eventNID), redactsEventID, validated)
return err
}
func (s *redactionStatements) bulkSelectRedaction(
ctx context.Context,
txn *sql.Tx,
eventIDs []string,
) (
validated map[string]types.EventNID,
unvalidated map[string]types.EventNID,
err error,
) {
stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, nil, err
}
defer func() { err = rows.Close() }()
validated = make(map[string]types.EventNID)
unvalidated = make(map[string]types.EventNID)
var (
redactedByNID types.EventNID
redactedEventID string
isValidated bool
)
for rows.Next() {
if err = rows.Scan(
&redactedByNID,
&redactedEventID,
&isValidated,
); err != nil {
return nil, nil, err
}
if isValidated {
validated[redactedEventID] = redactedByNID
} else {
unvalidated[redactedEventID] = redactedByNID
}
}
if err = rows.Err(); err != nil {
return nil, nil, err
}
return validated, unvalidated, nil
}
func (s *redactionStatements) bulkUpdateValidationStatus(
ctx context.Context,
txn *sql.Tx,
eventNIDs []types.EventNID,
newStatus bool,
) error {
stmt := common.TxStmt(txn, s.bulkUpdateValidationStatusStmt)
_, err := stmt.ExecContext(ctx, eventNIDsAsArray(eventNIDs), newStatus)
return err
}

View file

@ -31,6 +31,7 @@ type statements struct {
inviteStatements inviteStatements
membershipStatements membershipStatements
transactionStatements transactionStatements
redactionStatements
} }
func (s *statements) prepare(db *sql.DB) error { func (s *statements) prepare(db *sql.DB) error {
@ -49,6 +50,7 @@ func (s *statements) prepare(db *sql.DB) error {
s.inviteStatements.prepare, s.inviteStatements.prepare,
s.membershipStatements.prepare, s.membershipStatements.prepare,
s.transactionStatements.prepare, s.transactionStatements.prepare,
s.redactionStatements.prepare,
} { } {
if err = prepare(db); err != nil { if err = prepare(db); err != nil {
return err return err

View file

@ -20,6 +20,7 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -84,26 +85,35 @@ func (d *Database) StoreEvent(
} }
} }
if eventNID, stateNID, err = d.statements.insertEvent( err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
ctx, if eventNID, stateNID, err = d.statements.insertEvent(
roomNID, ctx, txn,
eventTypeNID, roomNID,
eventStateKeyNID, eventTypeNID,
event.EventID(), eventStateKeyNID,
event.EventReference().EventSHA256, event.EventID(),
authEventNIDs, event.EventReference().EventSHA256,
event.Depth(), authEventNIDs,
); err != nil { event.Depth(),
if err == sql.ErrNoRows { ); err != nil {
// We've already inserted the event so select the numeric event ID if err == sql.ErrNoRows {
eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) // We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID())
}
if err != nil {
return err
}
} }
if err != nil {
return 0, types.StateAtEvent{}, err
}
}
if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { if err = d.updateSpecialTablesForEvent(
ctx, txn, &event, eventNID,
); err != nil {
return err
}
return d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON())
})
if err != nil {
return 0, types.StateAtEvent{}, err return 0, types.StateAtEvent{}, err
} }
@ -167,6 +177,24 @@ func (d *Database) assignStateKeyNID(
return eventStateKeyNID, err return eventStateKeyNID, err
} }
func (d *Database) updateSpecialTablesForEvent(
ctx context.Context,
txn *sql.Tx,
event *gomatrixserverlib.Event,
eventNID types.EventNID,
) (err error) {
switch event.Type() {
case gomatrixserverlib.MRoomRedaction:
// TODO: After we support room versioning, set validated = false only for rooms >= v3.
if err = d.statements.insertRedaction(
ctx, txn, eventNID, event.Redacts(), false,
); err != nil {
return err
}
}
return nil
}
// StateEntriesForEventIDs implements input.EventDatabase // StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs( func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string, ctx context.Context, eventIDs []string,
@ -210,7 +238,9 @@ func (d *Database) Events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
results := make([]types.Event, len(eventJSONs)) results := make([]types.Event, len(eventJSONs))
eventPointers := make([]*gomatrixserverlib.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs { for i, eventJSON := range eventJSONs {
result := &results[i] result := &results[i]
result.EventNID = eventJSON.EventNID result.EventNID = eventJSON.EventNID
@ -219,10 +249,153 @@ func (d *Database) Events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
eventPointers[i] = &result.Event
} }
if err = d.applyRedactions(ctx, eventPointers); err != nil {
return nil, err
}
return results, nil return results, nil
} }
// applyRedactions applies necessary redactions to the given events.
// It will replace events referenced by the pointers with their redacted versions.
// It will update the validation status in the redactions table if there are
// redaction events newly validated.
func (d *Database) applyRedactions(
ctx context.Context,
eventPointers []*gomatrixserverlib.Event,
) error {
eventIDs := make([]string, len(eventPointers))
for i, e := range eventPointers {
eventIDs[i] = e.EventID()
}
validatedRedactions, unvalidatedRedactions, err := d.statements.bulkSelectRedaction(ctx, nil, eventIDs)
if err != nil {
return err
}
totalPossibleRedactions := len(validatedRedactions) + len(unvalidatedRedactions)
// Fast path if nothing to redact
if totalPossibleRedactions == 0 {
return nil
}
redactionNIDToEvent, err := d.fetchRedactionEvents(ctx, validatedRedactions, unvalidatedRedactions)
if err != nil {
return err
}
eventIDToEventPointer := make(map[string]*gomatrixserverlib.Event, len(eventPointers))
for _, p := range eventPointers {
eventIDToEventPointer[p.EventID()] = p
}
if len(unvalidatedRedactions) != 0 {
var newlyValidated map[string]types.EventNID
if newlyValidated, err = d.validateRedactions(
ctx, unvalidatedRedactions, redactionNIDToEvent, eventIDToEventPointer,
); err != nil {
return err
}
for redactedEventID, redactedByNID := range newlyValidated {
validatedRedactions[redactedEventID] = redactedByNID
}
}
for redactedEventID, redactedByNID := range validatedRedactions {
redactedEvent := eventIDToEventPointer[redactedEventID]
*redactedEvent = redactedEvent.Redact()
if err = redactedEvent.SetUnsignedField(
"redacted_because",
gomatrixserverlib.ToClientEvent(
*redactionNIDToEvent[redactedByNID],
gomatrixserverlib.FormatAll,
),
); err != nil {
return err
}
}
return nil
}
func (d *Database) fetchRedactionEvents(
ctx context.Context,
validatedRedactions, unvalidatedRedactions map[string]types.EventNID,
) (redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event, err error) {
redactionEventsToFetch := make([]types.EventNID, 0, len(validatedRedactions)+len(unvalidatedRedactions))
for _, nid := range validatedRedactions {
redactionEventsToFetch = append(redactionEventsToFetch, nid)
}
for _, nid := range unvalidatedRedactions {
redactionEventsToFetch = append(redactionEventsToFetch, nid)
}
redactionJSONs, err := d.statements.bulkSelectEventJSON(ctx, redactionEventsToFetch)
if err != nil {
return nil, err
}
redactionNIDToEvent = make(map[types.EventNID]*gomatrixserverlib.Event, len(redactionJSONs))
for _, redactionJSON := range redactionJSONs {
e, err := gomatrixserverlib.NewEventFromTrustedJSON(redactionJSON.EventJSON, false)
if err != nil {
return nil, err
}
redactionNIDToEvent[redactionJSON.EventNID] = &e
}
return
}
func (d *Database) validateRedactions(
ctx context.Context,
unvalidatedRedactions map[string]types.EventNID,
redactionNIDToEvent map[types.EventNID]*gomatrixserverlib.Event,
eventIDToEvent map[string]*gomatrixserverlib.Event,
) (validatedRedactions map[string]types.EventNID, err error) {
validatedRedactions = make(map[string]types.EventNID, len(unvalidatedRedactions))
var expectedDomain, redactorDomain gomatrixserverlib.ServerName
for redactedEventID, redactedByNID := range unvalidatedRedactions {
if _, expectedDomain, err = gomatrixserverlib.SplitID(
'@', eventIDToEvent[redactedEventID].Sender(),
); err != nil {
return nil, err
}
if _, redactorDomain, err = gomatrixserverlib.SplitID(
'@', redactionNIDToEvent[redactedByNID].Sender(),
); err != nil {
return nil, err
}
if redactorDomain != expectedDomain {
// TODO: Still allow power users to redact
continue
}
validatedRedactions[redactedEventID] = redactedByNID
}
eventNIDs := make([]types.EventNID, 0, len(validatedRedactions))
for _, nid := range validatedRedactions {
eventNIDs = append(eventNIDs, nid)
}
if err = d.statements.bulkUpdateValidationStatus(
ctx, nil, eventNIDs, true,
); err != nil {
return nil, err
}
// TODO: We might want to clear the unvalidated redactions
return validatedRedactions, nil
}
// AddState implements input.EventDatabase // AddState implements input.EventDatabase
func (d *Database) AddState( func (d *Database) AddState(
ctx context.Context, ctx context.Context,
@ -276,7 +449,7 @@ func (d *Database) StateEntries(
func (d *Database) SnapshotNIDFromEventID( func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string, ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
_, stateNID, err := d.statements.selectEvent(ctx, eventID) _, stateNID, err := d.statements.selectEvent(ctx, nil, eventID)
return stateNID, err return stateNID, err
} }

View file

@ -0,0 +1,145 @@
// Copyright 2019 Alex Chen
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
)
const redactionsSchema = `
-- The redactions table holds redactions.
CREATE TABLE IF NOT EXISTS syncapi_redactions (
-- The event ID for the redaction event.
event_id TEXT NOT NULL,
-- The event ID for the redacted event.
redacts TEXT NOT NULL,
-- Whether the redaction has been validated.
-- For use of the "accept first, validate later" strategy for rooms >= v3.
-- Should always be TRUE for rooms before v3.
validated BOOLEAN NOT NULL
);
CREATE INDEX IF NOT EXISTS syncapi_redactions_redacts ON syncapi_redactions(redacts);
`
const insertRedactionSQL = "" +
"INSERT INTO syncapi_redactions (event_id, redacts, validated)" +
" VALUES ($1, $2, $3)"
const bulkSelectRedactionSQL = "" +
"SELECT event_id, redacts, validated FROM syncapi_redactions" +
" WHERE redacts = ANY($1)"
const bulkUpdateValidationStatusSQL = "" +
" UPDATE syncapi_redactions SET validated = $2 WHERE event_id = ANY($1)"
type redactionStatements struct {
insertRedactionStmt *sql.Stmt
bulkSelectRedactionStmt *sql.Stmt
bulkUpdateValidationStatusStmt *sql.Stmt
}
// redactedToRedactionMap is a map in the form map[redactedEventID]redactionEventID.
type redactedToRedactionMap map[string]string
func (s *redactionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(redactionsSchema)
if err != nil {
return
}
if s.insertRedactionStmt, err = db.Prepare(insertRedactionSQL); err != nil {
return
}
if s.bulkSelectRedactionStmt, err = db.Prepare(bulkSelectRedactionSQL); err != nil {
return
}
if s.bulkUpdateValidationStatusStmt, err = db.Prepare(bulkUpdateValidationStatusSQL); err != nil {
return
}
return
}
func (s *redactionStatements) insertRedaction(
ctx context.Context,
txn *sql.Tx,
eventID string,
redactsEventID string,
validated bool,
) error {
stmt := common.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, eventID, redactsEventID, validated)
return err
}
// bulkSelectRedaction returns the redactions for the given event IDs.
func (s *redactionStatements) bulkSelectRedaction(
ctx context.Context,
txn *sql.Tx,
eventIDs []string,
) (
validated redactedToRedactionMap,
unvalidated redactedToRedactionMap,
err error,
) {
stmt := common.TxStmt(txn, s.bulkSelectRedactionStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, nil, err
}
defer func() { err = rows.Close() }()
validated = make(redactedToRedactionMap)
unvalidated = make(redactedToRedactionMap)
var (
redactedByID string
redactedEventID string
isValidated bool
)
for rows.Next() {
if err = rows.Scan(
&redactedByID,
&redactedEventID,
&isValidated,
); err != nil {
return nil, nil, err
}
if isValidated {
validated[redactedEventID] = redactedByID
} else {
unvalidated[redactedEventID] = redactedByID
}
}
if err = rows.Err(); err != nil {
return nil, nil, err
}
return validated, unvalidated, nil
}
func (s *redactionStatements) bulkUpdateValidationStatus(
ctx context.Context,
txn *sql.Tx,
eventIDs []string,
newStatus bool,
) error {
stmt := common.TxStmt(txn, s.bulkUpdateValidationStatusStmt)
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs), newStatus)
return err
}

View file

@ -60,6 +60,7 @@ type SyncServerDatasource struct {
events outputRoomEventsStatements events outputRoomEventsStatements
roomstate currentRoomStateStatements roomstate currentRoomStateStatements
invites inviteEventsStatements invites inviteEventsStatements
redactions redactionStatements
typingCache *cache.TypingCache typingCache *cache.TypingCache
} }
@ -85,6 +86,9 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er
if err := d.invites.prepare(d.db); err != nil { if err := d.invites.prepare(d.db); err != nil {
return nil, err return nil, err
} }
if err := d.redactions.prepare(d.db); err != nil {
return nil, err
}
d.typingCache = cache.NewTypingCache() d.typingCache = cache.NewTypingCache()
return &d, nil return &d, nil
} }
@ -128,6 +132,10 @@ func (d *SyncServerDatasource) WriteEvent(
} }
pduPosition = pos pduPosition = pos
if err = d.updateSpecialTablesForEvent(ctx, txn, ev); err != nil {
return err
}
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event. // Nothing to do, the event may have just been a message event.
return nil return nil
@ -138,6 +146,23 @@ func (d *SyncServerDatasource) WriteEvent(
return return
} }
func (d *SyncServerDatasource) updateSpecialTablesForEvent(
ctx context.Context,
txn *sql.Tx,
event *gomatrixserverlib.Event,
) (err error) {
switch event.Type() {
case gomatrixserverlib.MRoomRedaction:
// TODO: After we support room versioning, set validated = false only for rooms >= v3.
if err = d.redactions.insertRedaction(
ctx, txn, event.EventID(), event.Redacts(), false,
); err != nil {
return err
}
}
return nil
}
func (d *SyncServerDatasource) updateRoomState( func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
removedEventIDs []string, removedEventIDs []string,
@ -178,7 +203,18 @@ func (d *SyncServerDatasource) updateRoomState(
func (d *SyncServerDatasource) GetStateEvent( func (d *SyncServerDatasource) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.Event, error) { ) (*gomatrixserverlib.Event, error) {
return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey) e, err := d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
if e == nil || err != nil {
return e, err
}
if err = d.applyRedactionsForEventPointers(
ctx, nil, []*gomatrixserverlib.Event{e},
); err != nil {
return nil, err
}
return e, nil
} }
// GetStateEventsForRoom fetches the state events for a given room. // GetStateEventsForRoom fetches the state events for a given room.
@ -189,7 +225,10 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
) (stateEvents []gomatrixserverlib.Event, err error) { ) (stateEvents []gomatrixserverlib.Event, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID)
return err if err != nil {
return err
}
return d.applyRedactionsForEventLists(ctx, txn, stateEvents)
}) })
return return
} }
@ -418,8 +457,17 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs // transaction IDs for complete syncs
recentEvents := streamEventsToEvents(nil, recentStreamEvents) recentEvents := streamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
// Note that we're not passing txn into applyRedactions because txn is
// readonly but we may need to write during validation of redactions.
// This may be optimised in the future.
if err = d.applyRedactionsForEventLists(
ctx, nil, recentEvents, stateEvents,
); err != nil {
return
}
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
// Use the short form of batch token for prev_batch // Use the short form of batch token for prev_batch
@ -549,10 +597,26 @@ func (d *SyncServerDatasource) addInvitesToResponse(
if err != nil { if err != nil {
return err return err
} }
// Unzip the map into two lists so we can applyRedactions()
roomIDs := make([]string, 0, len(invites))
inviteEvents := make([]gomatrixserverlib.Event, 0, len(invites))
for roomID, inviteEvent := range invites { for roomID, inviteEvent := range invites {
roomIDs = append(roomIDs, roomID)
inviteEvents = append(inviteEvents, inviteEvent)
}
// Note that we're not passing txn into applyRedactions because txn may be
// readonly but we may need to write during validation of redactions.
// This may be optimised in the future.
if err = d.applyRedactionsForEventLists(ctx, nil, inviteEvents); err != nil {
return err
}
for i, roomID := range roomIDs {
ir := types.NewInviteResponse() ir := types.NewInviteResponse()
ir.InviteState.Events = gomatrixserverlib.ToClientEvents( ir.InviteState.Events = gomatrixserverlib.ToClientEvents(
[]gomatrixserverlib.Event{inviteEvent}, gomatrixserverlib.FormatSync, []gomatrixserverlib.Event{inviteEvents[i]}, gomatrixserverlib.FormatSync,
) )
// TODO: add the invite state from the invite event. // TODO: add the invite state from the invite event.
res.Rooms.Invite[roomID] = *ir res.Rooms.Invite[roomID] = *ir
@ -594,6 +658,15 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
return nil return nil
} }
// Note that we're not passing txn into applyRedactions because txn is
// readonly but we may need to write during validation of redactions.
// This may be optimised in the future.
if err = d.applyRedactionsForEventLists(
ctx, nil, recentEvents, delta.stateEvents,
); err != nil {
return err
}
switch delta.membership { switch delta.membership {
case "join": case "join":
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -790,6 +863,165 @@ func (d *SyncServerDatasource) getStateDeltas(
return deltas, nil return deltas, nil
} }
// applyRedactionsForEventLists applies necessary redactions to the events in the lists in-place.
// It will replace the events with their redacted versions.
// It will update the validation status in the redactions table if there are
// redaction events newly validated.
func (d *SyncServerDatasource) applyRedactionsForEventLists(
ctx context.Context,
txn *sql.Tx,
eventLists ...[]gomatrixserverlib.Event,
) error {
totalLen := 0
for _, eventList := range eventLists {
totalLen += len(eventList)
}
eventPointers := make([]*gomatrixserverlib.Event, 0, totalLen)
for _, eventList := range eventLists {
for i := range eventList {
eventPointers = append(eventPointers, &eventList[i])
}
}
return d.applyRedactionsForEventPointers(ctx, txn, eventPointers)
}
// applyRedactionsForEventPointers applies necessary redactions to the events
// referenced by the given pointers. The events will be replaced with their
// redacted copies.
// There cannot be nil pointers in eventPointers.
func (d *SyncServerDatasource) applyRedactionsForEventPointers(
ctx context.Context,
txn *sql.Tx,
eventPointers []*gomatrixserverlib.Event,
) error {
eventIDs := make([]string, len(eventPointers))
for i, e := range eventPointers {
eventIDs[i] = e.EventID()
}
validatedRedactions, unvalidatedRedactions, err := d.redactions.bulkSelectRedaction(ctx, txn, eventIDs)
if err != nil {
return err
}
totalPossibleRedactions := len(validatedRedactions) + len(unvalidatedRedactions)
// Fast path if nothing to redact
if totalPossibleRedactions == 0 {
return nil
}
redactionIDToEvent, err := d.fetchRedactionEvents(ctx, txn, validatedRedactions, unvalidatedRedactions)
if err != nil {
return err
}
eventIDToEventPointer := make(map[string]*gomatrixserverlib.Event, len(eventPointers))
for _, p := range eventPointers {
eventIDToEventPointer[p.EventID()] = p
}
if len(unvalidatedRedactions) != 0 {
var newlyValidated redactedToRedactionMap
if newlyValidated, err = d.validateRedactions(
ctx, txn, unvalidatedRedactions, redactionIDToEvent, eventIDToEventPointer,
); err != nil {
return err
}
for redactedEventID, redactedByID := range newlyValidated {
validatedRedactions[redactedEventID] = redactedByID
}
}
for redactedEventID, redactedByID := range validatedRedactions {
redactedEvent := eventIDToEventPointer[redactedEventID]
*redactedEvent = redactedEvent.Redact()
if err = redactedEvent.SetUnsignedField(
"redacted_because",
gomatrixserverlib.ToClientEvent(
*redactionIDToEvent[redactedByID], gomatrixserverlib.FormatAll,
),
); err != nil {
return err
}
}
return nil
}
func (d *SyncServerDatasource) fetchRedactionEvents(
ctx context.Context,
txn *sql.Tx,
validatedRedactions, unvalidatedRedactions redactedToRedactionMap,
) (redactionIDToEvent map[string]*gomatrixserverlib.Event, err error) {
redactionEventsToFetch := make([]string, 0, len(validatedRedactions)+len(unvalidatedRedactions))
for _, id := range validatedRedactions {
redactionEventsToFetch = append(redactionEventsToFetch, id)
}
for _, id := range unvalidatedRedactions {
redactionEventsToFetch = append(redactionEventsToFetch, id)
}
redactionEvents, err := d.events.selectEvents(ctx, txn, redactionEventsToFetch)
if err != nil {
return nil, err
}
redactionIDToEvent = make(map[string]*gomatrixserverlib.Event, len(redactionEvents))
for _, redactionEvent := range redactionEvents {
redactionIDToEvent[redactionEvent.EventID()] = &redactionEvent.Event
}
return
}
func (d *SyncServerDatasource) validateRedactions(
ctx context.Context,
txn *sql.Tx,
unvalidatedRedactions redactedToRedactionMap,
redactionIDToEvent map[string]*gomatrixserverlib.Event,
eventIDToEvent map[string]*gomatrixserverlib.Event,
) (validatedRedactions redactedToRedactionMap, err error) {
validatedRedactions = make(redactedToRedactionMap, len(unvalidatedRedactions))
var expectedDomain, redactorDomain gomatrixserverlib.ServerName
for redactedEventID, redactedByID := range unvalidatedRedactions {
if _, expectedDomain, err = gomatrixserverlib.SplitID(
'@', eventIDToEvent[redactedEventID].Sender(),
); err != nil {
return nil, err
}
if _, redactorDomain, err = gomatrixserverlib.SplitID(
'@', redactionIDToEvent[redactedByID].Sender(),
); err != nil {
return nil, err
}
if redactorDomain != expectedDomain {
// TODO: Still allow power users to redact
continue
}
validatedRedactions[redactedEventID] = redactedByID
}
eventIDs := make([]string, 0, len(validatedRedactions))
for _, id := range validatedRedactions {
eventIDs = append(eventIDs, id)
}
if err = d.redactions.bulkUpdateValidationStatus(
ctx, txn, eventIDs, true,
); err != nil {
return nil, err
}
// TODO: We might want to clear the unvalidated redactions
return validatedRedactions, nil
}
// streamEventsToEvents converts streamEvent to Event. If device is non-nil and // streamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.