Merge branch 'master' into neilalexander/keydb

This commit is contained in:
Neil Alexander 2020-05-21 10:35:47 +01:00 committed by GitHub
commit 45488487a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 301 additions and 149 deletions

4
build-dendritejs.sh Executable file
View file

@ -0,0 +1,4 @@
#!/bin/bash -eu
export GIT_COMMIT=$(git rev-list -1 HEAD) && \
GOOS=js GOARCH=wasm go build -ldflags "-X main.GitCommit=$GIT_COMMIT" -o main.wasm ./cmd/dendritejs

View file

@ -47,8 +47,10 @@ import (
_ "github.com/matrix-org/go-sqlite3-js" _ "github.com/matrix-org/go-sqlite3-js"
) )
var GitCommit string
func init() { func init() {
fmt.Println("dendrite.js starting...") fmt.Printf("[%s] dendrite.js starting...\n", GitCommit)
} }
const keyNameEd25519 = "_go_ed25519_key" const keyNameEd25519 = "_go_ed25519_key"

View file

@ -69,9 +69,14 @@ func Backfill(
// Populate the request. // Populate the request.
req := api.QueryBackfillRequest{ req := api.QueryBackfillRequest{
RoomID: roomID, RoomID: roomID,
EarliestEventsIDs: eIDs, // we don't know who the successors are for these events, which won't
ServerName: request.Origin(), // be a problem because we don't use that information when servicing /backfill requests,
// only when making them. TODO: Think of a better API shape
BackwardsExtremities: map[string][]string{
"": eIDs,
},
ServerName: request.Origin(),
} }
if req.Limit, err = strconv.Atoi(limit); err != nil { if req.Limit, err = strconv.Atoi(limit); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed")
@ -97,7 +102,7 @@ func Backfill(
} }
} }
var eventJSONs []json.RawMessage eventJSONs := []json.RawMessage{}
for _, e := range gomatrixserverlib.ReverseTopologicalOrdering( for _, e := range gomatrixserverlib.ReverseTopologicalOrdering(
evs, evs,
gomatrixserverlib.TopologicalOrderByPrevEvents, gomatrixserverlib.TopologicalOrderByPrevEvents,
@ -105,6 +110,12 @@ func Backfill(
eventJSONs = append(eventJSONs, e.JSON()) eventJSONs = append(eventJSONs, e.JSON())
} }
// sytest wants these in reversed order, similar to /messages, so reverse them now.
for i := len(eventJSONs)/2 - 1; i >= 0; i-- {
opp := len(eventJSONs) - 1 - i
eventJSONs[i], eventJSONs[opp] = eventJSONs[opp], eventJSONs[i]
}
txn := gomatrixserverlib.Transaction{ txn := gomatrixserverlib.Transaction{
Origin: cfg.Matrix.ServerName, Origin: cfg.Matrix.ServerName,
PDUs: eventJSONs, PDUs: eventJSONs,

View file

@ -21,6 +21,7 @@ import (
commonHTTP "github.com/matrix-org/dendrite/common/http" commonHTTP "github.com/matrix-org/dendrite/common/http"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
) )
@ -228,14 +229,24 @@ type QueryStateAndAuthChainResponse struct {
type QueryBackfillRequest struct { type QueryBackfillRequest struct {
// The room to backfill // The room to backfill
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
// Events to start paginating from. // A map of backwards extremity event ID to a list of its prev_event IDs.
EarliestEventsIDs []string `json:"earliest_event_ids"` BackwardsExtremities map[string][]string `json:"backwards_extremities"`
// The maximum number of events to retrieve. // The maximum number of events to retrieve.
Limit int `json:"limit"` Limit int `json:"limit"`
// The server interested in the events. // The server interested in the events.
ServerName gomatrixserverlib.ServerName `json:"server_name"` ServerName gomatrixserverlib.ServerName `json:"server_name"`
} }
// PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order.
func (r *QueryBackfillRequest) PrevEventIDs() []string {
var prevEventIDs []string
for _, pes := range r.BackwardsExtremities {
prevEventIDs = append(prevEventIDs, pes...)
}
prevEventIDs = util.UniqueStrings(prevEventIDs)
return prevEventIDs
}
// QueryBackfillResponse is a response to QueryBackfill. // QueryBackfillResponse is a response to QueryBackfill.
type QueryBackfillResponse struct { type QueryBackfillResponse struct {
// Missing events, arbritrary order. // Missing events, arbritrary order.

View file

@ -60,7 +60,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents(
defer r.mutex.Unlock() defer r.mutex.Unlock()
for i := range request.InputInviteEvents { for i := range request.InputInviteEvents {
var loopback *api.InputRoomEvent var loopback *api.InputRoomEvent
if loopback, err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil { if loopback, err = r.processInviteEvent(ctx, r, request.InputInviteEvents[i]); err != nil {
return err return err
} }
// The processInviteEvent function can optionally return a // The processInviteEvent function can optionally return a
@ -71,7 +71,7 @@ func (r *RoomserverInternalAPI) InputRoomEvents(
} }
} }
for i := range request.InputRoomEvents { for i := range request.InputRoomEvents {
if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil {
return err return err
} }
} }

View file

@ -31,21 +31,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
type OutputRoomEventWriter interface {
// Write a list of events for a room
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
}
// processRoomEvent can only be called once at a time // processRoomEvent can only be called once at a time
// //
// TODO(#375): This should be rewritten to allow concurrent calls. The // TODO(#375): This should be rewritten to allow concurrent calls. The
// difficulty is in ensuring that we correctly annotate events with the correct // difficulty is in ensuring that we correctly annotate events with the correct
// state deltas when sending to kafka streams // state deltas when sending to kafka streams
func processRoomEvent( func (r *RoomserverInternalAPI) processRoomEvent(
ctx context.Context, ctx context.Context,
db storage.Database,
ow OutputRoomEventWriter,
input api.InputRoomEvent, input api.InputRoomEvent,
) (eventID string, err error) { ) (eventID string, err error) {
// Parse and validate the event JSON // Parse and validate the event JSON
@ -54,7 +46,7 @@ func processRoomEvent(
// Check that the event passes authentication checks and work out // Check that the event passes authentication checks and work out
// the numeric IDs for the auth events. // the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs) authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
return return
@ -63,7 +55,7 @@ func processRoomEvent(
// If we don't have a transaction ID then get one. // If we don't have a transaction ID then get one.
if input.TransactionID != nil { if input.TransactionID != nil {
tdID := input.TransactionID tdID := input.TransactionID
eventID, err = db.GetTransactionEventID( eventID, err = r.DB.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.SessionID, event.Sender(), ctx, tdID.TransactionID, tdID.SessionID, event.Sender(),
) )
// On error OR event with the transaction already processed/processesing // On error OR event with the transaction already processed/processesing
@ -73,7 +65,7 @@ func processRoomEvent(
} }
// Store the event. // Store the event.
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) roomNID, stateAtEvent, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil { if err != nil {
return return
} }
@ -93,16 +85,14 @@ func processRoomEvent(
if stateAtEvent.BeforeStateSnapshotNID == 0 { if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event) err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
if err != nil { if err != nil {
return return
} }
} }
if err = updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
db, // roomserver database
ow, // output event writer
roomNID, // room NID to update roomNID, // room NID to update
stateAtEvent, // state at event (below) stateAtEvent, // state at event (below)
event, // event event, // event
@ -116,29 +106,36 @@ func processRoomEvent(
return event.EventID(), nil return event.EventID(), nil
} }
func calculateAndSetState( func (r *RoomserverInternalAPI) calculateAndSetState(
ctx context.Context, ctx context.Context,
db storage.Database,
input api.InputRoomEvent, input api.InputRoomEvent,
roomNID types.RoomNID, roomNID types.RoomNID,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
) error { ) error {
var err error var err error
roomState := state.NewStateResolution(db) roomState := state.NewStateResolution(r.DB)
if input.HasState { if input.HasState {
// TODO: Check here if we think we're in the room already. // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID
// Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it.
stateAtEvent.Overwrite = len(joinEventNIDs) == 0
}
// We've been told what the state at the event is so we don't need to calculate it. // We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err return err
} }
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
return err return err
} }
} else { } else {
@ -149,12 +146,11 @@ func calculateAndSetState(
return err return err
} }
} }
return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
} }
func processInviteEvent( func (r *RoomserverInternalAPI) processInviteEvent(
ctx context.Context, ctx context.Context,
db storage.Database,
ow *RoomserverInternalAPI, ow *RoomserverInternalAPI,
input api.InputInviteEvent, input api.InputInviteEvent,
) (*api.InputRoomEvent, error) { ) (*api.InputRoomEvent, error) {
@ -172,7 +168,10 @@ func processInviteEvent(
"target_user_id": targetUserID, "target_user_id": targetUserID,
}).Info("processing invite event") }).Info("processing invite event")
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion) _, domain, _ := gomatrixserverlib.SplitID('@', targetUserID)
isTargetLocalUser := domain == r.Cfg.Matrix.ServerName
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocalUser, input.RoomVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -239,7 +238,7 @@ func processInviteEvent(
// up from local data (which is most likely to be if the event came // up from local data (which is most likely to be if the event came
// from the CS API). If we know about the room then we can insert // from the CS API). If we know about the room then we can insert
// the invite room state, if we don't then we just fail quietly. // the invite room state, if we don't then we just fail quietly.
if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil { if irs, ierr := buildInviteStrippedState(ctx, r.DB, input); ierr == nil {
if err = event.SetUnsignedField("invite_room_state", irs); err != nil { if err = event.SetUnsignedField("invite_room_state", irs); err != nil {
return nil, err return nil, err
} }

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -46,17 +45,15 @@ import (
// 7 <----- latest // 7 <----- latest
// //
// Can only be called once at a time // Can only be called once at a time
func updateLatestEvents( func (r *RoomserverInternalAPI) updateLatestEvents(
ctx context.Context, ctx context.Context,
db storage.Database,
ow OutputRoomEventWriter,
roomNID types.RoomNID, roomNID types.RoomNID,
stateAtEvent types.StateAtEvent, stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
sendAsServer string, sendAsServer string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
) (err error) { ) (err error) {
updater, err := db.GetLatestEventsForUpdate(ctx, roomNID) updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil { if err != nil {
return return
} }
@ -70,9 +67,8 @@ func updateLatestEvents(
u := latestEventsUpdater{ u := latestEventsUpdater{
ctx: ctx, ctx: ctx,
db: db, api: r,
updater: updater, updater: updater,
ow: ow,
roomNID: roomNID, roomNID: roomNID,
stateAtEvent: stateAtEvent, stateAtEvent: stateAtEvent,
event: event, event: event,
@ -94,9 +90,8 @@ func updateLatestEvents(
// when there are so many variables to pass around. // when there are so many variables to pass around.
type latestEventsUpdater struct { type latestEventsUpdater struct {
ctx context.Context ctx context.Context
db storage.Database api *RoomserverInternalAPI
updater types.RoomRecentEventsUpdater updater types.RoomRecentEventsUpdater
ow OutputRoomEventWriter
roomNID types.RoomNID roomNID types.RoomNID
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event event gomatrixserverlib.Event
@ -181,7 +176,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// If we need to generate any output events then here's where we do it. // If we need to generate any output events then here's where we do it.
// TODO: Move this! // TODO: Move this!
updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added) updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
if err != nil { if err != nil {
return err return err
} }
@ -200,7 +195,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in // send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now. // necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil { if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return err return err
} }
@ -213,7 +208,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error { func (u *latestEventsUpdater) latestState() error {
var err error var err error
roomState := state.NewStateResolution(u.db) roomState := state.NewStateResolution(u.api.DB)
// Get a list of the current latest events. // Get a list of the current latest events.
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
@ -303,7 +298,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
latestEventIDs[i] = u.latest[i].EventID latestEventIDs[i] = u.latest[i].EventID
} }
roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -329,7 +324,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID) stateEventNIDs = append(stateEventNIDs, entry.EventNID)
} }
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))] stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs) eventIDMap, err := u.api.DB.EventIDs(u.ctx, stateEventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,7 +19,6 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -28,9 +27,8 @@ import (
// user affected by a change in the current state of the room. // user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the // Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state. // consumers about the invites added or retired by the change in current state.
func updateMemberships( func (r *RoomserverInternalAPI) updateMemberships(
ctx context.Context, ctx context.Context,
db storage.Database,
updater types.RoomRecentEventsUpdater, updater types.RoomRecentEventsUpdater,
removed, added []types.StateEntry, removed, added []types.StateEntry,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
@ -48,7 +46,7 @@ func updateMemberships(
// Load the event JSON so we can look up the "membership" key. // Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that // TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON? // key without having to load the entire event JSON?
events, err := db.Events(ctx, eventNIDs) events, err := r.DB.Events(ctx, eventNIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,15 +69,16 @@ func updateMemberships(
ae = &ev.Event ae = &ev.Event
} }
} }
if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil { if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err return nil, err
} }
} }
return updates, nil return updates, nil
} }
func updateMembership( func (r *RoomserverInternalAPI) updateMembership(
updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID, updater types.RoomRecentEventsUpdater,
targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
@ -113,7 +112,7 @@ func updateMembership(
return updates, nil return updates, nil
} }
mu, err := updater.MembershipUpdater(targetUserNID) mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,6 +131,15 @@ func updateMembership(
} }
} }
func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
isTargetLocalUser = domain == r.Cfg.Matrix.ServerName
}
return isTargetLocalUser
}
func updateToInviteMembership( func updateToInviteMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,

View file

@ -267,7 +267,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
var stateEntries []types.StateEntry var stateEntries []types.StateEntry
if stillInRoom { if stillInRoom {
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly, false)
if err != nil { if err != nil {
return err return err
} }
@ -495,14 +495,8 @@ func (r *RoomserverInternalAPI) QueryBackfill(
// defines the highest number of elements in the map below. // defines the highest number of elements in the map below.
visited := make(map[string]bool, request.Limit) visited := make(map[string]bool, request.Limit)
// The provided event IDs have already been seen by the request's emitter, // this will include these events which is what we want
// and will be retrieved anyway, so there's no need to care about them if front = request.PrevEventIDs()
// they appear in our exploration of the event tree.
for _, id := range request.EarliestEventsIDs {
visited[id] = true
}
front = request.EarliestEventsIDs
// Scan the event tree for events to send back. // Scan the event tree for events to send back.
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
@ -534,10 +528,15 @@ func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *
if err != nil { if err != nil {
return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
} }
requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName) requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
// Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
// (so we don't need to hit /state_ids which the test has no listener for)
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill( events, err := gomatrixserverlib.RequestBackfill(
ctx, requester, ctx, requester,
r.KeyRing, req.RoomID, roomVer, req.EarliestEventsIDs, req.Limit) r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100)
if err != nil { if err != nil {
return err return err
} }
@ -592,7 +591,7 @@ func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, ser
return false, err return false, err
} }
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true) eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -15,6 +16,7 @@ type backfillRequester struct {
db storage.Database db storage.Database
fedClient *gomatrixserverlib.FederationClient fedClient *gomatrixserverlib.FederationClient
thisServer gomatrixserverlib.ServerName thisServer gomatrixserverlib.ServerName
bwExtrems map[string][]string
// per-request state // per-request state
servers []gomatrixserverlib.ServerName servers []gomatrixserverlib.ServerName
@ -22,13 +24,14 @@ type backfillRequester struct {
eventIDMap map[string]gomatrixserverlib.Event eventIDMap map[string]gomatrixserverlib.Event
} }
func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName) *backfillRequester { func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester {
return &backfillRequester{ return &backfillRequester{
db: db, db: db,
fedClient: fedClient, fedClient: fedClient,
thisServer: thisServer, thisServer: thisServer,
eventIDToBeforeStateIDs: make(map[string][]string), eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]gomatrixserverlib.Event), eventIDMap: make(map[string]gomatrixserverlib.Event),
bwExtrems: bwExtrems,
} }
} }
@ -37,6 +40,11 @@ func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok { if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
return ids, nil return ids, nil
} }
if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") {
util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room")
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{}
return nil, nil
}
// if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event. // if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event.
// Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or // Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or
// we don't know the result of state res to merge forks (2 or more prev_events) // we don't know the result of state res to merge forks (2 or more prev_events)
@ -154,26 +162,44 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
// It returns a list of servers which can be queried for backfill requests. These servers // It returns a list of servers which can be queried for backfill requests. These servers
// will be servers that are in the room already. The entries at the beginning are preferred servers // will be servers that are in the room already. The entries at the beginning are preferred servers
// and will be tried first. An empty list will fail the request. // and will be tried first. An empty list will fail the request.
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) (servers []gomatrixserverlib.ServerName) { func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName {
// eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use
// its successor, so look it up.
successor := ""
FindSuccessor:
for sucID, prevEventIDs := range b.bwExtrems {
for _, pe := range prevEventIDs {
if pe == eventID {
successor = sucID
break FindSuccessor
}
}
}
if successor == "" {
logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state")
return nil
}
eventID = successor
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
// the event is necessary. // the event is necessary.
NIDs, err := b.db.EventNIDs(ctx, []string{eventID}) NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event")
return return nil
} }
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return return nil
} }
// possibly return all joined servers depending on history visiblity // possibly return all joined servers depending on history visiblity
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
if err != nil { if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
return return nil
} }
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis)) logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
@ -183,7 +209,7 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true) memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
if err != nil { if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return return nil
} }
memberEvents = append(memberEvents, memberEventsFromVis...) memberEvents = append(memberEvents, memberEventsFromVis...)
@ -192,6 +218,7 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
for _, event := range memberEvents { for _, event := range memberEvents {
serverSet[event.Origin()] = true serverSet[event.Origin()] = true
} }
var servers []gomatrixserverlib.ServerName
for server := range serverSet { for server := range serverSet {
if server == b.thisServer { if server == b.thisServer {
continue continue
@ -199,7 +226,7 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
servers = append(servers, server) servers = append(servers, server)
} }
b.servers = servers b.servers = servers
return return servers
} }
// Backfill performs a backfill request to the given server. // Backfill performs a backfill request to the given server.
@ -270,7 +297,7 @@ func joinEventsFromHistoryVisibility(
if err != nil { if err != nil {
return nil, err return nil, err
} }
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true) joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -83,9 +83,9 @@ type Database interface {
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error RemoveRoomAlias(ctx context.Context, alias string) error
MembershipUpdater(ctx context.Context, roomID, targetUserID string, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error) MembershipUpdater(ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
} }

View file

@ -59,6 +59,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
-- This NID is updated if the join event gets updated (e.g. profile update), -- This NID is updated if the join event gets updated (e.g. profile update),
-- or if the user leaves/joins the room. -- or if the user leaves/joins the room.
event_nid BIGINT NOT NULL DEFAULT 0, event_nid BIGINT NOT NULL DEFAULT 0,
-- Local target is true if the target_nid refers to a local user rather than
-- a federated one. This is an optimisation for resetting state on federated
-- room joins.
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
@ -66,8 +70,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid)" + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2)" + " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
@ -78,10 +82,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE" " WHERE room_nid = $1 AND target_nid = $2 FOR UPDATE"
@ -91,12 +105,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -110,7 +126,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db) }.prepare(db)
} }
@ -118,9 +136,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership( func (s *membershipStatements) insertMembership(
ctx context.Context, ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error { ) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt) stmt := common.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err return err
} }
@ -145,9 +164,15 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
} }
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, ctx context.Context, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID) var stmt *sql.Stmt
if localOnly {
stmt = s.selectLocalMembershipsFromRoomStmt
} else {
stmt = s.selectMembershipsFromRoomStmt
}
rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return return
} }
@ -165,10 +190,16 @@ func (s *membershipStatements) selectMembershipsFromRoom(
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
stmt := s.selectMembershipsFromRoomAndMembershipStmt var rows *sql.Rows
rows, err := stmt.QueryContext(ctx, roomNID, membership) var stmt *sql.Stmt
if localOnly {
stmt = s.selectLocalMembershipsFromRoomAndMembershipStmt
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return
} }

View file

@ -459,8 +459,8 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
} }
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
} }
// RoomNID implements query.RoomserverQueryAPIDB // RoomNID implements query.RoomserverQueryAPIDB
@ -558,7 +558,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
roomVersion gomatrixserverlib.RoomVersion, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
txn, err := d.db.Begin() txn, err := d.db.Begin()
if err != nil { if err != nil {
@ -581,7 +581,7 @@ func (d *Database) MembershipUpdater(
return nil, err return nil, err
} }
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -603,9 +603,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
targetLocal bool,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err return nil, err
} }
@ -748,15 +749,15 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) { ) ([]types.EventNID, error) {
if joinOnly { if joinOnly {
return d.statements.selectMembershipsFromRoomAndMembership( return d.statements.selectMembershipsFromRoomAndMembership(
ctx, roomNID, membershipStateJoin, ctx, roomNID, membershipStateJoin, localOnly,
) )
} }
return d.statements.selectMembershipsFromRoom(ctx, roomNID) return d.statements.selectMembershipsFromRoom(ctx, roomNID, localOnly)
} }
// EventsFromIDs implements query.RoomserverQueryAPIEventDB // EventsFromIDs implements query.RoomserverQueryAPIEventDB

View file

@ -38,6 +38,7 @@ const membershipSchema = `
sender_nid INTEGER NOT NULL DEFAULT 0, sender_nid INTEGER NOT NULL DEFAULT 0,
membership_nid INTEGER NOT NULL DEFAULT 1, membership_nid INTEGER NOT NULL DEFAULT 1,
event_nid INTEGER NOT NULL DEFAULT 0, event_nid INTEGER NOT NULL DEFAULT 0,
target_local BOOLEAN NOT NULL DEFAULT false,
UNIQUE (room_nid, target_nid) UNIQUE (room_nid, target_nid)
); );
` `
@ -45,8 +46,8 @@ const membershipSchema = `
// Insert a row in to membership table so that it can be locked by the // Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE // SELECT FOR UPDATE
const insertMembershipSQL = "" + const insertMembershipSQL = "" +
"INSERT INTO roomserver_membership (room_nid, target_nid)" + "INSERT INTO roomserver_membership (room_nid, target_nid, target_local)" +
" VALUES ($1, $2)" + " VALUES ($1, $2, $3)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectMembershipFromRoomAndTargetSQL = "" + const selectMembershipFromRoomAndTargetSQL = "" +
@ -57,10 +58,20 @@ const selectMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" " WHERE room_nid = $1 AND membership_nid = $2"
const selectLocalMembershipsFromRoomAndMembershipSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND membership_nid = $2" +
" AND target_local = true"
const selectMembershipsFromRoomSQL = "" + const selectMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" + "SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" " WHERE room_nid = $1"
const selectLocalMembershipsFromRoomSQL = "" +
"SELECT event_nid FROM roomserver_membership" +
" WHERE room_nid = $1" +
" AND target_local = true"
const selectMembershipForUpdateSQL = "" + const selectMembershipForUpdateSQL = "" +
"SELECT membership_nid FROM roomserver_membership" + "SELECT membership_nid FROM roomserver_membership" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
@ -70,12 +81,14 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5" " WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct { type membershipStatements struct {
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt
selectMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
updateMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
} }
func (s *membershipStatements) prepare(db *sql.DB) (err error) { func (s *membershipStatements) prepare(db *sql.DB) (err error) {
@ -89,7 +102,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
{&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
{&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
{&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
{&s.selectLocalMembershipsFromRoomAndMembershipStmt, selectLocalMembershipsFromRoomAndMembershipSQL},
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL}, {&s.updateMembershipStmt, updateMembershipSQL},
}.prepare(db) }.prepare(db)
} }
@ -97,9 +112,10 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
func (s *membershipStatements) insertMembership( func (s *membershipStatements) insertMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error { ) error {
stmt := common.TxStmt(txn, s.insertMembershipStmt) stmt := common.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID) _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err return err
} }
@ -127,9 +143,14 @@ func (s *membershipStatements) selectMembershipFromRoomAndTarget(
func (s *membershipStatements) selectMembershipsFromRoom( func (s *membershipStatements) selectMembershipsFromRoom(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) var selectStmt *sql.Stmt
if localOnly {
selectStmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomStmt)
} else {
selectStmt = common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
}
rows, err := selectStmt.QueryContext(ctx, roomNID) rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,11 +166,17 @@ func (s *membershipStatements) selectMembershipsFromRoom(
} }
return return
} }
func (s *membershipStatements) selectMembershipsFromRoomAndMembership( func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership membershipState, roomNID types.RoomNID, membership membershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) var stmt *sql.Stmt
if localOnly {
stmt = common.TxStmt(txn, s.selectLocalMembershipsFromRoomAndMembershipStmt)
} else {
stmt = common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
}
rows, err := stmt.QueryContext(ctx, roomNID, membership) rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil { if err != nil {
return return

View file

@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return err return err
} }
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) {
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal)
return err return err
}) })
return return
@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase // MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater( func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string, ctx context.Context, roomID, targetUserID string,
roomVersion gomatrixserverlib.RoomVersion, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (updater types.MembershipUpdater, err error) { ) (updater types.MembershipUpdater, err error) {
var txn *sql.Tx var txn *sql.Tx
txn, err = d.db.Begin() txn, err = d.db.Begin()
@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater(
return nil, err return nil, err
} }
updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx, txn *sql.Tx,
roomNID types.RoomNID, roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
targetLocal bool,
) (types.MembershipUpdater, error) { ) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err return nil, err
} }
@ -896,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB // GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom( func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) { ) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly { if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
ctx, txn, roomNID, membershipStateJoin, ctx, txn, roomNID, membershipStateJoin, localOnly,
) )
return nil return nil
} }
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil return nil
}) })
return return

View file

@ -172,7 +172,7 @@ type RoomRecentEventsUpdater interface {
MarkEventAsSent(eventNID EventNID) error MarkEventAsSent(eventNID EventNID) error
// Build a membership updater for the target user in this room. // Build a membership updater for the target user in this room.
// It will share the same transaction as this updater. // It will share the same transaction as this updater.
MembershipUpdater(targetUserNID EventStateKeyNID) (MembershipUpdater, error) MembershipUpdater(targetUserNID EventStateKeyNID, isTargetLocalUser bool) (MembershipUpdater, error)
// Implements Transaction so it can be committed or rolledback // Implements Transaction so it can be committed or rolledback
common.Transaction common.Transaction
} }

View file

@ -205,6 +205,7 @@ func (r *messagesReq) retrieveEvents() (
} }
var events []gomatrixserverlib.HeaderedEvent var events []gomatrixserverlib.HeaderedEvent
util.GetLogger(r.ctx).WithField("start", start).WithField("end", end).Infof("Fetched %d events locally", len(streamEvents))
// There can be two reasons for streamEvents to be empty: either we've // There can be two reasons for streamEvents to be empty: either we've
// reached the oldest event in the room (or the most recent one, depending // reached the oldest event in the room (or the most recent one, depending
@ -373,13 +374,13 @@ func (e eventsByDepth) Less(i, j int) bool {
// event, or if there is no remote homeserver to contact. // event, or if there is no remote homeserver to contact.
// Returns an error if there was an issue with retrieving the list of servers in // Returns an error if there was an issue with retrieving the list of servers in
// the room or sending the request. // the room or sending the request.
func (r *messagesReq) backfill(roomID string, fromEventIDs []string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) { func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]string, limit int) ([]gomatrixserverlib.HeaderedEvent, error) {
var res api.QueryBackfillResponse var res api.QueryBackfillResponse
err := r.rsAPI.QueryBackfill(context.Background(), &api.QueryBackfillRequest{ err := r.rsAPI.QueryBackfill(context.Background(), &api.QueryBackfillRequest{
RoomID: roomID, RoomID: roomID,
EarliestEventsIDs: fromEventIDs, BackwardsExtremities: backwardsExtremities,
Limit: limit, Limit: limit,
ServerName: r.cfg.Matrix.ServerName, ServerName: r.cfg.Matrix.ServerName,
}, &res) }, &res)
if err != nil { if err != nil {
return nil, fmt.Errorf("QueryBackfill failed: %w", err) return nil, fmt.Errorf("QueryBackfill failed: %w", err)
@ -412,7 +413,14 @@ func (r *messagesReq) backfill(roomID string, fromEventIDs []string, limit int)
} }
} }
return res.Events, nil // we may have got more than the requested limit so resize now
events := res.Events
if len(events) > limit {
// last `limit` events
events = events[len(events)-limit:]
}
return events, nil
} }
// setToDefault returns the default value for the "to" query parameter of a // setToDefault returns the default value for the "to" query parameter of a

View file

@ -94,9 +94,8 @@ type Database interface {
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns the event IDs of all of the backward // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
// extremities we know of for a given room. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error)
// MaxTopologicalPosition returns the highest topological position for a given room. // MaxTopologicalPosition returns the highest topological position for a given room.
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and

View file

@ -41,7 +41,7 @@ const insertBackwardExtremitySQL = "" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
const selectBackwardExtremitiesForRoomSQL = "" + const selectBackwardExtremitiesForRoomSQL = "" +
"SELECT DISTINCT event_id FROM syncapi_backward_extremities WHERE room_id = $1" "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
@ -79,23 +79,24 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (eventIDs []string, err error) { ) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed")
bwExtrems = make(map[string][]string)
for rows.Next() { for rows.Next() {
var eID string var eID string
if err = rows.Scan(&eID); err != nil { var prevEventID string
if err = rows.Scan(&eID, &prevEventID); err != nil {
return return
} }
bwExtrems[eID] = append(bwExtrems[eID], prevEventID)
eventIDs = append(eventIDs, eID)
} }
return eventIDs, rows.Err() return bwExtrems, rows.Err()
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(

View file

@ -310,6 +310,7 @@ func (d *Database) updateRoomState(
} }
membership = &value membership = &value
} }
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return err
} }
@ -367,7 +368,7 @@ func (d *Database) SyncPosition(ctx context.Context) (tok types.StreamingToken,
func (d *Database) BackwardExtremitiesForRoom( func (d *Database) BackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (backwardExtremities []string, err error) { ) (backwardExtremities map[string][]string, err error) {
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID) return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, roomID)
} }

View file

@ -41,7 +41,7 @@ const insertBackwardExtremitySQL = "" +
" ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING" " ON CONFLICT (room_id, event_id, prev_event_id) DO NOTHING"
const selectBackwardExtremitiesForRoomSQL = "" + const selectBackwardExtremitiesForRoomSQL = "" +
"SELECT DISTINCT event_id FROM syncapi_backward_extremities WHERE room_id = $1" "SELECT event_id, prev_event_id FROM syncapi_backward_extremities WHERE room_id = $1"
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
@ -79,23 +79,24 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
ctx context.Context, roomID string, ctx context.Context, roomID string,
) (eventIDs []string, err error) { ) (bwExtrems map[string][]string, err error) {
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
if err != nil { if err != nil {
return return
} }
defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed") defer common.CloseAndLogIfError(ctx, rows, "selectBackwardExtremitiesForRoom: rows.close() failed")
bwExtrems = make(map[string][]string)
for rows.Next() { for rows.Next() {
var eID string var eID string
if err = rows.Scan(&eID); err != nil { var prevEventID string
if err = rows.Scan(&eID, &prevEventID); err != nil {
return return
} }
bwExtrems[eID] = append(bwExtrems[eID], prevEventID)
eventIDs = append(eventIDs, eID)
} }
return eventIDs, rows.Err() return bwExtrems, rows.Err()
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(

View file

@ -89,8 +89,8 @@ type CurrentRoomState interface {
type BackwardsExtremities interface { type BackwardsExtremities interface {
// InsertsBackwardExtremity inserts a new backwards extremity. // InsertsBackwardExtremity inserts a new backwards extremity.
InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error) InsertsBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string) (err error)
// SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room. // SelectBackwardExtremitiesForRoom retrieves all backwards extremities for the room, as a map of event_id to list of prev_event_ids.
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (eventIDs []string, err error) SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
} }

View file

@ -16,6 +16,7 @@ package sync
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -30,6 +31,14 @@ import (
const defaultSyncTimeout = time.Duration(0) const defaultSyncTimeout = time.Duration(0)
const defaultTimelineLimit = 20 const defaultTimelineLimit = 20
type filter struct {
Room struct {
Timeline struct {
Limit *int `json:"limit"`
} `json:"timeline"`
} `json:"room"`
}
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied. // syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
type syncRequest struct { type syncRequest struct {
ctx context.Context ctx context.Context
@ -54,6 +63,17 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e
} }
since = &tok since = &tok
} }
timelineLimit := defaultTimelineLimit
// TODO: read from stored filters too
filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" && filterQuery[0] == '{' {
// attempt to parse the timeline limit at least
var f filter
err := json.Unmarshal([]byte(filterQuery), &f)
if err == nil && f.Room.Timeline.Limit != nil {
timelineLimit = *f.Room.Timeline.Limit
}
}
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{
ctx: req.Context(), ctx: req.Context(),
@ -61,7 +81,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e
timeout: timeout, timeout: timeout,
since: since, since: since,
wantFullState: wantFullState, wantFullState: wantFullState,
limit: defaultTimelineLimit, // TODO: read from filter limit: timelineLimit,
log: util.GetLogger(req.Context()), log: util.GetLogger(req.Context()),
}, nil }, nil
} }

View file

@ -59,6 +59,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
"userID": userID, "userID": userID,
"since": syncReq.since, "since": syncReq.since,
"timeout": syncReq.timeout, "timeout": syncReq.timeout,
"limit": syncReq.limit,
}) })
currPos := rp.notifier.CurrentPosition() currPos := rp.notifier.CurrentPosition()

View file

@ -279,3 +279,8 @@ Inbound federation can return missing events for invite visibility
Inbound federation can get public room list Inbound federation can get public room list
An event which redacts itself should be ignored An event which redacts itself should be ignored
A pair of events which redact each other should be ignored A pair of events which redact each other should be ignored
Outbound federation can backfill events
Inbound federation can backfill events
Backfill checks the events requested belong to the room
Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination
Outbound federation can request missing events