mirror of
https://github.com/matrix-org/dendrite.git
synced 2024-11-29 17:51:56 -06:00
Fix RewritesState bug (#1557)
* Set RewritesState once * Check if any new state provided * Obey rewritesState * Don't nuke everything the sync API knows when purging state * Fix panic from duplicate insert * Consistency * Use HasState * Remove nolint * Clean up joined rooms on state rewrite
This commit is contained in:
parent
04dc019e5e
commit
3afc623098
|
@ -87,6 +87,12 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
case api.OutputTypeNewRoomEvent:
|
case api.OutputTypeNewRoomEvent:
|
||||||
ev := &output.NewRoomEvent.Event
|
ev := &output.NewRoomEvent.Event
|
||||||
|
|
||||||
|
if output.NewRoomEvent.RewritesState {
|
||||||
|
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
|
||||||
|
return fmt.Errorf("s.db.PurgeRoom: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
||||||
// panic rather than continue with an inconsistent database
|
// panic rather than continue with an inconsistent database
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
|
|
@ -32,6 +32,7 @@ type Database interface {
|
||||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
PurgeRoomState(ctx context.Context, roomID string) error
|
||||||
|
|
||||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
|
||||||
const deleteJoinedHostsSQL = "" +
|
const deleteJoinedHostsSQL = "" +
|
||||||
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"
|
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"
|
||||||
|
|
||||||
|
const deleteJoinedHostsForRoomSQL = "" +
|
||||||
|
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
|
||||||
|
|
||||||
const selectJoinedHostsSQL = "" +
|
const selectJoinedHostsSQL = "" +
|
||||||
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||||
" WHERE room_id = $1"
|
" WHERE room_id = $1"
|
||||||
|
@ -67,6 +70,7 @@ type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
|
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsForRoomsStmt *sql.Stmt
|
selectJoinedHostsForRoomsStmt *sql.Stmt
|
||||||
|
@ -86,6 +90,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
|
||||||
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
|
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
|
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -117,6 +124,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) ([]types.JoinedHost, error) {
|
) ([]types.JoinedHost, error) {
|
||||||
|
|
|
@ -148,6 +148,20 @@ func (d *Database) StoreJSON(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) PurgeRoomState(
|
||||||
|
ctx context.Context, roomID string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
// If the event is a create event then we'll delete all of the existing
|
||||||
|
// data for the room. The only reason that a create event would be replayed
|
||||||
|
// to us in this way is if we're about to receive the entire room state.
|
||||||
|
if err := d.FederationSenderJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
||||||
|
return fmt.Errorf("d.FederationSenderJoinedHosts.DeleteJoinedHosts: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||||
|
|
|
@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
|
||||||
const deleteJoinedHostsSQL = "" +
|
const deleteJoinedHostsSQL = "" +
|
||||||
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
|
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
|
||||||
|
|
||||||
|
const deleteJoinedHostsForRoomSQL = "" +
|
||||||
|
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
|
||||||
|
|
||||||
const selectJoinedHostsSQL = "" +
|
const selectJoinedHostsSQL = "" +
|
||||||
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||||
" WHERE room_id = $1"
|
" WHERE room_id = $1"
|
||||||
|
@ -67,6 +70,7 @@ type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
|
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
@ -86,6 +90,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
|
||||||
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
|
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
|
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -118,6 +125,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||||
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) ([]types.JoinedHost, error) {
|
) ([]types.JoinedHost, error) {
|
||||||
|
|
|
@ -50,6 +50,7 @@ type FederationSenderQueueJSON interface {
|
||||||
type FederationSenderJoinedHosts interface {
|
type FederationSenderJoinedHosts interface {
|
||||||
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
|
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
|
||||||
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
|
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
|
||||||
|
DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
|
||||||
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
|
|
@ -116,7 +116,6 @@ type latestEventsUpdater struct {
|
||||||
|
|
||||||
func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
u.lastEventIDSent = u.updater.LastEventIDSent()
|
u.lastEventIDSent = u.updater.LastEventIDSent()
|
||||||
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
|
|
||||||
|
|
||||||
// If we are doing a regular event update then we will get the
|
// If we are doing a regular event update then we will get the
|
||||||
// previous latest events to use as a part of the calculation. If
|
// previous latest events to use as a part of the calculation. If
|
||||||
|
@ -125,7 +124,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// then start with an empty set - none of the forward extremities
|
// then start with an empty set - none of the forward extremities
|
||||||
// that we knew about before matter anymore.
|
// that we knew about before matter anymore.
|
||||||
oldLatest := []types.StateAtEventAndReference{}
|
oldLatest := []types.StateAtEventAndReference{}
|
||||||
if !u.stateAtEvent.Overwrite {
|
if !u.rewritesState {
|
||||||
|
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
|
||||||
oldLatest = u.updater.LatestEvents()
|
oldLatest = u.updater.LatestEvents()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
// Now that we know what the latest events are, it's time to get the
|
// Now that we know what the latest events are, it's time to get the
|
||||||
// latest state.
|
// latest state.
|
||||||
var updates []api.OutputEvent
|
var updates []api.OutputEvent
|
||||||
if extremitiesChanged {
|
if extremitiesChanged || u.rewritesState {
|
||||||
if err = u.latestState(); err != nil {
|
if err = u.latestState(); err != nil {
|
||||||
return fmt.Errorf("u.latestState: %w", err)
|
return fmt.Errorf("u.latestState: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -324,7 +324,6 @@ func (u *latestEventsUpdater) calculateLatest(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
|
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
|
||||||
|
|
||||||
latestEventIDs := make([]string, len(u.latest))
|
latestEventIDs := make([]string, len(u.latest))
|
||||||
for i := range u.latest {
|
for i := range u.latest {
|
||||||
latestEventIDs[i] = u.latest[i].EventID
|
latestEventIDs[i] = u.latest[i].EventID
|
||||||
|
@ -365,11 +364,6 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
||||||
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
|
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// State is rewritten if the input room event HasState and we actually produced a delta on state events.
|
|
||||||
// Without this check, /get_missing_events which produce events with associated (but not complete) state
|
|
||||||
// will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced
|
|
||||||
// a state conflict res which just so happens to include 2+ events we might purge the room state downstream.
|
|
||||||
ore.RewritesState = len(ore.AddsStateEventIDs) > 1
|
|
||||||
|
|
||||||
return &api.OutputEvent{
|
return &api.OutputEvent{
|
||||||
Type: api.OutputTypeNewRoomEvent,
|
Type: api.OutputTypeNewRoomEvent,
|
||||||
|
|
|
@ -379,7 +379,7 @@ func TestOutputRewritesState(t *testing.T) {
|
||||||
if len(producer.producedMessages) != 1 {
|
if len(producer.producedMessages) != 1 {
|
||||||
t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages))
|
t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages))
|
||||||
}
|
}
|
||||||
outputEvent := producer.producedMessages[0]
|
outputEvent := producer.producedMessages[len(producer.producedMessages)-1]
|
||||||
if !outputEvent.NewRoomEvent.RewritesState {
|
if !outputEvent.NewRoomEvent.RewritesState {
|
||||||
t.Errorf("RewritesState flag not set on output event")
|
t.Errorf("RewritesState flag not set on output event")
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,7 +149,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
if msg.RewritesState {
|
if msg.RewritesState {
|
||||||
if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil {
|
if err = s.db.PurgeRoomState(ctx, ev.RoomID()); err != nil {
|
||||||
return fmt.Errorf("s.db.PurgeRoom: %w", err)
|
return fmt.Errorf("s.db.PurgeRoom: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,9 +43,9 @@ type Database interface {
|
||||||
// Returns an error if there was a problem inserting this event.
|
// Returns an error if there was a problem inserting this event.
|
||||||
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent,
|
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent,
|
||||||
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
|
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
|
||||||
// PurgeRoom completely purges room state from the sync API. This is done when
|
// PurgeRoomState completely purges room state from the sync API. This is done when
|
||||||
// receiving an output event that completely resets the state.
|
// receiving an output event that completely resets the state.
|
||||||
PurgeRoom(ctx context.Context, roomID string) error
|
PurgeRoomState(ctx context.Context, roomID string) error
|
||||||
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
|
||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
|
|
@ -276,7 +276,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) PurgeRoom(
|
func (d *Database) PurgeRoomState(
|
||||||
ctx context.Context, roomID string,
|
ctx context.Context, roomID string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
@ -286,15 +286,6 @@ func (d *Database) PurgeRoom(
|
||||||
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
|
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
|
||||||
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
|
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
|
||||||
}
|
}
|
||||||
if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil {
|
|
||||||
return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err)
|
|
||||||
}
|
|
||||||
if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil {
|
|
||||||
return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err)
|
|
||||||
}
|
|
||||||
if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil {
|
|
||||||
return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue