Merge branch 'master' into matthew/peeking-over-fed

This commit is contained in:
Kegsay 2020-10-22 12:06:43 +01:00 committed by GitHub
commit 0fd9e960bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 325 additions and 124 deletions

View file

@ -75,7 +75,7 @@ Then point your favourite Matrix client at `http://localhost:8008`.
We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver We use a script called Are We Synapse Yet which checks Sytest compliance rates. Sytest is a black-box homeserver
test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it test rig with around 900 tests. The script works out how many of these tests are passing on Dendrite and it
updates with CI. As of October 2020 we're at around 56% CS API coverage and 77% Federation coverage, though check updates with CI. As of October 2020 we're at around 57% CS API coverage and 81% Federation coverage, though check
CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse CI for the latest numbers. In practice, this means you can communicate locally and via federation with Synapse
servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably: servers such as matrix.org reasonably well. There's a long list of features that are not implemented, notably:
- Receipts - Receipts

132
cmd/resolve-state/main.go Normal file
View file

@ -0,0 +1,132 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"strconv"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// This is a utility for inspecting state snapshots and running state resolution
// against real snapshots in an actual database.
// It takes one or more state snapshot NIDs as arguments, along with a room version
// to use for unmarshalling events, and will produce resolved output.
//
// Usage: ./resolve-state --roomversion=version snapshot [snapshot ...]
// e.g. ./resolve-state --roomversion=5 1254 1235 1282
var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
// nolint:gocyclo
func main() {
ctx := context.Background()
cfg := setup.ParseFlags(true)
args := os.Args[1:]
fmt.Println("Room version", *roomVersion)
snapshotNIDs := []types.StateSnapshotNID{}
for _, arg := range args {
if i, err := strconv.Atoi(arg); err == nil {
snapshotNIDs = append(snapshotNIDs, types.StateSnapshotNID(i))
}
}
fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs")
cache, err := caching.NewInMemoryLRUCache(true)
if err != nil {
panic(err)
}
roomserverDB, err := storage.Open(&cfg.RoomServer.Database, cache)
if err != nil {
panic(err)
}
blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs)
if err != nil {
panic(err)
}
var stateEntries []types.StateEntryList
for _, list := range blockNIDs {
entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs)
if err2 != nil {
panic(err2)
}
stateEntries = append(stateEntries, entries...)
}
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
for _, e := range entry.StateEntries {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
fmt.Println("Fetching", len(eventNIDs), "state events")
eventEntries, err := roomserverDB.Events(ctx, eventNIDs)
if err != nil {
panic(err)
}
authEventIDMap := make(map[string]struct{})
eventPtrs := make([]*gomatrixserverlib.Event, len(eventEntries))
for i := range eventEntries {
eventPtrs[i] = &eventEntries[i].Event
for _, authEventID := range eventEntries[i].AuthEventIDs() {
authEventIDMap[authEventID] = struct{}{}
}
}
authEventIDs := make([]string, 0, len(authEventIDMap))
for authEventID := range authEventIDMap {
authEventIDs = append(authEventIDs, authEventID)
}
fmt.Println("Fetching", len(authEventIDs), "auth events")
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs)
if err != nil {
panic(err)
}
authEventPtrs := make([]*gomatrixserverlib.Event, len(authEventEntries))
for i := range authEventEntries {
authEventPtrs[i] = &authEventEntries[i].Event
}
events := make([]gomatrixserverlib.Event, len(eventEntries))
authEvents := make([]gomatrixserverlib.Event, len(authEventEntries))
for i, ptr := range eventPtrs {
events[i] = *ptr
}
for i, ptr := range authEventPtrs {
authEvents[i] = *ptr
}
fmt.Println("Resolving state")
resolved, err := state.ResolveConflictsAdhoc(
gomatrixserverlib.RoomVersion(*roomVersion),
events,
authEvents,
)
if err != nil {
panic(err)
}
fmt.Println("Resolved state contains", len(resolved), "events")
for _, event := range resolved {
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}

View file

@ -87,6 +87,12 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
case api.OutputTypeNewRoomEvent: case api.OutputTypeNewRoomEvent:
ev := &output.NewRoomEvent.Event ev := &output.NewRoomEvent.Event
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err)
}
}
if err := s.processMessage(*output.NewRoomEvent); err != nil { if err := s.processMessage(*output.NewRoomEvent); err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{

View file

@ -32,6 +32,7 @@ type Database interface {
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given. // GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
const deleteJoinedHostsSQL = "" + const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)" "DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" + const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" + "SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1" " WHERE room_id = $1"
@ -67,6 +70,7 @@ type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt
@ -86,6 +90,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil { if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
return return
} }
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil { if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
return return
} }
@ -117,6 +124,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
return err return err
} }
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx( func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {

View file

@ -150,6 +150,20 @@ func (d *Database) StoreJSON(
}, nil }, nil
} }
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationSenderJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationSenderJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName) return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)

View file

@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
const deleteJoinedHostsSQL = "" + const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1" "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" + const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" + "SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1" " WHERE room_id = $1"
@ -64,11 +67,12 @@ const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt deleteJoinedHostsForRoomStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
} }
@ -86,6 +90,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil { if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return return
} }
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return return
} }
@ -118,6 +125,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
return nil return nil
} }
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx( func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {

View file

@ -50,6 +50,7 @@ type FederationSenderQueueJSON interface {
type FederationSenderJoinedHosts interface { type FederationSenderJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)

View file

@ -17,7 +17,6 @@
package input package input
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
@ -28,7 +27,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// updateLatestEvents updates the list of latest events for this room in the database and writes the // updateLatestEvents updates the list of latest events for this room in the database and writes the
@ -118,7 +116,6 @@ type latestEventsUpdater struct {
func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) doUpdateLatestEvents() error {
u.lastEventIDSent = u.updater.LastEventIDSent() u.lastEventIDSent = u.updater.LastEventIDSent()
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
// If we are doing a regular event update then we will get the // If we are doing a regular event update then we will get the
// previous latest events to use as a part of the calculation. If // previous latest events to use as a part of the calculation. If
@ -127,7 +124,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// then start with an empty set - none of the forward extremities // then start with an empty set - none of the forward extremities
// that we knew about before matter anymore. // that we knew about before matter anymore.
oldLatest := []types.StateAtEventAndReference{} oldLatest := []types.StateAtEventAndReference{}
if !u.stateAtEvent.Overwrite { if !u.rewritesState {
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
oldLatest = u.updater.LatestEvents() oldLatest = u.updater.LatestEvents()
} }
@ -141,27 +139,32 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Work out what the latest events are. This will include the new // Work out what the latest events are. This will include the new
// event if it is not already referenced. // event if it is not already referenced.
if err := u.calculateLatest( extremitiesChanged, err := u.calculateLatest(
oldLatest, oldLatest, &u.event,
types.StateAtEventAndReference{ types.StateAtEventAndReference{
EventReference: u.event.EventReference(), EventReference: u.event.EventReference(),
StateAtEvent: u.stateAtEvent, StateAtEvent: u.stateAtEvent,
}, },
); err != nil { )
if err != nil {
return fmt.Errorf("u.calculateLatest: %w", err) return fmt.Errorf("u.calculateLatest: %w", err)
} }
// Now that we know what the latest events are, it's time to get the // Now that we know what the latest events are, it's time to get the
// latest state. // latest state.
if err := u.latestState(); err != nil { var updates []api.OutputEvent
return fmt.Errorf("u.latestState: %w", err) if extremitiesChanged || u.rewritesState {
} if err = u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err)
}
// If we need to generate any output events then here's where we do it. // If we need to generate any output events then here's where we do it.
// TODO: Move this! // TODO: Move this!
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if updates, err = u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added); err != nil {
if err != nil { return fmt.Errorf("u.api.updateMemberships: %w", err)
return fmt.Errorf("u.api.updateMemberships: %w", err) }
} else {
u.newStateNID = u.oldStateNID
} }
update, err := u.makeOutputNewRoomEvent() update, err := u.makeOutputNewRoomEvent()
@ -250,54 +253,77 @@ func (u *latestEventsUpdater) latestState() error {
// true if the new event is included in those extremites, false otherwise. // true if the new event is included in those extremites, false otherwise.
func (u *latestEventsUpdater) calculateLatest( func (u *latestEventsUpdater) calculateLatest(
oldLatest []types.StateAtEventAndReference, oldLatest []types.StateAtEventAndReference,
newEvent types.StateAtEventAndReference, newEvent *gomatrixserverlib.Event,
) error { newStateAndRef types.StateAtEventAndReference,
var newLatest []types.StateAtEventAndReference ) (bool, error) {
// First of all, get a list of all of the events in our current
// First of all, let's see if any of the existing forward extremities // set of forward extremities.
// now have entries in the previous events table. If they do then we existingRefs := make(map[string]*types.StateAtEventAndReference)
// will no longer include them as forward extremities. existingNIDs := make([]types.EventNID, len(oldLatest))
for _, l := range oldLatest { for i, old := range oldLatest {
referenced, err := u.updater.IsReferenced(l.EventReference) existingRefs[old.EventID] = &oldLatest[i]
if err != nil { existingNIDs[i] = old.EventNID
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID)
return fmt.Errorf("u.updater.IsReferenced (old): %w", err)
} else if !referenced {
newLatest = append(newLatest, l)
}
} }
// Then check and see if our new event is already included in that set. // Look up the old extremity events. This allows us to find their
// This ordinarily won't happen but it covers the edge-case that we've // prev events.
// already seen this event before and it's a forward extremity, so rather events, err := u.api.DB.Events(u.ctx, existingNIDs)
// than adding a duplicate, we'll just return the set as complete.
for _, l := range newLatest {
if l.EventReference.EventID == newEvent.EventReference.EventID && bytes.Equal(l.EventReference.EventSHA256, newEvent.EventReference.EventSHA256) {
// We've already referenced this new event so we can just return
// the newly completed extremities at this point.
u.latest = newLatest
return nil
}
}
// At this point we've processed the old extremities, and we've checked
// that our new event isn't already in that set. Therefore now we can
// check if our *new* event is a forward extremity, and if it is, add
// it in.
referenced, err := u.updater.IsReferenced(newEvent.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID) return false, fmt.Errorf("u.api.DB.Events: %w", err)
return fmt.Errorf("u.updater.IsReferenced (new): %w", err) }
} else if !referenced || len(newLatest) == 0 {
newLatest = append(newLatest, newEvent) // Make a list of all of the prev events as referenced by all of
// the current forward extremities.
existingPrevs := make(map[string]struct{})
for _, old := range events {
for _, prevEventID := range old.PrevEventIDs() {
existingPrevs[prevEventID] = struct{}{}
}
}
// If the "new" event is already referenced by a forward extremity
// then do nothing - it's not a candidate to be a new extremity if
// it has been referenced.
if _, ok := existingPrevs[newEvent.EventID()]; ok {
return false, nil
}
// If the "new" event is already a forward extremity then stop, as
// nothing changes.
for _, event := range events {
if event.EventID() == newEvent.EventID() {
return false, nil
}
}
// Include our new event in the extremities.
newLatest := []types.StateAtEventAndReference{newStateAndRef}
// Then run through and see if the other extremities are still valid.
// If our new event references them then they are no longer good
// candidates.
for _, prevEventID := range newEvent.PrevEventIDs() {
delete(existingRefs, prevEventID)
}
// Ensure that we don't add any candidate forward extremities from
// the old set that are, themselves, referenced by the old set of
// forward extremities. This shouldn't happen but guards against
// the possibility anyway.
for prevEventID := range existingPrevs {
delete(existingRefs, prevEventID)
}
// Then re-add any old extremities that are still valid after all.
for _, old := range existingRefs {
newLatest = append(newLatest, *old)
} }
u.latest = newLatest u.latest = newLatest
return nil return true, nil
} }
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
latestEventIDs := make([]string, len(u.latest)) latestEventIDs := make([]string, len(u.latest))
for i := range u.latest { for i := range u.latest {
latestEventIDs[i] = u.latest[i].EventID latestEventIDs[i] = u.latest[i].EventID
@ -338,11 +364,6 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
} }
} }
// State is rewritten if the input room event HasState and we actually produced a delta on state events.
// Without this check, /get_missing_events which produce events with associated (but not complete) state
// will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced
// a state conflict res which just so happens to include 2+ events we might purge the room state downstream.
ore.RewritesState = len(ore.AddsStateEventIDs) > 1
return &api.OutputEvent{ return &api.OutputEvent{
Type: api.OutputTypeNewRoomEvent, Type: api.OutputTypeNewRoomEvent,

View file

@ -379,7 +379,7 @@ func TestOutputRewritesState(t *testing.T) {
if len(producer.producedMessages) != 1 { if len(producer.producedMessages) != 1 {
t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages)) t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages))
} }
outputEvent := producer.producedMessages[0] outputEvent := producer.producedMessages[len(producer.producedMessages)-1]
if !outputEvent.NewRoomEvent.RewritesState { if !outputEvent.NewRoomEvent.RewritesState {
t.Errorf("RewritesState flag not set on output event") t.Errorf("RewritesState flag not set on output event")
} }

View file

@ -526,13 +526,7 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent(
isRejected bool, isRejected bool,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
// Load the state at the prev events. // Load the state at the prev events.
prevEventRefs := event.PrevEvents() prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
prevEventIDs := make([]string, len(prevEventRefs))
for i := range prevEventRefs {
prevEventIDs[i] = prevEventRefs[i].EventID
}
prevStates, err := v.db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -27,23 +27,24 @@ import (
const redactionsArePermanent = true const redactionsArePermanent = true
type Database struct { type Database struct {
DB *sql.DB DB *sql.DB
Cache caching.RoomServerCaches Cache caching.RoomServerCaches
Writer sqlutil.Writer Writer sqlutil.Writer
EventsTable tables.Events EventsTable tables.Events
EventJSONTable tables.EventJSON EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms RoomsTable tables.Rooms
TransactionsTable tables.Transactions TransactionsTable tables.Transactions
StateSnapshotTable tables.StateSnapshot StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites InvitesTable tables.Invites
MembershipTable tables.Membership MembershipTable tables.Membership
PublishedTable tables.Published PublishedTable tables.Published
RedactionsTable tables.Redactions RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
} }
func (d *Database) SupportsConcurrentRoomInputs() bool { func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -372,6 +373,9 @@ func (d *Database) MembershipUpdater(
func (d *Database) GetLatestEventsForUpdate( func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomInfo types.RoomInfo, ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) { ) (*LatestEventsUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
}
txn, err := d.DB.Begin() txn, err := d.DB.Begin()
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -120,23 +120,24 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Cache: cache, Cache: cache,
Writer: d.writer, Writer: d.writer,
EventsTable: d.events, EventsTable: d.events,
EventTypesTable: d.eventTypes, EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys, EventStateKeysTable: d.eventStateKeys,
EventJSONTable: d.eventJSON, EventJSONTable: d.eventJSON,
RoomsTable: d.rooms, RoomsTable: d.rooms,
TransactionsTable: d.transactions, TransactionsTable: d.transactions,
StateBlockTable: stateBlock, StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot, StateSnapshotTable: stateSnapshot,
PrevEventsTable: d.prevEvents, PrevEventsTable: d.prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: d.invites, InvitesTable: d.invites,
MembershipTable: d.membership, MembershipTable: d.membership,
PublishedTable: published, PublishedTable: published,
RedactionsTable: redactions, RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
} }
return &d, nil return &d, nil
} }

View file

@ -149,7 +149,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
if msg.RewritesState { if msg.RewritesState {
if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil { if err = s.db.PurgeRoomState(ctx, ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err) return fmt.Errorf("s.db.PurgeRoom: %w", err)
} }
} }
@ -189,14 +189,20 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
) error { ) error {
ev := msg.Event ev := msg.Event
// TODO: The state key check when excluding from sync is designed
// to stop us from lying to clients with old state, whilst still
// allowing normal timeline events through. This is an absolute
// hack but until we have some better strategy for dealing with
// old events in the sync API, this should at least prevent us
// from confusing clients into thinking they've joined/left rooms.
pduPos, err := s.db.WriteEvent( pduPos, err := s.db.WriteEvent(
ctx, ctx,
&ev, &ev,
[]gomatrixserverlib.HeaderedEvent{}, []gomatrixserverlib.HeaderedEvent{},
[]string{}, // adds no state []string{}, // adds no state
[]string{}, // removes no state []string{}, // removes no state
nil, // no transaction nil, // no transaction
false, // not excluded from sync ev.StateKey() != nil, // exclude from sync?
) )
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database

View file

@ -43,9 +43,9 @@ type Database interface {
// Returns an error if there was a problem inserting this event. // Returns an error if there was a problem inserting this event.
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent,
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
// PurgeRoom completely purges room state from the sync API. This is done when // PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state. // receiving an output event that completely resets the state.
PurgeRoom(ctx context.Context, roomID string) error PurgeRoomState(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error

View file

@ -276,7 +276,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
return nil return nil
} }
func (d *Database) PurgeRoom( func (d *Database) PurgeRoomState(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -286,15 +286,6 @@ func (d *Database) PurgeRoom(
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
} }
if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err)
}
if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err)
}
if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err)
}
return nil return nil
}) })
} }