From 82a96176599aad737e5052a501442b29a5be5c69 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Tue, 1 Sep 2020 20:35:38 +0100 Subject: [PATCH 01/12] Put redactions/filters in the writer goroutine (#1378) * Put redactions in the writer goroutine * Update filters on writer goroutine --- syncapi/storage/shared/syncserver.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 3388473ae..255fe6b58 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -525,7 +525,13 @@ func (d *Database) GetFilter( func (d *Database) PutFilter( ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, ) (string, error) { - return d.Filter.InsertFilter(ctx, filter, localpart) + var filterID string + var err error + err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + filterID, err = d.Filter.InsertFilter(ctx, filter, localpart) + return err + }) + return filterID, err } func (d *Database) IncrementalSync( @@ -587,7 +593,10 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda } newEvent := ev.Headered(redactedBecause.RoomVersion) - return d.OutputEvents.UpdateEventJSON(ctx, &newEvent) + err = d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.OutputEvents.UpdateEventJSON(ctx, &newEvent) + }) + return err } // getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed From 02a73f29f861c637f30df4a2bb1fce400e481a3c Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 2 Sep 2020 10:02:48 +0100 Subject: [PATCH 02/12] Expand RoomInfo to cover more DB storage functions (#1377) * Factor more things to RoomInfo * Factor out remaining bits for RoomInfo * Linting for now --- roomserver/internal/alias.go | 10 +- roomserver/internal/input_events.go | 24 ++-- roomserver/internal/input_latest_events.go | 23 ++-- roomserver/internal/perform_backfill.go | 13 +- roomserver/internal/perform_invite.go | 2 +- roomserver/internal/query.go | 128 +++++++++++------- roomserver/state/state.go | 41 ++---- roomserver/storage/interface.go | 6 +- roomserver/storage/postgres/rooms_table.go | 17 --- .../storage/shared/latest_events_updater.go | 13 +- roomserver/storage/shared/storage.go | 28 +--- roomserver/storage/sqlite3/rooms_table.go | 17 --- roomserver/storage/sqlite3/storage.go | 4 +- roomserver/storage/tables/interface.go | 1 - 14 files changed, 148 insertions(+), 179 deletions(-) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 4139582b6..d576a8175 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "time" "github.com/matrix-org/dendrite/roomserver/api" @@ -239,16 +240,19 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( } builder.AuthEvents = refs - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { return err } + if roomInfo == nil { + return fmt.Errorf("room %s does not exist", roomID) + } // Build the event now := time.Now() event, err := builder.Build( now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, - r.Cfg.Matrix.PrivateKey, roomVersion, + r.Cfg.Matrix.PrivateKey, roomInfo.RoomVersion, ) if err != nil { return err @@ -257,7 +261,7 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( // Create the request ire := api.InputRoomEvent{ Kind: api.KindNew, - Event: event.Headered(roomVersion), + Event: event.Headered(roomInfo.RoomVersion), AuthEventIDs: event.AuthEventIDs(), SendAsServer: serverName, } diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index a63082990..287db1af2 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -64,7 +64,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( } // Store the event. - roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } @@ -89,10 +89,18 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } + roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + if err != nil { + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if roomInfo == nil { + return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + } + if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) if err != nil { return "", fmt.Errorf("r.calculateAndSetState: %w", err) } @@ -100,7 +108,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( if err = r.updateLatestEvents( ctx, // context - roomNID, // room NID to update + roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event input.SendAsServer, // send as server @@ -135,19 +143,19 @@ func (r *RoomserverInternalAPI) processRoomEvent( func (r *RoomserverInternalAPI) calculateAndSetState( ctx context.Context, input api.InputRoomEvent, - roomNID types.RoomNID, + roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -161,14 +169,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState( return err } - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { return err } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { return err } } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index f11a78d72..d5e38e7a4 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -49,13 +49,13 @@ import ( // Can only be called once at a time func (r *RoomserverInternalAPI) updateLatestEvents( ctx context.Context, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } @@ -66,7 +66,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ctx: ctx, api: r, updater: updater, - roomNID: roomNID, + roomInfo: roomInfo, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, @@ -89,7 +89,7 @@ type latestEventsUpdater struct { ctx context.Context api *RoomserverInternalAPI updater *shared.LatestEventsUpdater - roomNID types.RoomNID + roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID @@ -196,7 +196,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } - if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { + if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } @@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB) + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error { // of the state after the events. The snapshot state will be resolved // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( - u.ctx, u.roomNID, latestStateAtEvents, + u.ctx, latestStateAtEvents, ) if err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) @@ -303,13 +303,8 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) - if err != nil { - return nil, err - } - ore := api.OutputNewRoomEvent{ - Event: u.event.Headered(roomVersion), + Event: u.event.Headered(u.roomInfo.RoomVersion), LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, @@ -337,7 +332,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) // include extra state events if they were added as nearly every downstream component will care about it // and we'd rather not have them all hit QueryEventsByID at the same time! if len(ore.AddsStateEventIDs) > 0 { - ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs) + ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs) if err != nil { return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) } diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform_backfill.go index 65c88860c..721f66106 100644 --- a/roomserver/internal/perform_backfill.go +++ b/roomserver/internal/perform_backfill.go @@ -162,6 +162,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr // It returns a list of servers which can be queried for backfill requests. These servers // will be servers that are in the room already. The entries at the beginning are preferred servers // and will be tried first. An empty list will fail the request. +// nolint:gocyclo func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use // its successor, so look it up. @@ -189,7 +190,17 @@ FindSuccessor: return nil } - stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) + info, err := b.db.RoomInfo(ctx, roomID) + if err != nil { + logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room") + return nil + } + if info == nil || info.IsStub { + logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing") + return nil + } + + stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform_invite.go index 1cfbcc18c..6690de055 100644 --- a/roomserver/internal/perform_invite.go +++ b/roomserver/internal/perform_invite.go @@ -208,7 +208,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(db, *info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 897164330..f8e8ba04d 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -38,27 +38,22 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { + return err + } + if roomInfo == nil || roomInfo.IsStub { response.RoomExists = false return nil } - roomState := state.NewStateResolution(r.DB) - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info.IsStub { - return nil - } + roomState := state.NewStateResolution(r.DB, *roomInfo) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = roomInfo.RoomVersion var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = - r.DB.LatestEventIDs(ctx, info.RoomNID) + r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) if err != nil { return err } @@ -85,7 +80,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) } return nil @@ -97,23 +92,17 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } + + roomState := state.NewStateResolution(r.DB, *info) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = info.RoomVersion prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { @@ -128,7 +117,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( // Look up the currrent state for the requested tuples. stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, info.RoomNID, prevStates, request.StateToFetch, + ctx, prevStates, request.StateToFetch, ) if err != nil { return err @@ -140,7 +129,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) } return nil @@ -168,7 +157,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID( } for _, event := range events { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) + roomVersion, verr := r.roomVersion(event.RoomID()) if verr != nil { return verr } @@ -277,7 +266,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) + stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -297,8 +286,8 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( return nil } -func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db) +func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -370,20 +359,28 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see return } - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID()) + roomID := events[0].RoomID() + isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID) if err != nil { return } + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) + } response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, request.EventID, request.ServerName, isServerInRoom, + ctx, *info, request.EventID, request.ServerName, isServerInRoom, ) return } func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err @@ -400,6 +397,7 @@ func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( } // QueryMissingEvents implements api.RoomserverInternalAPI +// nolint:gocyclo func (r *RoomserverInternalAPI) QueryMissingEvents( ctx context.Context, request *api.QueryMissingEventsRequest, @@ -418,8 +416,22 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( eventsToFilter[id] = true } } + events, err := r.DB.EventsFromIDs(ctx, front) + if err != nil { + return err + } + if len(events) == 0 { + return nil // we are missing the events being asked to search from, give up. + } + info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + } - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -432,7 +444,7 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) for _, event := range loadedEvents { if !eventsToFilter[event.EventID()] { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) + roomVersion, verr := r.roomVersion(event.RoomID()) if verr != nil { return verr } @@ -467,8 +479,16 @@ func (r *RoomserverInternalAPI) PerformBackfill( // this will include these events which is what we want front = request.PrevEventIDs() + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) + } + // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -481,19 +501,14 @@ func (r *RoomserverInternalAPI) PerformBackfill( } for _, event := range loadedEvents { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(info.RoomVersion)) } return err } func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { - roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID) + roomVer, err := r.roomVersion(req.RoomID) if err != nil { return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) } @@ -642,7 +657,7 @@ func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, // TODO: Remove this when we have tests to assert correctness of this function // nolint:gocyclo func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, front []string, visited map[string]bool, limit int, + ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID @@ -708,7 +723,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom) + allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -744,13 +759,13 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } response.RoomExists = true response.RoomVersion = info.RoomVersion - stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) + stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) if err != nil { return err } @@ -788,8 +803,8 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( return err } -func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(r.DB) +func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { + roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -937,15 +952,26 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( return nil } - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - response.RoomVersion = roomVersion + if info == nil { + return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) + } + response.RoomVersion = info.RoomVersion r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } +func (r *RoomserverInternalAPI) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { + var res api.QueryRoomVersionForRoomResponse + err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{ + RoomID: roomID, + }, &res) + return res.RoomVersion, err +} + func (r *RoomserverInternalAPI) QueryPublishedRooms( ctx context.Context, req *api.QueryPublishedRoomsRequest, diff --git a/roomserver/state/state.go b/roomserver/state/state.go index b9ad4a504..37e6807a3 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -31,12 +31,14 @@ import ( ) type StateResolution struct { - db storage.Database + db storage.Database + roomInfo types.RoomInfo } -func NewStateResolution(db storage.Database) StateResolution { +func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { return StateResolution{ - db: db, + db: db, + roomInfo: roomInfo, } } @@ -339,7 +341,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples( // This is typically the state before an event. // Returns a sorted list of state entries or an error if there was a problem talking to the database. func (v StateResolution) LoadStateAfterEventsForStringTuples( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { @@ -347,24 +349,18 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples( if err != nil { return nil, err } - return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples) + return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples) } func (v StateResolution) loadStateAfterEventsForNumericTuples( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return nil, err - } - if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] - var result []types.StateEntry - result, err = v.loadStateAtSnapshotForNumericTuples( + result, err := v.loadStateAtSnapshotForNumericTuples( ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples, ) if err != nil { @@ -403,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples( // TODO: Add metrics for this as it could take a long time for big rooms // with large conflicts. - fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) if err != nil { return nil, err } @@ -527,7 +523,6 @@ func init() { func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, event gomatrixserverlib.Event, - roomNID types.RoomNID, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -542,14 +537,13 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent( } // The state before this event will be the state after the events that came before it. - return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates) + return v.CalculateAndStoreStateAfterEvents(ctx, prevStates) } // CalculateAndStoreStateAfterEvents finds the room state after the given events. // Stores the resulting state in the database and returns a numeric ID for that snapshot. func (v StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, - roomNID types.RoomNID, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} @@ -558,7 +552,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) } @@ -590,7 +584,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" stateNID, err := v.db.AddState( - ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) @@ -601,7 +595,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // So fall through to calculateAndStoreStateAfterManyEvents } - stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics) if err != nil { return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) } @@ -624,13 +618,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { - roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return metrics.stop(0, err) - } - state, algorithm, conflictLength, err := - v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates) + v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm if err != nil { return metrics.stop(0, err) diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 5f6416145..ef7a9f090 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -66,8 +66,6 @@ type Database interface { Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) - // Look up a room version from the room NID. - GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -91,7 +89,7 @@ type Database interface { // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error. // Returns the latest events in the room and the last eventID sent to the log along with an updater. // If this returns an error then no further action is required. - GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error) + GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error) // Look up event ID by transaction's info. // This is used to determine if the room event is processed/processing already. // Returns an empty string if no such event exists. @@ -136,8 +134,6 @@ type Database interface { // not found. // Returns an error if the retrieval went wrong. EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) - // Look up the room version for a given room. - GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) // Publish or unpublish a room from the room directory. PublishRoom(ctx context.Context, roomID string, publish bool) error // Returns a list of room IDs for rooms which are published. diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 691c04ba6..13c8e703d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -68,9 +68,6 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1" -const selectRoomVersionForRoomIDSQL = "" + - "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" - const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" @@ -83,7 +80,6 @@ type roomStatements struct { selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt - selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt } @@ -100,7 +96,6 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, - {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, }.Prepare(db) @@ -192,18 +187,6 @@ func (s *roomStatements) UpdateLatestEventNIDs( return err } -func (s *roomStatements) SelectRoomVersionForRoomID( - ctx context.Context, txn *sql.Tx, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - var roomVersion gomatrixserverlib.RoomVersion - stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) - if err == sql.ErrNoRows { - return roomVersion, errors.New("room not found") - } - return roomVersion, err -} - func (s *roomStatements) SelectRoomVersionForRoomNID( ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index e9a0f6982..29eab0c98 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -12,15 +12,15 @@ import ( type LatestEventsUpdater struct { transaction d *Database - roomNID types.RoomNID + roomInfo types.RoomInfo latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID } -func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) { +func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) { eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -39,14 +39,13 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomN } } return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil } // RoomVersion implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { - version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID) - return + return u.roomInfo.RoomVersion } // LatestEvents implements types.RoomRecentEventsUpdater @@ -118,5 +117,5 @@ func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal) + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4af61be8f..6e0ebd2c2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -229,30 +229,6 @@ func (d *Database) StateEntries( return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -func (d *Database) GetRoomVersionForRoom( - ctx context.Context, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - return roomVersion, nil - } - return d.RoomsTable.SelectRoomVersionForRoomID( - ctx, nil, roomID, - ) -} - -func (d *Database) GetRoomVersionForRoomNID( - ctx context.Context, roomNID types.RoomNID, -) (gomatrixserverlib.RoomVersion, error) { - if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { - if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok { - return roomVersion, nil - } - } - return d.RoomsTable.SelectRoomVersionForRoomNID( - ctx, roomNID, - ) -} - func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) @@ -387,7 +363,7 @@ func (d *Database) MembershipUpdater( } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*LatestEventsUpdater, error) { txn, err := d.DB.Begin() if err != nil { @@ -395,7 +371,7 @@ func (d *Database) GetLatestEventsForUpdate( } var updater *LatestEventsUpdater _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { - updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) + updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) return nil }) return updater, err diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index fc1bcf22f..4c1699d00 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -58,9 +58,6 @@ const selectLatestEventNIDsForUpdateSQL = "" + const updateLatestEventNIDsSQL = "" + "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" -const selectRoomVersionForRoomIDSQL = "" + - "SELECT room_version FROM roomserver_rooms WHERE room_id = $1" - const selectRoomVersionForRoomNIDSQL = "" + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" @@ -74,7 +71,6 @@ type roomStatements struct { selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt - selectRoomVersionForRoomIDStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt } @@ -93,7 +89,6 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, - {&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, }.Prepare(db) @@ -198,18 +193,6 @@ func (s *roomStatements) UpdateLatestEventNIDs( return err } -func (s *roomStatements) SelectRoomVersionForRoomID( - ctx context.Context, txn *sql.Tx, roomID string, -) (gomatrixserverlib.RoomVersion, error) { - var roomVersion gomatrixserverlib.RoomVersion - stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt) - err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion) - if err == sql.ErrNoRows { - return roomVersion, errors.New("room not found") - } - return roomVersion, err -} - func (s *roomStatements) SelectRoomVersionForRoomNID( ctx context.Context, roomNID types.RoomNID, ) (gomatrixserverlib.RoomVersion, error) { diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 33782171e..4a74bf736 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -150,7 +150,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { } func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, + ctx context.Context, roomInfo types.RoomInfo, ) (*shared.LatestEventsUpdater, error) { // TODO: Do not use transactions. We should be holding open this transaction but we cannot have // multiple write transactions on sqlite. The code will perform additional @@ -158,7 +158,7 @@ func (d *Database) GetLatestEventsForUpdate( // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID) + return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo) } func (d *Database) MembershipUpdater( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index ca9159d07..c599dd3fe 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -63,7 +63,6 @@ type Rooms interface { SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error - SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error) SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) } From e473320e733484b1cc6da0588fd2ccf4affb3d24 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 2 Sep 2020 13:47:31 +0100 Subject: [PATCH 03/12] Refactor roomserver/internal - split perform stuff out (#1380) - New package `perform` which contains all `Perform` functions - New package `helpers` which contains helper functions used by both perform and query/input functions. - Perform invite/leave have no idea how to `WriteOutputEvents` and this is now returned from `PerformInvite` or `PerformLeave` respectively. Still to do: - RSAPI is fed into the inviter/joiner/leaver - this introduces circular logic so will need to be removed. - Put query operations in a `query` package. - Put input operations (and output) in an `input` package. - Factor out helper functions as much as possible, possibly rejigging the storage layer in the process. --- build/scripts/complement.sh | 2 +- roomserver/internal/api.go | 120 ++++- .../{input_authevents.go => helpers/auth.go} | 16 +- .../auth_test.go} | 6 +- roomserver/internal/helpers/helpers.go | 326 ++++++++++++ roomserver/internal/input.go | 9 - roomserver/internal/input_events.go | 3 +- roomserver/internal/input_membership.go | 37 +- .../{ => perform}/perform_backfill.go | 237 ++++++++- .../internal/{ => perform}/perform_invite.go | 59 ++- .../internal/{ => perform}/perform_join.go | 37 +- .../internal/{ => perform}/perform_leave.go | 126 ++--- .../internal/{ => perform}/perform_publish.go | 9 +- roomserver/internal/query.go | 466 +----------------- roomserver/roomserver.go | 14 +- 15 files changed, 820 insertions(+), 647 deletions(-) rename roomserver/internal/{input_authevents.go => helpers/auth.go} (96%) rename roomserver/internal/{input_authevents_test.go => helpers/auth_test.go} (97%) create mode 100644 roomserver/internal/helpers/helpers.go rename roomserver/internal/{ => perform}/perform_backfill.go (55%) rename roomserver/internal/{ => perform}/perform_invite.go (83%) rename roomserver/internal/{ => perform}/perform_join.go (89%) rename roomserver/internal/{ => perform}/perform_leave.go (53%) rename roomserver/internal/{ => perform}/perform_publish.go (67%) diff --git a/build/scripts/complement.sh b/build/scripts/complement.sh index 17ddea57e..c1e52dde6 100755 --- a/build/scripts/complement.sh +++ b/build/scripts/complement.sh @@ -15,5 +15,5 @@ tar -xzf master.tar.gz # Run the tests! cd complement-master -COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v ./tests +COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v -count=1 ./tests diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index f94c72f05..1897f7a53 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -1,12 +1,15 @@ package internal import ( + "context" "sync" "github.com/Shopify/sarama" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/perform" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -20,7 +23,122 @@ type RoomserverInternalAPI struct { ServerName gomatrixserverlib.ServerName KeyRing gomatrixserverlib.JSONVerifier FedClient *gomatrixserverlib.FederationClient - OutputRoomEventTopic string // Kafka topic for new output room events + OutputRoomEventTopic string // Kafka topic for new output room events + Inviter *perform.Inviter + Joiner *perform.Joiner + Leaver *perform.Leaver + Publisher *perform.Publisher + Backfiller *perform.Backfiller mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent fsAPI fsAPI.FederationSenderInternalAPI } + +func NewRoomserverAPI( + cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer, + outputRoomEventTopic string, caches caching.RoomServerCaches, fedClient *gomatrixserverlib.FederationClient, + keyRing gomatrixserverlib.JSONVerifier, +) *RoomserverInternalAPI { + a := &RoomserverInternalAPI{ + DB: roomserverDB, + Cfg: cfg, + Producer: producer, + Cache: caches, + ServerName: cfg.Matrix.ServerName, + KeyRing: keyRing, + FedClient: fedClient, + OutputRoomEventTopic: outputRoomEventTopic, + // perform-er structs get initialised when we have a federation sender to use + } + return a +} + +// SetFederationSenderInputAPI passes in a federation sender input API reference +// so that we can avoid the chicken-and-egg problem of both the roomserver input API +// and the federation sender input API being interdependent. +func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { + r.fsAPI = fsAPI + + r.Inviter = &perform.Inviter{ + DB: r.DB, + Cfg: r.Cfg, + FSAPI: r.fsAPI, + RSAPI: r, + } + r.Joiner = &perform.Joiner{ + ServerName: r.Cfg.Matrix.ServerName, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + RSAPI: r, + } + r.Leaver = &perform.Leaver{ + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + RSAPI: r, + } + r.Publisher = &perform.Publisher{ + DB: r.DB, + } + r.Backfiller = &perform.Backfiller{ + ServerName: r.ServerName, + DB: r.DB, + FedClient: r.FedClient, + KeyRing: r.KeyRing, + } +} + +func (r *RoomserverInternalAPI) PerformInvite( + ctx context.Context, + req *api.PerformInviteRequest, + res *api.PerformInviteResponse, +) error { + outputEvents, err := r.Inviter.PerformInvite(ctx, req, res) + if err != nil { + return err + } + if len(outputEvents) == 0 { + return nil + } + return r.WriteOutputEvents(req.Event.RoomID(), outputEvents) +} + +func (r *RoomserverInternalAPI) PerformJoin( + ctx context.Context, + req *api.PerformJoinRequest, + res *api.PerformJoinResponse, +) { + r.Joiner.PerformJoin(ctx, req, res) +} + +func (r *RoomserverInternalAPI) PerformLeave( + ctx context.Context, + req *api.PerformLeaveRequest, + res *api.PerformLeaveResponse, +) error { + outputEvents, err := r.Leaver.PerformLeave(ctx, req, res) + if err != nil { + return err + } + if len(outputEvents) == 0 { + return nil + } + return r.WriteOutputEvents(req.RoomID, outputEvents) +} + +func (r *RoomserverInternalAPI) PerformPublish( + ctx context.Context, + req *api.PerformPublishRequest, + res *api.PerformPublishResponse, +) { + r.Publisher.PerformPublish(ctx, req, res) +} + +// Query a given amount (or less) of events prior to a given set of events. +func (r *RoomserverInternalAPI) PerformBackfill( + ctx context.Context, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, +) error { + return r.Backfiller.PerformBackfill(ctx, request, response) +} diff --git a/roomserver/internal/input_authevents.go b/roomserver/internal/helpers/auth.go similarity index 96% rename from roomserver/internal/input_authevents.go rename to roomserver/internal/helpers/auth.go index e3828f566..060f0a0e9 100644 --- a/roomserver/internal/input_authevents.go +++ b/roomserver/internal/helpers/auth.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package helpers import ( "context" @@ -23,9 +23,9 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// checkAuthEvents checks that the event passes authentication checks +// CheckAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. -func checkAuthEvents( +func CheckAuthEvents( ctx context.Context, db storage.Database, event gomatrixserverlib.HeaderedEvent, @@ -63,7 +63,7 @@ func checkAuthEvents( type authEvents struct { stateKeyNIDMap map[string]types.EventStateKeyNID state stateEntryMap - events eventMap + events EventMap } // Create implements gomatrixserverlib.AuthEventProvider @@ -99,7 +99,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) * if !ok { return nil } - event, ok := ae.events.lookup(eventNID) + event, ok := ae.events.Lookup(eventNID) if !ok { return nil } @@ -118,7 +118,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * if !ok { return nil } - event, ok := ae.events.lookup(eventNID) + event, ok := ae.events.Lookup(eventNID) if !ok { return nil } @@ -224,10 +224,10 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.Even // Map from numeric event ID to event. // Implemented using binary search on a sorted array. -type eventMap []types.Event +type EventMap []types.Event // lookup an entry in the event map. -func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) { +func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) { // Since the list is sorted we can implement this using binary search. // This is faster than using a hash map. // We don't have to worry about pathological cases because the keys are fixed diff --git a/roomserver/internal/input_authevents_test.go b/roomserver/internal/helpers/auth_test.go similarity index 97% rename from roomserver/internal/input_authevents_test.go rename to roomserver/internal/helpers/auth_test.go index 6b981571b..2a1c3ea49 100644 --- a/roomserver/internal/input_authevents_test.go +++ b/roomserver/internal/helpers/auth_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package helpers import ( "testing" @@ -95,7 +95,7 @@ func TestStateEntryMap(t *testing.T) { } func TestEventMap(t *testing.T) { - events := eventMap([]types.Event{ + events := EventMap([]types.Event{ {EventNID: 1}, {EventNID: 2}, {EventNID: 3}, @@ -123,7 +123,7 @@ func TestEventMap(t *testing.T) { } for _, testCase := range testCases { - gotEvent, gotOK := events.lookup(testCase.inputEventNID) + gotEvent, gotOK := events.Lookup(testCase.inputEventNID) if testCase.wantOK != gotOK { t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK) } diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go new file mode 100644 index 000000000..d7bb40af5 --- /dev/null +++ b/roomserver/internal/helpers/helpers.go @@ -0,0 +1,326 @@ +package helpers + +import ( + "context" + "fmt" + + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// TODO: temporary package which has helper functions used by both internal/perform packages. +// Move these to a more sensible place. + +func UpdateToInviteMembership( + mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, + roomVersion gomatrixserverlib.RoomVersion, +) ([]api.OutputEvent, error) { + // We may have already sent the invite to the user, either because we are + // reprocessing this event, or because the we received this invite from a + // remote server via the federation invite API. In those cases we don't need + // to send the event. + needsSending, err := mu.SetToInvite(*add) + if err != nil { + return nil, err + } + if needsSending { + // We notify the consumers using a special event even though we will + // notify them about the change in current state as part of the normal + // room event stream. This ensures that the consumers only have to + // consider a single stream of events when determining whether a user + // is invited, rather than having to combine multiple streams themselves. + onie := api.OutputNewInviteEvent{ + Event: add.Headered(roomVersion), + RoomVersion: roomVersion, + } + updates = append(updates, api.OutputEvent{ + Type: api.OutputTypeNewInviteEvent, + NewInviteEvent: &onie, + }) + } + return updates, nil +} + +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { + info, err := db.RoomInfo(ctx, roomID) + if err != nil { + return false, err + } + if info == nil { + return false, fmt.Errorf("unknown room %s", roomID) + } + + eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + if err != nil { + return false, err + } + + events, err := db.Events(ctx, eventNIDs) + if err != nil { + return false, err + } + gmslEvents := make([]gomatrixserverlib.Event, len(events)) + for i := range events { + gmslEvents[i] = events[i].Event + } + return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil +} + +func IsInvitePending( + ctx context.Context, db storage.Database, + roomID, userID string, +) (bool, string, string, error) { + // Look up the room NID for the supplied room ID. + info, err := db.RoomInfo(ctx, roomID) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if info == nil { + return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) + } + + // Look up the state key NID for the supplied user ID. + targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID}) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) + } + targetUserNID, targetUserFound := targetUserNIDs[userID] + if !targetUserFound { + return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) + } + + // Let's see if we have an event active for the user in the room. If + // we do then it will contain a server name that we can direct the + // send_leave to. + senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) + } + if len(senderUserNIDs) == 0 { + return false, "", "", nil + } + userNIDToEventID := make(map[types.EventStateKeyNID]string) + for i, nid := range senderUserNIDs { + userNIDToEventID[nid] = eventIDs[i] + } + + // Look up the user ID from the NID. + senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs) + if err != nil { + return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) + } + if len(senderUsers) == 0 { + return false, "", "", fmt.Errorf("no senderUsers") + } + + senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] + if !senderUserFound { + return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) + } + + return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil +} + +// GetMembershipsAtState filters the state events to +// only keep the "m.room.member" events with a "join" membership. These events are returned. +// Returns an error if there was an issue fetching the events. +func GetMembershipsAtState( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, +) ([]types.Event, error) { + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomMemberNID { + eventNIDs = append(eventNIDs, entry.EventNID) + } + } + + // Get all of the events in this state + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + if !joinedOnly { + return stateEvents, nil + } + + // Filter the events to only keep the "join" membership events + var events []types.Event + for _, event := range stateEvents { + membership, err := event.Membership() + if err != nil { + return nil, err + } + + if membership == gomatrixserverlib.Join { + events = append(events, event) + } + } + + return events, nil +} + +func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) + // Lookup the event NID + eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) + if err != nil { + return nil, err + } + eventIDs := []string{eIDs[eventNID]} + + prevState, err := db.StateAtEventIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + // Fetch the state as it was when this event was fired + return roomState.LoadCombinedStateAfterEvents(ctx, prevState) +} + +func LoadEvents( + ctx context.Context, db storage.Database, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.Event, error) { + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + + result := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + result[i] = stateEvents[i].Event + } + return result, nil +} + +func LoadStateEvents( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, +) ([]gomatrixserverlib.Event, error) { + eventNIDs := make([]types.EventNID, len(stateEntries)) + for i := range stateEntries { + eventNIDs[i] = stateEntries[i].EventNID + } + return LoadEvents(ctx, db, eventNIDs) +} + +func CheckServerAllowedToSeeEvent( + ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, +) (bool, error) { + roomState := state.NewStateResolution(db, info) + stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) + if err != nil { + return false, err + } + + // TODO: We probably want to make it so that we don't have to pull + // out all the state if possible. + stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries) + if err != nil { + return false, err + } + + return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil +} + +// TODO: Remove this when we have tests to assert correctness of this function +// nolint:gocyclo +func ScanEventTree( + ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int, + serverName gomatrixserverlib.ServerName, +) ([]types.EventNID, error) { + var resultNIDs []types.EventNID + var err error + var allowed bool + var events []types.Event + var next []string + var pre string + + // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) + // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing + // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in + // duplicate events being sent in response to /backfill requests. + initialIgnoreList := make(map[string]bool, len(visited)) + for k, v := range visited { + initialIgnoreList[k] = v + } + + resultNIDs = make([]types.EventNID, 0, limit) + + var checkedServerInRoom bool + var isServerInRoom bool + + // Loop through the event IDs to retrieve the requested events and go + // through the whole tree (up to the provided limit) using the events' + // "prev_event" key. +BFSLoop: + for len(front) > 0 { + // Prevent unnecessary allocations: reset the slice only when not empty. + if len(next) > 0 { + next = make([]string, 0) + } + // Retrieve the events to process from the database. + events, err = db.EventsFromIDs(ctx, front) + if err != nil { + return resultNIDs, err + } + + if !checkedServerInRoom && len(events) > 0 { + // It's nasty that we have to extract the room ID from an event, but many federation requests + // only talk in event IDs, no room IDs at all (!!!) + ev := events[0] + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") + } + checkedServerInRoom = true + } + + for _, ev := range events { + // Break out of the loop if the provided limit is reached. + if len(resultNIDs) == limit { + break BFSLoop + } + + if !initialIgnoreList[ev.EventID()] { + // Update the list of events to retrieve. + resultNIDs = append(resultNIDs, ev.EventNID) + } + // Loop through the event's parents. + for _, pre = range ev.PrevEventIDs() { + // Only add an event to the list of next events to process if it + // hasn't been seen before. + if !visited[pre] { + visited[pre] = true + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + if err != nil { + util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( + "Error checking if allowed to see event", + ) + return resultNIDs, err + } + + // If the event hasn't been seen before and the HS + // requesting to retrieve it is allowed to do so, add it to + // the list of events to retrieve. + if allowed { + next = append(next, pre) + } else { + util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") + } + } + } + } + // Repeat the same process with the parent events we just processed. + front = next + } + + return resultNIDs, err +} diff --git a/roomserver/internal/input.go b/roomserver/internal/input.go index e85e9830d..dbf67b792 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input.go @@ -23,17 +23,8 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" log "github.com/sirupsen/logrus" - - fsAPI "github.com/matrix-org/dendrite/federationsender/api" ) -// SetFederationSenderInputAPI passes in a federation sender input API reference -// so that we can avoid the chicken-and-egg problem of both the roomserver input API -// and the federation sender input API being interdependent. -func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) { - r.fsAPI = fsAPI -} - // WriteOutputEvents implements OutputRoomEventWriter func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { messages := make([]*sarama.ProducerMessage, len(updates)) diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index 287db1af2..edc8b416a 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -45,7 +46,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if err != nil { logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") return diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input_membership.go index bcecfca0e..57a945966 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input_membership.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -59,13 +60,13 @@ func (r *RoomserverInternalAPI) updateMemberships( var re *gomatrixserverlib.Event targetUserNID := change.EventStateKeyNID if change.removedEventNID != 0 { - ev, _ := eventMap(events).lookup(change.removedEventNID) + ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) if ev != nil { re = &ev.Event } } if change.addedEventNID != 0 { - ev, _ := eventMap(events).lookup(change.addedEventNID) + ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) if ev != nil { ae = &ev.Event } @@ -120,7 +121,7 @@ func (r *RoomserverInternalAPI) updateMembership( switch newMembership { case gomatrixserverlib.Invite: - return updateToInviteMembership(mu, add, updates, updater.RoomVersion()) + return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion()) case gomatrixserverlib.Join: return updateToJoinMembership(mu, add, updates) case gomatrixserverlib.Leave, gomatrixserverlib.Ban: @@ -141,36 +142,6 @@ func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bo return isTargetLocalUser } -func updateToInviteMembership( - mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, - roomVersion gomatrixserverlib.RoomVersion, -) ([]api.OutputEvent, error) { - // We may have already sent the invite to the user, either because we are - // reprocessing this event, or because the we received this invite from a - // remote server via the federation invite API. In those cases we don't need - // to send the event. - needsSending, err := mu.SetToInvite(*add) - if err != nil { - return nil, err - } - if needsSending { - // We notify the consumers using a special event even though we will - // notify them about the change in current state as part of the normal - // room event stream. This ensures that the consumers only have to - // consider a single stream of events when determining whether a user - // is invited, rather than having to combine multiple streams themselves. - onie := api.OutputNewInviteEvent{ - Event: add.Headered(roomVersion), - RoomVersion: roomVersion, - } - updates = append(updates, api.OutputEvent{ - Type: api.OutputTypeNewInviteEvent, - NewInviteEvent: &onie, - }) - } - return updates, nil -} - func updateToJoinMembership( mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, ) ([]api.OutputEvent, error) { diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go similarity index 55% rename from roomserver/internal/perform_backfill.go rename to roomserver/internal/perform/perform_backfill.go index 721f66106..ebb66ef42 100644 --- a/roomserver/internal/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -1,9 +1,13 @@ -package internal +package perform import ( "context" + "fmt" + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -11,6 +15,189 @@ import ( "github.com/sirupsen/logrus" ) +type Backfiller struct { + ServerName gomatrixserverlib.ServerName + DB storage.Database + FedClient *gomatrixserverlib.FederationClient + KeyRing gomatrixserverlib.JSONVerifier +} + +// PerformBackfill implements api.RoomServerQueryAPI +func (r *Backfiller) PerformBackfill( + ctx context.Context, + request *api.PerformBackfillRequest, + response *api.PerformBackfillResponse, +) error { + // if we are requesting the backfill then we need to do a federation hit + // TODO: we could be more sensible and fetch as many events we already have then request the rest + // which is what the syncapi does already. + if request.ServerName == r.ServerName { + return r.backfillViaFederation(ctx, request, response) + } + // someone else is requesting the backfill, try to service their request. + var err error + var front []string + + // The limit defines the maximum number of events to retrieve, so it also + // defines the highest number of elements in the map below. + visited := make(map[string]bool, request.Limit) + + // this will include these events which is what we want + front = request.PrevEventIDs() + + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) + } + + // Scan the event tree for events to send back. + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) + if err != nil { + return err + } + + // Retrieve events from the list that was filled previously. + var loadedEvents []gomatrixserverlib.Event + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) + if err != nil { + return err + } + + for _, event := range loadedEvents { + response.Events = append(response.Events, event.Headered(info.RoomVersion)) + } + + return err +} + +func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { + info, err := r.DB.RoomInfo(ctx, req.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) + } + requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) + // Request 100 items regardless of what the query asks for. + // We don't want to go much higher than this. + // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass + // (so we don't need to hit /state_ids which the test has no listener for) + // Specifically the test "Outbound federation can backfill events" + events, err := gomatrixserverlib.RequestBackfill( + ctx, requester, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) + if err != nil { + return err + } + logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + + // persist these new events - auth checks have already been done + roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) + if err != nil { + return err + } + + for _, ev := range backfilledEventMap { + // now add state for these events + stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] + if !ok { + // this should be impossible as all events returned must have pass Step 5 of the PDU checks + // which requires a list of state IDs. + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") + continue + } + var entries []types.StateEntry + if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { + // attempt to fetch the missing events + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + // try again + entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") + return err + } + } + + var beforeStateSnapshotNID types.StateSnapshotNID + if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") + return err + } + if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") + } + } + + // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. + + res.Events = events + return nil +} + +// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just +// best effort. +func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + backfillRequester *backfillRequester, stateIDs []string) { + + servers := backfillRequester.servers + + // work out which are missing + nidMap, err := r.DB.EventNIDs(ctx, stateIDs) + if err != nil { + util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") + return + } + missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event + for _, id := range stateIDs { + if _, ok := nidMap[id]; !ok { + missingMap[id] = nil + } + } + util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) + + // fetch the events from federation. Loop the servers first so if we find one that works we stick with them + for _, srv := range servers { + for id, ev := range missingMap { + if ev != nil { + continue // already found + } + logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) + res, err := r.FedClient.GetEvent(ctx, srv, id) + if err != nil { + logger.WithError(err).Warn("failed to get event from server") + continue + } + loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + if err != nil { + logger.WithError(err).Warn("failed to load and verify event") + continue + } + logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) + for _, res := range result { + if res.Error != nil { + logger.WithError(err).Warn("event failed PDU checks") + continue + } + missingMap[id] = res.Event + } + } + } + + var newEvents []gomatrixserverlib.HeaderedEvent + for _, ev := range missingMap { + if ev != nil { + newEvents = append(newEvents, *ev) + } + } + util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) + persistEvents(ctx, r.DB, newEvents) +} + // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database @@ -200,7 +387,7 @@ FindSuccessor: return nil } - stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil @@ -217,7 +404,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -314,3 +501,47 @@ func joinEventsFromHistoryVisibility( } return db.Events(ctx, joinEventNIDs) } + +func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { + var roomNID types.RoomNID + backfilledEventMap := make(map[string]types.Event) + for j, ev := range events { + nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) + if err != nil { // this shouldn't happen as RequestBackfill already found them + logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") + continue + } + authNids := make([]types.EventNID, len(nidMap)) + i := 0 + for _, nid := range nidMap { + authNids[i] = nid + i++ + } + var stateAtEvent types.StateAtEvent + var redactedEventID string + var redactionEvent *gomatrixserverlib.Event + roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") + continue + } + // If storing this event results in it being redacted, then do so. + // It's also possible for this event to be a redaction which results in another event being + // redacted, which we don't care about since we aren't returning it in this backfill. + if redactedEventID == ev.EventID() { + eventToRedact := ev.Unwrap() + redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") + continue + } + ev = redactedEvent.Headered(ev.RoomVersion) + events[j] = ev + } + backfilledEventMap[ev.EventID()] = types.Event{ + EventNID: stateAtEvent.StateEntry.EventNID, + Event: ev.Unwrap(), + } + } + return roomNID, backfilledEventMap +} diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform/perform_invite.go similarity index 83% rename from roomserver/internal/perform_invite.go rename to roomserver/internal/perform/perform_invite.go index 6690de055..7320388e7 100644 --- a/roomserver/internal/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -1,11 +1,13 @@ -package internal +package perform import ( "context" "fmt" federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -13,22 +15,31 @@ import ( log "github.com/sirupsen/logrus" ) +type Inviter struct { + DB storage.Database + Cfg *config.RoomServer + FSAPI federationSenderAPI.FederationSenderInternalAPI + + // TODO FIXME: Remove this + RSAPI api.RoomserverInternalAPI +} + // nolint:gocyclo -func (r *RoomserverInternalAPI) PerformInvite( +func (r *Inviter) PerformInvite( ctx context.Context, req *api.PerformInviteRequest, res *api.PerformInviteResponse, -) error { +) ([]api.OutputEvent, error) { event := req.Event if event.StateKey() == nil { - return fmt.Errorf("invite must be a state event") + return nil, fmt.Errorf("invite must be a state event") } roomID := event.RoomID() targetUserID := *event.StateKey() info, err := r.DB.RoomInfo(ctx, roomID) if err != nil { - return fmt.Errorf("Failed to load RoomInfo: %w", err) + return nil, fmt.Errorf("Failed to load RoomInfo: %w", err) } log.WithFields(log.Fields{ @@ -52,11 +63,11 @@ func (r *RoomserverInternalAPI) PerformInvite( } if len(inviteState) == 0 { if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil { - return fmt.Errorf("event.SetUnsignedField: %w", err) + return nil, fmt.Errorf("event.SetUnsignedField: %w", err) } } else { if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil { - return fmt.Errorf("event.SetUnsignedField: %w", err) + return nil, fmt.Errorf("event.SetUnsignedField: %w", err) } } @@ -64,7 +75,7 @@ func (r *RoomserverInternalAPI) PerformInvite( if info != nil { _, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) if err != nil { - return fmt.Errorf("r.DB.GetMembership: %w", err) + return nil, fmt.Errorf("r.DB.GetMembership: %w", err) } } if isAlreadyJoined { @@ -99,7 +110,7 @@ func (r *RoomserverInternalAPI) PerformInvite( Code: api.PerformErrorNotAllowed, Msg: "User is already joined to room", } - return nil + return nil, nil } if isOriginLocal { @@ -107,7 +118,7 @@ func (r *RoomserverInternalAPI) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = checkAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) if err != nil { log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -117,9 +128,9 @@ func (r *RoomserverInternalAPI) PerformInvite( Msg: err.Error(), Code: api.PerformErrorNotAllowed, } - return nil + return nil, nil } - return fmt.Errorf("checkAuthEvents: %w", err) + return nil, fmt.Errorf("checkAuthEvents: %w", err) } // If the invite originated from us and the target isn't local then we @@ -133,13 +144,13 @@ func (r *RoomserverInternalAPI) PerformInvite( InviteRoomState: inviteState, } fsRes := &federationSenderAPI.PerformInviteResponse{} - if err = r.fsAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { + if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { res.Error = &api.PerformError{ Msg: err.Error(), Code: api.PerformErrorNoOperation, } - log.WithError(err).WithField("event_id", event.EventID()).Error("r.fsAPI.PerformInvite failed") - return nil + log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") + return nil, nil } event = fsRes.Event } @@ -159,8 +170,8 @@ func (r *RoomserverInternalAPI) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { - return fmt.Errorf("r.InputRoomEvents: %w", err) + if err = r.RSAPI.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } } else { // The invite originated over federation. Process the membership @@ -168,25 +179,23 @@ func (r *RoomserverInternalAPI) PerformInvite( // invite. updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { - return fmt.Errorf("r.DB.MembershipUpdater: %w", err) + return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } unwrapped := event.Unwrap() - outputUpdates, err := updateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion) + outputUpdates, err := helpers.UpdateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion) if err != nil { - return fmt.Errorf("updateToInviteMembership: %w", err) + return nil, fmt.Errorf("updateToInviteMembership: %w", err) } if err = updater.Commit(); err != nil { - return fmt.Errorf("updater.Commit: %w", err) + return nil, fmt.Errorf("updater.Commit: %w", err) } - if err = r.WriteOutputEvents(roomID, outputUpdates); err != nil { - return fmt.Errorf("r.WriteOutputEvents: %w", err) - } + return outputUpdates, nil } - return nil + return nil, nil } func buildInviteStrippedState( diff --git a/roomserver/internal/perform_join.go b/roomserver/internal/perform/perform_join.go similarity index 89% rename from roomserver/internal/perform_join.go rename to roomserver/internal/perform/perform_join.go index 3b9b1b3ca..c8e6e8e60 100644 --- a/roomserver/internal/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -1,4 +1,4 @@ -package internal +package perform import ( "context" @@ -8,14 +8,27 @@ import ( "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) +type Joiner struct { + ServerName gomatrixserverlib.ServerName + Cfg *config.RoomServer + FSAPI fsAPI.FederationSenderInternalAPI + DB storage.Database + + // TODO FIXME: Remove this + RSAPI api.RoomserverInternalAPI +} + // PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender. -func (r *RoomserverInternalAPI) PerformJoin( +func (r *Joiner) PerformJoin( ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse, @@ -34,7 +47,7 @@ func (r *RoomserverInternalAPI) PerformJoin( res.RoomID = roomID } -func (r *RoomserverInternalAPI) performJoin( +func (r *Joiner) performJoin( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -63,7 +76,7 @@ func (r *RoomserverInternalAPI) performJoin( } } -func (r *RoomserverInternalAPI) performJoinRoomByAlias( +func (r *Joiner) performJoinRoomByAlias( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -85,7 +98,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias( ServerName: domain, // the server to ask } dirRes := fsAPI.PerformDirectoryLookupResponse{} - err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) + err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes) if err != nil { logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias) return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err) @@ -112,7 +125,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias( // TODO: Break this function up a bit // nolint:gocyclo -func (r *RoomserverInternalAPI) performJoinRoomByID( +func (r *Joiner) performJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, ) (string, error) { @@ -161,8 +174,8 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( // where we might think we know about a room in the following // section but don't know the latest state as all of our users // have left. - serverInRoom, _ := r.isServerCurrentlyInRoom(ctx, r.ServerName, req.RoomIDOrAlias) - isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID) + serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) + isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) if err == nil && isInvitePending && !serverInRoom { // Check if there's an invite pending. _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) @@ -194,7 +207,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( &eb, // the template join event r.Cfg.Matrix, // the server configuration time.Now(), // the event timestamp to use - r, // the roomserver API to use + r.RSAPI, // the roomserver API to use &buildRes, // the query response ) @@ -228,7 +241,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { var notAllowed *gomatrixserverlib.NotAllowed if errors.As(err, ¬Allowed) { return "", &api.PerformError{ @@ -271,7 +284,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID( return req.RoomIDOrAlias, nil } -func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( +func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *api.PerformJoinRequest, ) error { @@ -283,7 +296,7 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID( Content: req.Content, // the membership event content } fedRes := fsAPI.PerformJoinResponse{} - r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes) + r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes) if fedRes.LastError != nil { return &api.PerformError{ Code: api.PerformErrRemote, diff --git a/roomserver/internal/perform_leave.go b/roomserver/internal/perform/perform_leave.go similarity index 53% rename from roomserver/internal/perform_leave.go rename to roomserver/internal/perform/perform_leave.go index b8603147c..b4053eed6 100644 --- a/roomserver/internal/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -1,4 +1,4 @@ -package internal +package perform import ( "context" @@ -7,39 +7,50 @@ import ( "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" + "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" ) +type Leaver struct { + Cfg *config.RoomServer + DB storage.Database + FSAPI fsAPI.FederationSenderInternalAPI + + // TODO FIXME: Remove this + RSAPI api.RoomserverInternalAPI +} + // WriteOutputEvents implements OutputRoomEventWriter -func (r *RoomserverInternalAPI) PerformLeave( +func (r *Leaver) PerformLeave( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, -) error { +) ([]api.OutputEvent, error) { _, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { - return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID) + return nil, fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID) } if domain != r.Cfg.Matrix.ServerName { - return fmt.Errorf("User %q does not belong to this homeserver", req.UserID) + return nil, fmt.Errorf("User %q does not belong to this homeserver", req.UserID) } if strings.HasPrefix(req.RoomID, "!") { return r.performLeaveRoomByID(ctx, req, res) } - return fmt.Errorf("Room ID %q is invalid", req.RoomID) + return nil, fmt.Errorf("Room ID %q is invalid", req.RoomID) } -func (r *RoomserverInternalAPI) performLeaveRoomByID( +func (r *Leaver) performLeaveRoomByID( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam -) error { +) ([]api.OutputEvent, error) { // If there's an invite outstanding for the room then respond to // that. - isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID) + isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID) if err == nil && isInvitePending { return r.performRejectInvite(ctx, req, res, senderUser, eventID) } @@ -56,25 +67,25 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( }, } latestRes := api.QueryLatestEventsAndStateResponse{} - if err = r.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil { - return err + if err = r.RSAPI.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil { + return nil, err } if !latestRes.RoomExists { - return fmt.Errorf("Room %q does not exist", req.RoomID) + return nil, fmt.Errorf("Room %q does not exist", req.RoomID) } // Now let's see if the user is in the room. if len(latestRes.StateEvents) == 0 { - return fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID) + return nil, fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID) } membership, err := latestRes.StateEvents[0].Membership() if err != nil { - return fmt.Errorf("Error getting membership: %w", err) + return nil, fmt.Errorf("Error getting membership: %w", err) } if membership != gomatrixserverlib.Join { // TODO: should be able to handle "invite" in this case too, if // it's a case of kicking or banning or such - return fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership) + return nil, fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership) } // Prepare the template for the leave event. @@ -87,10 +98,10 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( Redacts: "", } if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil { - return fmt.Errorf("eb.SetContent: %w", err) + return nil, fmt.Errorf("eb.SetContent: %w", err) } if err = eb.SetUnsigned(struct{}{}); err != nil { - return fmt.Errorf("eb.SetUnsigned: %w", err) + return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } // We know that the user is in the room at this point so let's build @@ -103,11 +114,11 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( &eb, // the template leave event r.Cfg.Matrix, // the server configuration time.Now(), // the event timestamp to use - r, // the roomserver API to use + r.RSAPI, // the roomserver API to use &buildRes, // the query response ) if err != nil { - return fmt.Errorf("eventutil.BuildEvent: %w", err) + return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } // Give our leave event to the roomserver input stream. The @@ -124,22 +135,22 @@ func (r *RoomserverInternalAPI) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { - return fmt.Errorf("r.InputRoomEvents: %w", err) + if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } - return nil + return nil, nil } -func (r *RoomserverInternalAPI) performRejectInvite( +func (r *Leaver) performRejectInvite( ctx context.Context, req *api.PerformLeaveRequest, res *api.PerformLeaveResponse, // nolint:unparam senderUser, eventID string, -) error { +) ([]api.OutputEvent, error) { _, domain, err := gomatrixserverlib.SplitID('@', senderUser) if err != nil { - return fmt.Errorf("User ID %q invalid: %w", senderUser, err) + return nil, fmt.Errorf("User ID %q invalid: %w", senderUser, err) } // Ask the federation sender to perform a federated leave for us. @@ -149,13 +160,13 @@ func (r *RoomserverInternalAPI) performRejectInvite( ServerNames: []gomatrixserverlib.ServerName{domain}, } leaveRes := fsAPI.PerformLeaveResponse{} - if err := r.fsAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { - return err + if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil { + return nil, err } // Withdraw the invite, so that the sync API etc are // notified that we rejected it. - return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{ + return []api.OutputEvent{ { Type: api.OutputTypeRetireInviteEvent, RetireInviteEvent: &api.OutputRetireInviteEvent{ @@ -164,60 +175,5 @@ func (r *RoomserverInternalAPI) performRejectInvite( TargetUserID: req.UserID, }, }, - }) -} - -func (r *RoomserverInternalAPI) isInvitePending( - ctx context.Context, - roomID, userID string, -) (bool, string, string, error) { - // Look up the room NID for the supplied room ID. - info, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err) - } - if info == nil { - return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID) - } - - // Look up the state key NID for the supplied user ID. - targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID}) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err) - } - targetUserNID, targetUserFound := targetUserNIDs[userID] - if !targetUserFound { - return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs) - } - - // Let's see if we have an event active for the user in the room. If - // we do then it will contain a server name that we can direct the - // send_leave to. - senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, info.RoomNID, targetUserNID) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err) - } - if len(senderUserNIDs) == 0 { - return false, "", "", nil - } - userNIDToEventID := make(map[types.EventStateKeyNID]string) - for i, nid := range senderUserNIDs { - userNIDToEventID[nid] = eventIDs[i] - } - - // Look up the user ID from the NID. - senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs) - if err != nil { - return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err) - } - if len(senderUsers) == 0 { - return false, "", "", fmt.Errorf("no senderUsers") - } - - senderUser, senderUserFound := senderUsers[senderUserNIDs[0]] - if !senderUserFound { - return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers) - } - - return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil + }, nil } diff --git a/roomserver/internal/perform_publish.go b/roomserver/internal/perform/perform_publish.go similarity index 67% rename from roomserver/internal/perform_publish.go rename to roomserver/internal/perform/perform_publish.go index d7863620a..aab282f39 100644 --- a/roomserver/internal/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -1,12 +1,17 @@ -package internal +package perform import ( "context" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" ) -func (r *RoomserverInternalAPI) PerformPublish( +type Publisher struct { + DB storage.Database +} + +func (r *Publisher) PerformPublish( ctx context.Context, req *api.PerformPublishRequest, res *api.PerformPublishResponse, diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index f8e8ba04d..26b22c74b 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -20,11 +20,9 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" @@ -74,7 +72,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( return err } - stateEvents, err := r.loadStateEvents(ctx, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) if err != nil { return err } @@ -123,7 +121,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( return err } - stateEvents, err := r.loadStateEvents(ctx, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) if err != nil { return err } @@ -151,7 +149,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID( eventNIDs = append(eventNIDs, nid) } - events, err := r.loadEvents(ctx, eventNIDs) + events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) if err != nil { return err } @@ -168,31 +166,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID( return nil } -func (r *RoomserverInternalAPI) loadStateEvents( - ctx context.Context, stateEntries []types.StateEntry, -) ([]gomatrixserverlib.Event, error) { - eventNIDs := make([]types.EventNID, len(stateEntries)) - for i := range stateEntries { - eventNIDs[i] = stateEntries[i].EventNID - } - return r.loadEvents(ctx, eventNIDs) -} - -func (r *RoomserverInternalAPI) loadEvents( - ctx context.Context, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.Event, error) { - stateEvents, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - result := make([]gomatrixserverlib.Event, len(stateEvents)) - for i := range stateEvents { - result[i] = stateEvents[i].Event - } - return result, nil -} - // QueryMembershipForUser implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryMembershipForUser( ctx context.Context, @@ -266,12 +239,12 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) } if err != nil { @@ -286,65 +259,6 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( return nil } -func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) - // Lookup the event NID - eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) - if err != nil { - return nil, err - } - eventIDs := []string{eIDs[eventNID]} - - prevState, err := db.StateAtEventIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - // Fetch the state as it was when this event was fired - return roomState.LoadCombinedStateAfterEvents(ctx, prevState) -} - -// getMembershipsAtState filters the state events to -// only keep the "m.room.member" events with a "join" membership. These events are returned. -// Returns an error if there was an issue fetching the events. -func getMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, -) ([]types.Event, error) { - - var eventNIDs []types.EventNID - for _, entry := range stateEntries { - // Filter the events to retrieve to only keep the membership events - if entry.EventTypeNID == types.MRoomMemberNID { - eventNIDs = append(eventNIDs, entry.EventNID) - } - } - - // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - if !joinedOnly { - return stateEvents, nil - } - - // Filter the events to only keep the "join" membership events - var events []types.Event - for _, event := range stateEvents { - membership, err := event.Membership() - if err != nil { - return nil, err - } - - if membership == gomatrixserverlib.Join { - events = append(events, event) - } - } - - return events, nil -} - // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( ctx context.Context, @@ -360,7 +274,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( return } roomID := events[0].RoomID() - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID) + isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID) if err != nil { return } @@ -371,31 +285,12 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( if info == nil { return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } - response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, *info, request.EventID, request.ServerName, isServerInRoom, + response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom, ) return } -func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, -) (bool, error) { - roomState := state.NewStateResolution(r.DB, info) - stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) - if err != nil { - return false, err - } - - // TODO: We probably want to make it so that we don't have to pull - // out all the state if possible. - stateAtEvent, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return false, err - } - - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil -} - // QueryMissingEvents implements api.RoomserverInternalAPI // nolint:gocyclo func (r *RoomserverInternalAPI) QueryMissingEvents( @@ -431,12 +326,12 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) } - resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } - loadedEvents, err := r.loadEvents(ctx, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) if err != nil { return err } @@ -456,299 +351,6 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( return err } -// PerformBackfill implements api.RoomServerQueryAPI -func (r *RoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - // if we are requesting the backfill then we need to do a federation hit - // TODO: we could be more sensible and fetch as many events we already have then request the rest - // which is what the syncapi does already. - if request.ServerName == r.ServerName { - return r.backfillViaFederation(ctx, request, response) - } - // someone else is requesting the backfill, try to service their request. - var err error - var front []string - - // The limit defines the maximum number of events to retrieve, so it also - // defines the highest number of elements in the map below. - visited := make(map[string]bool, request.Limit) - - // this will include these events which is what we want - front = request.PrevEventIDs() - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info == nil || info.IsStub { - return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) - } - - // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) - if err != nil { - return err - } - - // Retrieve events from the list that was filled previously. - var loadedEvents []gomatrixserverlib.Event - loadedEvents, err = r.loadEvents(ctx, resultNIDs) - if err != nil { - return err - } - - for _, event := range loadedEvents { - response.Events = append(response.Events, event.Headered(info.RoomVersion)) - } - - return err -} - -func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { - roomVer, err := r.roomVersion(req.RoomID) - if err != nil { - return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) - } - requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) - // Request 100 items regardless of what the query asks for. - // We don't want to go much higher than this. - // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass - // (so we don't need to hit /state_ids which the test has no listener for) - // Specifically the test "Outbound federation can backfill events" - events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100) - if err != nil { - return err - } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) - - // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) - if err != nil { - return err - } - - for _, ev := range backfilledEventMap { - // now add state for these events - stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] - if !ok { - // this should be impossible as all events returned must have pass Step 5 of the PDU checks - // which requires a list of state IDs. - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") - continue - } - var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { - // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs) - // try again - entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") - return err - } - } - - var beforeStateSnapshotNID types.StateSnapshotNID - if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") - return err - } - if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") - } - } - - // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. - - res.Events = events - return nil -} - -func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { - info, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return false, err - } - if info == nil { - return false, fmt.Errorf("unknown room %s", roomID) - } - - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) - if err != nil { - return false, err - } - - events, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return false, err - } - gmslEvents := make([]gomatrixserverlib.Event, len(events)) - for i := range events { - gmslEvents[i] = events[i].Event - } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil -} - -// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just -// best effort. -func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { - - servers := backfillRequester.servers - - // work out which are missing - nidMap, err := r.DB.EventNIDs(ctx, stateIDs) - if err != nil { - util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") - return - } - missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event - for _, id := range stateIDs { - if _, ok := nidMap[id]; !ok { - missingMap[id] = nil - } - } - util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) - - // fetch the events from federation. Loop the servers first so if we find one that works we stick with them - for _, srv := range servers { - for id, ev := range missingMap { - if ev != nil { - continue // already found - } - logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FedClient.GetEvent(ctx, srv, id) - if err != nil { - logger.WithError(err).Warn("failed to get event from server") - continue - } - loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) - if err != nil { - logger.WithError(err).Warn("failed to load and verify event") - continue - } - logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) - for _, res := range result { - if res.Error != nil { - logger.WithError(err).Warn("event failed PDU checks") - continue - } - missingMap[id] = res.Event - } - } - } - - var newEvents []gomatrixserverlib.HeaderedEvent - for _, ev := range missingMap { - if ev != nil { - newEvents = append(newEvents, *ev) - } - } - util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) -} - -// TODO: Remove this when we have tests to assert correctness of this function -// nolint:gocyclo -func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName gomatrixserverlib.ServerName, -) ([]types.EventNID, error) { - var resultNIDs []types.EventNID - var err error - var allowed bool - var events []types.Event - var next []string - var pre string - - // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) - // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing - // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in - // duplicate events being sent in response to /backfill requests. - initialIgnoreList := make(map[string]bool, len(visited)) - for k, v := range visited { - initialIgnoreList[k] = v - } - - resultNIDs = make([]types.EventNID, 0, limit) - - var checkedServerInRoom bool - var isServerInRoom bool - - // Loop through the event IDs to retrieve the requested events and go - // through the whole tree (up to the provided limit) using the events' - // "prev_event" key. -BFSLoop: - for len(front) > 0 { - // Prevent unnecessary allocations: reset the slice only when not empty. - if len(next) > 0 { - next = make([]string, 0) - } - // Retrieve the events to process from the database. - events, err = r.DB.EventsFromIDs(ctx, front) - if err != nil { - return resultNIDs, err - } - - if !checkedServerInRoom && len(events) > 0 { - // It's nasty that we have to extract the room ID from an event, but many federation requests - // only talk in event IDs, no room IDs at all (!!!) - ev := events[0] - isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID()) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") - } - checkedServerInRoom = true - } - - for _, ev := range events { - // Break out of the loop if the provided limit is reached. - if len(resultNIDs) == limit { - break BFSLoop - } - - if !initialIgnoreList[ev.EventID()] { - // Update the list of events to retrieve. - resultNIDs = append(resultNIDs, ev.EventNID) - } - // Loop through the event's parents. - for _, pre = range ev.PrevEventIDs() { - // Only add an event to the list of next events to process if it - // hasn't been seen before. - if !visited[pre] { - visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom) - if err != nil { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( - "Error checking if allowed to see event", - ) - return resultNIDs, err - } - - // If the event hasn't been seen before and the HS - // requesting to retrieve it is allowed to do so, add it to - // the list of events to retrieve. - if allowed { - next = append(next, pre) - } else { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") - } - } - } - } - // Repeat the same process with the parent events we just processed. - front = next - } - - return resultNIDs, err -} - // QueryStateAndAuthChain implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryStateAndAuthChain( ctx context.Context, @@ -823,7 +425,7 @@ func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInf return nil, err } - return r.loadStateEvents(ctx, stateEntries) + return helpers.LoadStateEvents(ctx, r.DB, stateEntries) } type eventsFromIDs func(context.Context, []string) ([]types.Event, error) @@ -879,50 +481,6 @@ func getAuthChain( return authEvents, nil } -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { - var roomNID types.RoomNID - backfilledEventMap := make(map[string]types.Event) - for j, ev := range events { - nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) - if err != nil { // this shouldn't happen as RequestBackfill already found them - logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") - continue - } - authNids := make([]types.EventNID, len(nidMap)) - i := 0 - for _, nid := range nidMap { - authNids[i] = nid - i++ - } - var stateAtEvent types.StateAtEvent - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") - continue - } - // If storing this event results in it being redacted, then do so. - // It's also possible for this event to be a redaction which results in another event being - // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = redactedEvent.Headered(ev.RoomVersion) - events[j] = ev - } - backfilledEventMap[ev.EventID()] = types.Event{ - EventNID: stateAtEvent.StateEntry.EventNID, - Event: ev.Unwrap(), - } - } - return roomNID, backfilledEventMap -} - // QueryRoomVersionCapabilities implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( ctx context.Context, diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 21af5f32d..a428ad57b 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -47,14 +47,8 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to room server db") } - return &internal.RoomserverInternalAPI{ - DB: roomserverDB, - Cfg: cfg, - Producer: base.KafkaProducer, - OutputRoomEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), - Cache: base.Caches, - ServerName: cfg.Matrix.ServerName, - FedClient: fedClient, - KeyRing: keyRing, - } + return internal.NewRoomserverAPI( + cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), + base.Caches, fedClient, keyRing, + ) } From 096191ca240776031370e99b93732557972ba92a Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Sep 2020 15:26:30 +0100 Subject: [PATCH 04/12] Use federation sender for backfill/getting missing events (#1379) * Use federation sender for backfill and getting missing events * Fix internal URL paths * Update go.mod/go.sum for matrix-org/gomatrixserverlib#218 * Add missing server implementations in HTTP interface --- build/gobind/monolith.go | 2 +- cmd/dendrite-demo-libp2p/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 2 +- cmd/dendrite-monolith-server/main.go | 2 +- cmd/dendrite-room-server/main.go | 3 +- cmd/dendritejs/main.go | 2 +- federationsender/api/api.go | 3 + federationsender/internal/api.go | 48 +++++++ federationsender/inthttp/client.go | 130 ++++++++++++++++++ federationsender/inthttp/server.go | 88 ++++++++++++ go.mod | 2 +- go.sum | 4 +- roomserver/internal/api.go | 8 +- .../internal/perform/perform_backfill.go | 23 ++-- roomserver/roomserver.go | 3 +- roomserver/roomserver_test.go | 2 +- 16 files changed, 295 insertions(+), 29 deletions(-) diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 59535c7b9..725c9c074 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -120,7 +120,7 @@ func (m *DendriteMonolith) Start() { keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) eduInputAPI := eduserver.NewInternalAPI( diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index e2d23e895..d4f0cee04 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -155,7 +155,7 @@ func main() { stateAPI := currentstateserver.NewInternalAPI(&base.Base.Cfg.CurrentStateServer, base.Base.KafkaConsumer) rsAPI := roomserver.NewInternalAPI( - &base.Base, keyRing, federation, + &base.Base, keyRing, ) eduInputAPI := eduserver.NewInternalAPI( &base.Base, cache.New(), userAPI, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 26999ebed..fcf3d4c56 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -104,7 +104,7 @@ func main() { keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) rsAPI := rsComponent diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 815117463..717b21a9f 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -81,7 +81,7 @@ func main() { keyRing := serverKeyAPI.KeyRing() rsImpl := roomserver.NewInternalAPI( - base, keyRing, federation, + base, keyRing, ) // call functions directly on the impl unless running in HTTP mode rsAPI := rsImpl diff --git a/cmd/dendrite-room-server/main.go b/cmd/dendrite-room-server/main.go index 0d587e6ee..08ad34bfd 100644 --- a/cmd/dendrite-room-server/main.go +++ b/cmd/dendrite-room-server/main.go @@ -23,13 +23,12 @@ func main() { cfg := setup.ParseFlags(false) base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true) defer base.Close() // nolint: errcheck - federation := base.CreateFederationClient() serverKeyAPI := base.ServerKeyAPIClient() keyRing := serverKeyAPI.KeyRing() fsAPI := base.FederationSenderHTTPClient() - rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + rsAPI := roomserver.NewInternalAPI(base, keyRing) rsAPI.SetFederationSenderAPI(fsAPI) roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index c95eb3fce..aeca70946 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -205,7 +205,7 @@ func main() { } stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) - rsAPI := roomserver.NewInternalAPI(base, keyRing, federation) + rsAPI := roomserver.NewInternalAPI(base, keyRing) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, diff --git a/federationsender/api/api.go b/federationsender/api/api.go index cea0010d6..655d1d103 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -14,9 +14,12 @@ import ( // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type FederationClient interface { + gomatrixserverlib.BackfillClient + gomatrixserverlib.FederatedStateClient GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 6b5f4c342..61663be31 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -136,3 +136,51 @@ func (a *FederationSenderInternalAPI) QueryKeys( } return ires.(gomatrixserverlib.RespQueryKeys), nil } + +func (a *FederationSenderInternalAPI) Backfill( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, +) (res gomatrixserverlib.Transaction, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} + +func (a *FederationSenderInternalAPI) LookupState( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (res gomatrixserverlib.RespState, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + }) + if err != nil { + return gomatrixserverlib.RespState{}, err + } + return ires.(gomatrixserverlib.RespState), nil +} + +func (a *FederationSenderInternalAPI) LookupStateIDs( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, +) (res gomatrixserverlib.RespStateIDs, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + }) + if err != nil { + return gomatrixserverlib.RespStateIDs{}, err + } + return ires.(gomatrixserverlib.RespStateIDs), nil +} + +func (a *FederationSenderInternalAPI) GetEvent( + ctx context.Context, s gomatrixserverlib.ServerName, eventID string, +) (res gomatrixserverlib.Transaction, err error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.GetEvent(ctx, s, eventID) + }) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + return ires.(gomatrixserverlib.Transaction), nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 79e220c38..5bfe6089d 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -26,6 +26,10 @@ const ( FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" + FederationSenderBackfillPath = "/federationsender/client/backfill" + FederationSenderLookupStatePath = "/federationsender/client/lookupState" + FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" + FederationSenderGetEventPath = "/federationsender/client/getEvent" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -228,3 +232,129 @@ func (h *httpFederationSenderInternalAPI) QueryKeys( } return *response.Res, nil } + +type backfill struct { + S gomatrixserverlib.ServerName + RoomID string + Limit int + EventIDs []string + Res *gomatrixserverlib.Transaction + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) Backfill( + ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, +) (gomatrixserverlib.Transaction, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill") + defer span.Finish() + + request := backfill{ + S: s, + RoomID: roomID, + Limit: limit, + EventIDs: eventIDs, + } + var response backfill + apiURL := h.federationSenderURL + FederationSenderBackfillPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + if response.Err != nil { + return gomatrixserverlib.Transaction{}, response.Err + } + return *response.Res, nil +} + +type lookupState struct { + S gomatrixserverlib.ServerName + RoomID string + EventID string + RoomVersion gomatrixserverlib.RoomVersion + Res *gomatrixserverlib.RespState + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) LookupState( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, +) (gomatrixserverlib.RespState, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState") + defer span.Finish() + + request := lookupState{ + S: s, + RoomID: roomID, + EventID: eventID, + RoomVersion: roomVersion, + } + var response lookupState + apiURL := h.federationSenderURL + FederationSenderLookupStatePath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespState{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespState{}, response.Err + } + return *response.Res, nil +} + +type lookupStateIDs struct { + S gomatrixserverlib.ServerName + RoomID string + EventID string + Res *gomatrixserverlib.RespStateIDs + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) LookupStateIDs( + ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, +) (gomatrixserverlib.RespStateIDs, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs") + defer span.Finish() + + request := lookupStateIDs{ + S: s, + RoomID: roomID, + EventID: eventID, + } + var response lookupStateIDs + apiURL := h.federationSenderURL + FederationSenderLookupStateIDsPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.RespStateIDs{}, err + } + if response.Err != nil { + return gomatrixserverlib.RespStateIDs{}, response.Err + } + return *response.Res, nil +} + +type getEvent struct { + S gomatrixserverlib.ServerName + EventID string + Res *gomatrixserverlib.Transaction + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) GetEvent( + ctx context.Context, s gomatrixserverlib.ServerName, eventID string, +) (gomatrixserverlib.Transaction, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent") + defer span.Finish() + + request := getEvent{ + S: s, + EventID: eventID, + } + var response getEvent + apiURL := h.federationSenderURL + FederationSenderGetEventPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.Transaction{}, err + } + if response.Err != nil { + return gomatrixserverlib.Transaction{}, response.Err + } + return *response.Res, nil +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index b18255760..dfbff1c00 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -175,4 +175,92 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationSenderBackfillPath, + httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse { + var request backfill + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderLookupStatePath, + httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse { + var request lookupState + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderLookupStateIDsPath, + httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse { + var request lookupStateIDs + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderGetEventPath, + httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse { + var request getEvent + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.Res = &res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) } diff --git a/go.mod b/go.mod index c69068059..3a9fef9f5 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 diff --git a/go.sum b/go.sum index 332ae05fa..33b4f591a 100644 --- a/go.sum +++ b/go.sum @@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2 h1:9wKwfd5KDcXuqZ7/kAaYe0QM4DGM+2awjjvXQtrDa6k= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 h1:k5vsLfpylXHOXgN51N0QNbak9i+4bT33Puk/ZJgcdDw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1897f7a53..8ac1bdda2 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -22,7 +22,7 @@ type RoomserverInternalAPI struct { Cache caching.RoomServerCaches ServerName gomatrixserverlib.ServerName KeyRing gomatrixserverlib.JSONVerifier - FedClient *gomatrixserverlib.FederationClient + fsAPI fsAPI.FederationSenderInternalAPI OutputRoomEventTopic string // Kafka topic for new output room events Inviter *perform.Inviter Joiner *perform.Joiner @@ -30,12 +30,11 @@ type RoomserverInternalAPI struct { Publisher *perform.Publisher Backfiller *perform.Backfiller mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent - fsAPI fsAPI.FederationSenderInternalAPI } func NewRoomserverAPI( cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer, - outputRoomEventTopic string, caches caching.RoomServerCaches, fedClient *gomatrixserverlib.FederationClient, + outputRoomEventTopic string, caches caching.RoomServerCaches, keyRing gomatrixserverlib.JSONVerifier, ) *RoomserverInternalAPI { a := &RoomserverInternalAPI{ @@ -45,7 +44,6 @@ func NewRoomserverAPI( Cache: caches, ServerName: cfg.Matrix.ServerName, KeyRing: keyRing, - FedClient: fedClient, OutputRoomEventTopic: outputRoomEventTopic, // perform-er structs get initialised when we have a federation sender to use } @@ -83,7 +81,7 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen r.Backfiller = &perform.Backfiller{ ServerName: r.ServerName, DB: r.DB, - FedClient: r.FedClient, + FSAPI: r.fsAPI, KeyRing: r.KeyRing, } } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index ebb66ef42..d345e9c73 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/auth" @@ -18,7 +19,7 @@ import ( type Backfiller struct { ServerName gomatrixserverlib.ServerName DB storage.Database - FedClient *gomatrixserverlib.FederationClient + FSAPI federationSenderAPI.FederationSenderInternalAPI KeyRing gomatrixserverlib.JSONVerifier } @@ -81,7 +82,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) + requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass @@ -166,7 +167,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue // already found } logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FedClient.GetEvent(ctx, srv, id) + res, err := r.FSAPI.GetEvent(ctx, srv, id) if err != nil { logger.WithError(err).Warn("failed to get event from server") continue @@ -201,7 +202,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { db storage.Database - fedClient *gomatrixserverlib.FederationClient + fsAPI federationSenderAPI.FederationSenderInternalAPI thisServer gomatrixserverlib.ServerName bwExtrems map[string][]string @@ -211,10 +212,10 @@ type backfillRequester struct { eventIDMap map[string]gomatrixserverlib.Event } -func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester { +func newBackfillRequester(db storage.Database, fsAPI federationSenderAPI.FederationSenderInternalAPI, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester { return &backfillRequester{ db: db, - fedClient: fedClient, + fsAPI: fsAPI, thisServer: thisServer, eventIDToBeforeStateIDs: make(map[string][]string), eventIDMap: make(map[string]gomatrixserverlib.Event), @@ -258,7 +259,7 @@ FederationHit: logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event") for _, srv := range b.servers { // hit any valid server c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, + FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, } @@ -331,7 +332,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr } c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, + FedClient: b.fsAPI, RememberAuthEvents: false, Server: b.servers[0], } @@ -430,10 +431,10 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, - fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) { + limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { - tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs) - return &tx, err + tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + return tx, err } func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) { diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index a428ad57b..2eabf4504 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -38,7 +38,6 @@ func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { func NewInternalAPI( base *setup.BaseDendrite, keyRing gomatrixserverlib.JSONVerifier, - fedClient *gomatrixserverlib.FederationClient, ) api.RoomserverInternalAPI { cfg := &base.Cfg.RoomServer @@ -49,6 +48,6 @@ func NewInternalAPI( return internal.NewRoomserverAPI( cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)), - base.Caches, fedClient, keyRing, + base.Caches, keyRing, ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index bcd9afb38..0deb7acb1 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -112,7 +112,7 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js Cfg: cfg, } - rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}, nil) + rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) hevents := mustLoadEvents(t, ver, events) _, err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil) if err != nil { From 3b0774805cd06e1d9094a5b0773126cbfb573abb Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Sep 2020 16:18:08 +0100 Subject: [PATCH 05/12] Version imprint (#1383) * Versions * Update build.sh --- build.sh | 9 +++++++-- federationapi/routing/version.go | 11 ++++++++++- internal/setup/base.go | 2 ++ internal/version.go | 26 ++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 internal/version.go diff --git a/build.sh b/build.sh index 087f4ae72..34e4b1153 100755 --- a/build.sh +++ b/build.sh @@ -3,6 +3,11 @@ # Put installed packages into ./bin export GOBIN=$PWD/`dirname $0`/bin -go install -v $PWD/`dirname $0`/cmd/... +export BRANCH=`(git symbolic-ref --short HEAD | cut -d'/' -f 3 )|| ""` +export BUILD=`git rev-parse --short HEAD || ""` -GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs +export FLAGS="-X github.com/matrix-org/dendrite/internal.branch=$BRANCH -X github.com/matrix-org/dendrite/internal.build=$BUILD" + +go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/... + +GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs diff --git a/federationapi/routing/version.go b/federationapi/routing/version.go index 14ecd21e1..906fc2b9b 100644 --- a/federationapi/routing/version.go +++ b/federationapi/routing/version.go @@ -17,6 +17,7 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/util" ) @@ -31,5 +32,13 @@ type server struct { // Version returns the server version func Version() util.JSONResponse { - return util.JSONResponse{Code: http.StatusOK, JSON: &version{server{"dev", "Dendrite"}}} + return util.JSONResponse{ + Code: http.StatusOK, + JSON: &version{ + server{ + Name: "Dendrite", + Version: internal.VersionString(), + }, + }, + } } diff --git a/internal/setup/base.go b/internal/setup/base.go index 7bf06e748..ec2bbc4cf 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -100,6 +100,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo internal.SetupHookLogging(cfg.Logging, componentName) internal.SetupPprof() + logrus.Infof("Dendrite version %s", internal.VersionString()) + closer, err := cfg.SetupTracing("Dendrite" + componentName) if err != nil { logrus.WithError(err).Panicf("failed to start opentracing") diff --git a/internal/version.go b/internal/version.go new file mode 100644 index 000000000..851a09384 --- /dev/null +++ b/internal/version.go @@ -0,0 +1,26 @@ +package internal + +import "fmt" + +// -ldflags "-X github.com/matrix-org/dendrite/internal.branch=master" +var branch string + +// -ldflags "-X github.com/matrix-org/dendrite/internal.build=alpha" +var build string + +const ( + VersionMajor = 0 + VersionMinor = 0 + VersionPatch = 0 +) + +func VersionString() string { + version := fmt.Sprintf("%d.%d.%d", VersionMajor, VersionMinor, VersionPatch) + if branch != "" { + version += fmt.Sprintf("-%s", branch) + } + if build != "" { + version += fmt.Sprintf("+%s", build) + } + return version +} From f06637435b2124c89dfdd96cd723f54cc7055602 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 2 Sep 2020 16:52:06 +0100 Subject: [PATCH 06/12] Fix #1381 (#1384) --- clientapi/routing/login.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index d2bc9337d..772775aa0 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -41,15 +41,13 @@ type flows struct { } type flow struct { - Type string `json:"type"` - Stages []string `json:"stages"` + Type string `json:"type"` } func passwordLogin() flows { f := flows{} s := flow{ - Type: "m.login.password", - Stages: []string{"m.login.password"}, + Type: "m.login.password", } f.Flows = append(f.Flows, s) return f From 9d9e854fe042cd2c83cf694d6b3e4c8e7046cde1 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 2 Sep 2020 17:13:15 +0100 Subject: [PATCH 07/12] Add Queryer and Inputer and factor out more RSAPI stuff (#1382) * Add Queryer and use embedded structs * Add Inputer and factor out more RS API stuff This neatly splits up the RS API based on the functionality it provides, whilst providing a useful place for code sharing via the `helpers` package. --- clientapi/routing/membership.go | 2 +- clientapi/routing/profile.go | 2 +- clientapi/routing/redaction.go | 2 +- clientapi/routing/sendevent.go | 2 +- clientapi/threepid/invites.go | 2 +- federationapi/routing/join.go | 2 +- federationapi/routing/leave.go | 2 +- internal/eventutil/events.go | 53 ++++++++---- roomserver/internal/api.go | 81 +++++++---------- roomserver/internal/helpers/helpers.go | 53 ++++++++++++ roomserver/internal/{ => input}/input.go | 17 +++- .../internal/{ => input}/input_events.go | 6 +- .../{ => input}/input_latest_events.go | 6 +- .../internal/{ => input}/input_membership.go | 10 +-- .../internal/perform/perform_backfill.go | 14 +++ roomserver/internal/perform/perform_invite.go | 27 ++++-- roomserver/internal/perform/perform_join.go | 58 ++++++++++--- roomserver/internal/perform/perform_leave.go | 34 ++++---- .../internal/perform/perform_publish.go | 14 +++ roomserver/internal/{ => query}/query.go | 86 +++++-------------- roomserver/internal/{ => query}/query_test.go | 4 +- 21 files changed, 292 insertions(+), 185 deletions(-) rename roomserver/internal/{ => input}/input.go (83%) rename roomserver/internal/{ => input}/input_events.go (98%) rename roomserver/internal/{ => input}/input_latest_events.go (99%) rename roomserver/internal/{ => input}/input_membership.go (96%) rename roomserver/internal/{ => query}/query.go (85%) rename roomserver/internal/{ => query}/query_test.go (98%) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 5d635c018..cba19a24b 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -270,7 +270,7 @@ func buildMembershipEvent( return nil, err } - return eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index faf92451e..4c7895bd3 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -375,7 +375,7 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index bb5265135..a825da64d 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -115,7 +115,7 @@ func SendRedaction( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 9cf517cff..a25979ea0 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -158,7 +158,7 @@ func generateSendEvent( } var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index f1d54a47b..2ffb6bb09 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -354,7 +354,7 @@ func emit3PIDInviteEvent( } queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) if err != nil { return err } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index ffdadd522..6cac12451 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -95,7 +95,7 @@ func MakeJoin( queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index d2fbfc712..511623445 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -61,7 +61,7 @@ func MakeLeave( } var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index 35c7f33d8..0b878961e 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -30,13 +30,13 @@ import ( // doesn't exist var ErrRoomNoExists = errors.New("Room does not exist") -// BuildEvent builds a Matrix event using the event builder and roomserver query +// QueryAndBuildEvent builds a Matrix event using the event builder and roomserver query // API client provided. If also fills roomserver query API response (if provided) // in case the function calling FillBuilder needs to use it. // Returns ErrRoomNoExists if the state of the room could not be retrieved because // the room doesn't exist // Returns an error if something else went wrong -func BuildEvent( +func QueryAndBuildEvent( ctx context.Context, builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, @@ -45,11 +45,25 @@ func BuildEvent( queryRes = &api.QueryLatestEventsAndStateResponse{} } - ver, err := AddPrevEventsToEvent(ctx, builder, rsAPI, queryRes) + eventsNeeded, err := queryRequiredEventsForBuilder(ctx, builder, rsAPI, queryRes) if err != nil { // This can pass through a ErrRoomNoExists to the caller return nil, err } + return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) +} + +// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse +// provided. +func BuildEvent( + ctx context.Context, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, +) (*gomatrixserverlib.HeaderedEvent, error) { + err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) + if err != nil { + return nil, err + } event, err := builder.Build( evTime, cfg.ServerName, cfg.KeyID, @@ -59,23 +73,23 @@ func BuildEvent( return nil, err } - h := event.Headered(ver) + h := event.Headered(queryRes.RoomVersion) return &h, nil } -// AddPrevEventsToEvent fills out the prev_events and auth_events fields in builder -func AddPrevEventsToEvent( +// queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder. +func queryRequiredEventsForBuilder( ctx context.Context, builder *gomatrixserverlib.EventBuilder, rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse, -) (gomatrixserverlib.RoomVersion, error) { +) (*gomatrixserverlib.StateNeeded, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { - return "", fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + return nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) } if len(eventsNeeded.Tuples()) == 0 { - return "", errors.New("expecting state tuples for event builder, got none") + return nil, errors.New("expecting state tuples for event builder, got none") } // Ask the roomserver for information about this room @@ -83,17 +97,22 @@ func AddPrevEventsToEvent( RoomID: builder.RoomID, StateToFetch: eventsNeeded.Tuples(), } - if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil { - return "", fmt.Errorf("rsAPI.QueryLatestEventsAndState: %w", err) - } + return &eventsNeeded, rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes) +} +// addPrevEventsToEvent fills out the prev_events and auth_events fields in builder +func addPrevEventsToEvent( + builder *gomatrixserverlib.EventBuilder, + eventsNeeded *gomatrixserverlib.StateNeeded, + queryRes *api.QueryLatestEventsAndStateResponse, +) error { if !queryRes.RoomExists { - return "", ErrRoomNoExists + return ErrRoomNoExists } eventFormat, err := queryRes.RoomVersion.EventFormat() if err != nil { - return "", fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err) + return fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err) } builder.Depth = queryRes.Depth @@ -103,13 +122,13 @@ func AddPrevEventsToEvent( for i := range queryRes.StateEvents { err = authEvents.AddEvent(&queryRes.StateEvents[i].Event) if err != nil { - return "", fmt.Errorf("authEvents.AddEvent: %w", err) + return fmt.Errorf("authEvents.AddEvent: %w", err) } } refs, err := eventsNeeded.AuthEventReferences(&authEvents) if err != nil { - return "", fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) + return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents) @@ -129,7 +148,7 @@ func AddPrevEventsToEvent( builder.PrevEvents = v2PrevRefs } - return queryRes.RoomVersion, nil + return nil } // truncateAuthAndPrevEvents limits the number of events we add into diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 8ac1bdda2..93c0be77b 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -2,20 +2,28 @@ package internal import ( "context" - "sync" "github.com/Shopify/sarama" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/perform" + "github.com/matrix-org/dendrite/roomserver/internal/query" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI type RoomserverInternalAPI struct { + *input.Inputer + *query.Queryer + *perform.Inviter + *perform.Joiner + *perform.Leaver + *perform.Publisher + *perform.Backfiller DB storage.Database Cfg *config.RoomServer Producer sarama.SyncProducer @@ -24,12 +32,6 @@ type RoomserverInternalAPI struct { KeyRing gomatrixserverlib.JSONVerifier fsAPI fsAPI.FederationSenderInternalAPI OutputRoomEventTopic string // Kafka topic for new output room events - Inviter *perform.Inviter - Joiner *perform.Joiner - Leaver *perform.Leaver - Publisher *perform.Publisher - Backfiller *perform.Backfiller - mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent } func NewRoomserverAPI( @@ -38,13 +40,21 @@ func NewRoomserverAPI( keyRing gomatrixserverlib.JSONVerifier, ) *RoomserverInternalAPI { a := &RoomserverInternalAPI{ - DB: roomserverDB, - Cfg: cfg, - Producer: producer, - Cache: caches, - ServerName: cfg.Matrix.ServerName, - KeyRing: keyRing, - OutputRoomEventTopic: outputRoomEventTopic, + DB: roomserverDB, + Cfg: cfg, + Cache: caches, + ServerName: cfg.Matrix.ServerName, + KeyRing: keyRing, + Queryer: &query.Queryer{ + DB: roomserverDB, + Cache: caches, + }, + Inputer: &input.Inputer{ + DB: roomserverDB, + OutputRoomEventTopic: outputRoomEventTopic, + Producer: producer, + ServerName: cfg.Matrix.ServerName, + }, // perform-er structs get initialised when we have a federation sender to use } return a @@ -57,23 +67,23 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen r.fsAPI = fsAPI r.Inviter = &perform.Inviter{ - DB: r.DB, - Cfg: r.Cfg, - FSAPI: r.fsAPI, - RSAPI: r, + DB: r.DB, + Cfg: r.Cfg, + FSAPI: r.fsAPI, + Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ ServerName: r.Cfg.Matrix.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, - RSAPI: r, + Inputer: r.Inputer, } r.Leaver = &perform.Leaver{ - Cfg: r.Cfg, - DB: r.DB, - FSAPI: r.fsAPI, - RSAPI: r, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + Inputer: r.Inputer, } r.Publisher = &perform.Publisher{ DB: r.DB, @@ -101,14 +111,6 @@ func (r *RoomserverInternalAPI) PerformInvite( return r.WriteOutputEvents(req.Event.RoomID(), outputEvents) } -func (r *RoomserverInternalAPI) PerformJoin( - ctx context.Context, - req *api.PerformJoinRequest, - res *api.PerformJoinResponse, -) { - r.Joiner.PerformJoin(ctx, req, res) -} - func (r *RoomserverInternalAPI) PerformLeave( ctx context.Context, req *api.PerformLeaveRequest, @@ -123,20 +125,3 @@ func (r *RoomserverInternalAPI) PerformLeave( } return r.WriteOutputEvents(req.RoomID, outputEvents) } - -func (r *RoomserverInternalAPI) PerformPublish( - ctx context.Context, - req *api.PerformPublishRequest, - res *api.PerformPublishResponse, -) { - r.Publisher.PerformPublish(ctx, req, res) -} - -// Query a given amount (or less) of events prior to a given set of events. -func (r *RoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - return r.Backfiller.PerformBackfill(ctx, request, response) -} diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index d7bb40af5..b7e6ce86c 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -324,3 +324,56 @@ BFSLoop: return resultNIDs, err } + +func QueryLatestEventsAndState( + ctx context.Context, db storage.Database, + request *api.QueryLatestEventsAndStateRequest, + response *api.QueryLatestEventsAndStateResponse, +) error { + roomInfo, err := db.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if roomInfo == nil || roomInfo.IsStub { + response.RoomExists = false + return nil + } + + roomState := state.NewStateResolution(db, *roomInfo) + response.RoomExists = true + response.RoomVersion = roomInfo.RoomVersion + + var currentStateSnapshotNID types.StateSnapshotNID + response.LatestEvents, currentStateSnapshotNID, response.Depth, err = + db.LatestEventIDs(ctx, roomInfo.RoomNID) + if err != nil { + return err + } + + var stateEntries []types.StateEntry + if len(request.StateToFetch) == 0 { + // Look up all room state. + stateEntries, err = roomState.LoadStateAtSnapshot( + ctx, currentStateSnapshotNID, + ) + } else { + // Look up the current state for the requested tuples. + stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples( + ctx, currentStateSnapshotNID, request.StateToFetch, + ) + } + if err != nil { + return err + } + + stateEvents, err := LoadStateEvents(ctx, db, stateEntries) + if err != nil { + return err + } + + for _, event := range stateEvents { + response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) + } + + return nil +} diff --git a/roomserver/internal/input.go b/roomserver/internal/input/input.go similarity index 83% rename from roomserver/internal/input.go rename to roomserver/internal/input/input.go index dbf67b792..87bdc5dbf 100644 --- a/roomserver/internal/input.go +++ b/roomserver/internal/input/input.go @@ -13,7 +13,7 @@ // limitations under the License. // Package input contains the code processes new room events -package internal +package input import ( "context" @@ -22,11 +22,22 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) +type Inputer struct { + DB storage.Database + Producer sarama.SyncProducer + ServerName gomatrixserverlib.ServerName + OutputRoomEventTopic string + + mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent +} + // WriteOutputEvents implements OutputRoomEventWriter -func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { +func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error { messages := make([]*sarama.ProducerMessage, len(updates)) for i := range updates { value, err := json.Marshal(updates[i]) @@ -58,7 +69,7 @@ func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.O } // InputRoomEvents implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) InputRoomEvents( +func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input/input_events.go similarity index 98% rename from roomserver/internal/input_events.go rename to roomserver/internal/input/input_events.go index edc8b416a..69f51f4b8 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "context" @@ -36,7 +36,7 @@ import ( // state deltas when sending to kafka streams // TODO: Break up function - we should probably do transaction ID checks before calling this. // nolint:gocyclo -func (r *RoomserverInternalAPI) processRoomEvent( +func (r *Inputer) processRoomEvent( ctx context.Context, input api.InputRoomEvent, ) (eventID string, err error) { @@ -141,7 +141,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } -func (r *RoomserverInternalAPI) calculateAndSetState( +func (r *Inputer) calculateAndSetState( ctx context.Context, input api.InputRoomEvent, roomInfo types.RoomInfo, diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input/input_latest_events.go similarity index 99% rename from roomserver/internal/input_latest_events.go rename to roomserver/internal/input/input_latest_events.go index d5e38e7a4..67a7d8a40 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "bytes" @@ -47,7 +47,7 @@ import ( // 7 <----- latest // // Can only be called once at a time -func (r *RoomserverInternalAPI) updateLatestEvents( +func (r *Inputer) updateLatestEvents( ctx context.Context, roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, @@ -87,7 +87,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( // when there are so many variables to pass around. type latestEventsUpdater struct { ctx context.Context - api *RoomserverInternalAPI + api *Inputer updater *shared.LatestEventsUpdater roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent diff --git a/roomserver/internal/input_membership.go b/roomserver/internal/input/input_membership.go similarity index 96% rename from roomserver/internal/input_membership.go rename to roomserver/internal/input/input_membership.go index 57a945966..8befcd647 100644 --- a/roomserver/internal/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package input import ( "context" @@ -29,7 +29,7 @@ import ( // user affected by a change in the current state of the room. // Returns a list of output events to write to the kafka log to inform the // consumers about the invites added or retired by the change in current state. -func (r *RoomserverInternalAPI) updateMemberships( +func (r *Inputer) updateMemberships( ctx context.Context, updater *shared.LatestEventsUpdater, removed, added []types.StateEntry, @@ -78,7 +78,7 @@ func (r *RoomserverInternalAPI) updateMemberships( return updates, nil } -func (r *RoomserverInternalAPI) updateMembership( +func (r *Inputer) updateMembership( updater *shared.LatestEventsUpdater, targetUserNID types.EventStateKeyNID, remove, add *gomatrixserverlib.Event, @@ -133,11 +133,11 @@ func (r *RoomserverInternalAPI) updateMembership( } } -func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool { +func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { isTargetLocalUser := false if statekey := event.StateKey(); statekey != nil { _, domain, _ := gomatrixserverlib.SplitID('@', *statekey) - isTargetLocalUser = domain == r.Cfg.Matrix.ServerName + isTargetLocalUser = domain == r.ServerName } return isTargetLocalUser } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d345e9c73..668c80787 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package perform import ( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 7320388e7..e06ad062d 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package perform import ( @@ -8,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" @@ -16,12 +31,10 @@ import ( ) type Inviter struct { - DB storage.Database - Cfg *config.RoomServer - FSAPI federationSenderAPI.FederationSenderInternalAPI - - // TODO FIXME: Remove this - RSAPI api.RoomserverInternalAPI + DB storage.Database + Cfg *config.RoomServer + FSAPI federationSenderAPI.FederationSenderInternalAPI + Inputer *input.Inputer } // nolint:gocyclo @@ -170,7 +183,7 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.RSAPI.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } } else { diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index c8e6e8e60..3d1942272 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package perform import ( @@ -12,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" @@ -23,8 +38,7 @@ type Joiner struct { FSAPI fsAPI.FederationSenderInternalAPI DB storage.Database - // TODO FIXME: Remove this - RSAPI api.RoomserverInternalAPI + Inputer *input.Inputer } // PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender. @@ -201,15 +215,7 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent( - ctx, // the request context - &eb, // the template join event - r.Cfg.Matrix, // the server configuration - time.Now(), // the event timestamp to use - r.RSAPI, // the roomserver API to use - &buildRes, // the query response - ) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) switch err { case nil: @@ -241,7 +247,7 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { var notAllowed *gomatrixserverlib.NotAllowed if errors.As(err, ¬Allowed) { return "", &api.PerformError{ @@ -306,3 +312,31 @@ func (r *Joiner) performFederatedJoinRoomByID( } return nil } + +func buildEvent( + ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, +) (*gomatrixserverlib.HeaderedEvent, *api.QueryLatestEventsAndStateResponse, error) { + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) + if err != nil { + return nil, nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err) + } + + if len(eventsNeeded.Tuples()) == 0 { + return nil, nil, errors.New("expecting state tuples for event builder, got none") + } + + var queryRes api.QueryLatestEventsAndStateResponse + err = helpers.QueryLatestEventsAndState(ctx, db, &api.QueryLatestEventsAndStateRequest{ + RoomID: builder.RoomID, + StateToFetch: eventsNeeded.Tuples(), + }, &queryRes) + if err != nil { + return nil, nil, fmt.Errorf("QueryLatestEventsAndState: %w", err) + } + + ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) + if err != nil { + return nil, nil, err + } + return ev, &queryRes, nil +} diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index b4053eed6..aaa3b5b16 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -1,16 +1,29 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package perform import ( "context" "fmt" "strings" - "time" fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" + "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" ) @@ -20,8 +33,7 @@ type Leaver struct { DB storage.Database FSAPI fsAPI.FederationSenderInternalAPI - // TODO FIXME: Remove this - RSAPI api.RoomserverInternalAPI + Inputer *input.Inputer } // WriteOutputEvents implements OutputRoomEventWriter @@ -67,7 +79,7 @@ func (r *Leaver) performLeaveRoomByID( }, } latestRes := api.QueryLatestEventsAndStateResponse{} - if err = r.RSAPI.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil { + if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil { return nil, err } if !latestRes.RoomExists { @@ -108,15 +120,7 @@ func (r *Leaver) performLeaveRoomByID( // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - buildRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.BuildEvent( - ctx, // the request context - &eb, // the template leave event - r.Cfg.Matrix, // the server configuration - time.Now(), // the event timestamp to use - r.RSAPI, // the roomserver API to use - &buildRes, // the query response - ) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) if err != nil { return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } @@ -135,7 +139,7 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.RSAPI.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/roomserver/internal/perform/perform_publish.go b/roomserver/internal/perform/perform_publish.go index aab282f39..6ff42ac1a 100644 --- a/roomserver/internal/perform/perform_publish.go +++ b/roomserver/internal/perform/perform_publish.go @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package perform import ( diff --git a/roomserver/internal/query.go b/roomserver/internal/query/query.go similarity index 85% rename from roomserver/internal/query.go rename to roomserver/internal/query/query.go index 26b22c74b..b2799aefb 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query/query.go @@ -1,6 +1,4 @@ -// Copyright 2017 Vector Creations Ltd -// Copyright 2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package query import ( "context" "fmt" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" @@ -30,62 +30,22 @@ import ( "github.com/sirupsen/logrus" ) +type Queryer struct { + DB storage.Database + Cache caching.RoomServerCaches +} + // QueryLatestEventsAndState implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryLatestEventsAndState( +func (r *Queryer) QueryLatestEventsAndState( ctx context.Context, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if roomInfo == nil || roomInfo.IsStub { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB, *roomInfo) - response.RoomExists = true - response.RoomVersion = roomInfo.RoomVersion - - var currentStateSnapshotNID types.StateSnapshotNID - response.LatestEvents, currentStateSnapshotNID, response.Depth, err = - r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) - if err != nil { - return err - } - - var stateEntries []types.StateEntry - if len(request.StateToFetch) == 0 { - // Look up all room state. - stateEntries, err = roomState.LoadStateAtSnapshot( - ctx, currentStateSnapshotNID, - ) - } else { - // Look up the current state for the requested tuples. - stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples( - ctx, currentStateSnapshotNID, request.StateToFetch, - ) - } - if err != nil { - return err - } - - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) - if err != nil { - return err - } - - for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) - } - - return nil + return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) } // QueryStateAfterEvents implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryStateAfterEvents( +func (r *Queryer) QueryStateAfterEvents( ctx context.Context, request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, @@ -134,7 +94,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( } // QueryEventsByID implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryEventsByID( +func (r *Queryer) QueryEventsByID( ctx context.Context, request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, @@ -167,7 +127,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID( } // QueryMembershipForUser implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryMembershipForUser( +func (r *Queryer) QueryMembershipForUser( ctx context.Context, request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, @@ -204,7 +164,7 @@ func (r *RoomserverInternalAPI) QueryMembershipForUser( } // QueryMembershipsForRoom implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryMembershipsForRoom( +func (r *Queryer) QueryMembershipsForRoom( ctx context.Context, request *api.QueryMembershipsForRoomRequest, response *api.QueryMembershipsForRoomResponse, @@ -260,7 +220,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( } // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( +func (r *Queryer) QueryServerAllowedToSeeEvent( ctx context.Context, request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, @@ -293,7 +253,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( // QueryMissingEvents implements api.RoomserverInternalAPI // nolint:gocyclo -func (r *RoomserverInternalAPI) QueryMissingEvents( +func (r *Queryer) QueryMissingEvents( ctx context.Context, request *api.QueryMissingEventsRequest, response *api.QueryMissingEventsResponse, @@ -352,7 +312,7 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( } // QueryStateAndAuthChain implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryStateAndAuthChain( +func (r *Queryer) QueryStateAndAuthChain( ctx context.Context, request *api.QueryStateAndAuthChainRequest, response *api.QueryStateAndAuthChainResponse, @@ -405,7 +365,7 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( return err } -func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { @@ -482,7 +442,7 @@ func getAuthChain( } // QueryRoomVersionCapabilities implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( +func (r *Queryer) QueryRoomVersionCapabilities( ctx context.Context, request *api.QueryRoomVersionCapabilitiesRequest, response *api.QueryRoomVersionCapabilitiesResponse, @@ -500,7 +460,7 @@ func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( } // QueryRoomVersionCapabilities implements api.RoomserverInternalAPI -func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( +func (r *Queryer) QueryRoomVersionForRoom( ctx context.Context, request *api.QueryRoomVersionForRoomRequest, response *api.QueryRoomVersionForRoomResponse, @@ -522,7 +482,7 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( return nil } -func (r *RoomserverInternalAPI) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { +func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { var res api.QueryRoomVersionForRoomResponse err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{ RoomID: roomID, @@ -530,7 +490,7 @@ func (r *RoomserverInternalAPI) roomVersion(roomID string) (gomatrixserverlib.Ro return res.RoomVersion, err } -func (r *RoomserverInternalAPI) QueryPublishedRooms( +func (r *Queryer) QueryPublishedRooms( ctx context.Context, req *api.QueryPublishedRoomsRequest, res *api.QueryPublishedRoomsResponse, diff --git a/roomserver/internal/query_test.go b/roomserver/internal/query/query_test.go similarity index 98% rename from roomserver/internal/query_test.go rename to roomserver/internal/query/query_test.go index 92e008324..b4cb99b85 100644 --- a/roomserver/internal/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package query import ( "context" From d64d0c4be2ab33185b6dd837944dea3268b62c24 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 3 Sep 2020 10:07:14 +0100 Subject: [PATCH 08/12] Update complement.sh --- build/scripts/complement.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/scripts/complement.sh b/build/scripts/complement.sh index c1e52dde6..29feff304 100755 --- a/build/scripts/complement.sh +++ b/build/scripts/complement.sh @@ -10,7 +10,7 @@ cd `dirname $0`/../.. docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile . # Download Complement -wget https://github.com/matrix-org/complement/archive/master.tar.gz +wget -N https://github.com/matrix-org/complement/archive/master.tar.gz tar -xzf master.tar.gz # Run the tests! From 74743ac8ae3cc439862acd15d13ba4123d745598 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Sep 2020 10:12:11 +0100 Subject: [PATCH 09/12] Rate limiting (#1385) * Initial rate limiting * Move rate limiting to client API * Update rate limits to hopefully be self-cleaning * Use X-Forwarded-For, add comments * Reduce rate limit threshold * Tweak interval * Configurable backoff * Review comments, set cleanup interval to 30 seconds * Allow generate-config to produce sane CI config * Fix Complement dockerfile --- build/scripts/Complement.Dockerfile | 3 +- clientapi/routing/rate_limiting.go | 99 +++++++++++++++++++++++++++++ clientapi/routing/routing.go | 52 +++++++++++++++ cmd/generate-config/main.go | 9 +++ dendrite-config.yaml | 8 +++ internal/config/config_clientapi.go | 31 +++++++++ 6 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 clientapi/routing/rate_limiting.go diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 32c5234b1..de51f16da 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -12,8 +12,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config -RUN ./generate-config > dendrite.yaml -RUN sed -i "s/disable_tls_validation: false/disable_tls_validation: true/g" dendrite.yaml +RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key ENV SERVER_NAME=localhost diff --git a/clientapi/routing/rate_limiting.go b/clientapi/routing/rate_limiting.go new file mode 100644 index 000000000..16e3c0565 --- /dev/null +++ b/clientapi/routing/rate_limiting.go @@ -0,0 +1,99 @@ +package routing + +import ( + "net/http" + "sync" + "time" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/util" +) + +type rateLimits struct { + limits map[string]chan struct{} + limitsMutex sync.RWMutex + enabled bool + requestThreshold int64 + cooloffDuration time.Duration +} + +func newRateLimits(cfg *config.RateLimiting) *rateLimits { + l := &rateLimits{ + limits: make(map[string]chan struct{}), + enabled: cfg.Enabled, + requestThreshold: cfg.Threshold, + cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond, + } + if l.enabled { + go l.clean() + } + return l +} + +func (l *rateLimits) clean() { + for { + // On a 30 second interval, we'll take an exclusive write + // lock of the entire map and see if any of the channels are + // empty. If they are then we will close and delete them, + // freeing up memory. + time.Sleep(time.Second * 30) + l.limitsMutex.Lock() + for k, c := range l.limits { + if len(c) == 0 { + close(c) + delete(l.limits, k) + } + } + l.limitsMutex.Unlock() + } +} + +func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse { + // If rate limiting is disabled then do nothing. + if !l.enabled { + return nil + } + + // Lock the map long enough to check for rate limiting. We hold it + // for longer here than we really need to but it makes sure that we + // also don't conflict with the cleaner goroutine which might clean + // up a channel after we have retrieved it otherwise. + l.limitsMutex.RLock() + defer l.limitsMutex.RUnlock() + + // First of all, work out if X-Forwarded-For was sent to us. If not + // then we'll just use the IP address of the caller. + caller := req.RemoteAddr + if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" { + caller = forwardedFor + } + + // Look up the caller's channel, if they have one. If they don't then + // let's create one. + rateLimit, ok := l.limits[caller] + if !ok { + l.limits[caller] = make(chan struct{}, l.requestThreshold) + rateLimit = l.limits[caller] + } + + // Check if the user has got free resource slots for this request. + // If they don't then we'll return an error. + select { + case rateLimit <- struct{}{}: + default: + // We hit the rate limit. Tell the client to back off. + return &util.JSONResponse{ + Code: http.StatusTooManyRequests, + JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()), + } + } + + // After the time interval, drain a resource from the rate limiting + // channel. This will free up space in the channel for new requests. + go func() { + <-time.After(l.cooloffDuration) + <-rateLimit + }() + return nil +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 24343ee19..0c63f9686 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -60,6 +60,7 @@ func Setup( keyAPI keyserverAPI.KeyInternalAPI, extRoomsProvider api.ExtraPublicRoomsProvider, ) { + rateLimits := newRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) publicAPIMux.Handle("/versions", @@ -92,6 +93,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/join/{roomIDOrAlias}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -108,6 +112,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -119,6 +126,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -139,6 +149,9 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/invite", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -253,14 +266,23 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return RegisterAvailable(req, cfg, accountDB) })).Methods(http.MethodGet, http.MethodOptions) @@ -332,6 +354,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/typing/{userID}", httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -385,6 +410,9 @@ func Setup( r0mux.Handle("/account/whoami", httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Whoami(req, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -393,6 +421,9 @@ func Setup( r0mux.Handle("/login", httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return Login(req, accountDB, userAPI, cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) @@ -447,6 +478,9 @@ func Setup( r0mux.Handle("/profile/{userID}/avatar_url", httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -469,6 +503,9 @@ func Setup( r0mux.Handle("/profile/{userID}/displayname", httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) @@ -506,6 +543,9 @@ func Setup( // Riot logs get flooded unless this is handled r0mux.Handle("/presence/{userID}/status", httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } // TODO: Set presence (probably the responsibility of a presence server not clientapi) return util.JSONResponse{ Code: http.StatusOK, @@ -516,6 +556,9 @@ func Setup( r0mux.Handle("/voip/turnServer", httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return RequestTurnServer(req, device, cfg) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -582,6 +625,9 @@ func Setup( r0mux.Handle("/user_directory/search", httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } postContent := struct { SearchString string `json:"search_term"` Limit int `json:"limit"` @@ -623,6 +669,9 @@ func Setup( r0mux.Handle("/rooms/{roomID}/read_markers", httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } // TODO: return the read_markers. return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}} }), @@ -721,6 +770,9 @@ func Setup( r0mux.Handle("/capabilities", httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } return GetCapabilities(req, rsAPI) }), ).Methods(http.MethodGet) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index cff376d8c..78ed3af6c 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "fmt" "github.com/matrix-org/dendrite/internal/config" @@ -8,6 +9,9 @@ import ( ) func main() { + defaultsForCI := flag.Bool("ci", false, "sane defaults for CI testing") + flag.Parse() + cfg := &config.Dendrite{} cfg.Defaults() cfg.Global.TrustedIDServers = []string{ @@ -56,6 +60,11 @@ func main() { }, } + if *defaultsForCI { + cfg.ClientAPI.RateLimiting.Enabled = false + cfg.FederationSender.DisableTLSValidation = true + } + j, err := yaml.Marshal(cfg) if err != nil { panic(err) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 23f142a83..570669c1a 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -133,6 +133,14 @@ client_api: turn_username: "" turn_password: "" + # Settings for rate-limited endpoints. Rate limiting will kick in after the + # threshold number of "slots" have been taken by requests from a specific + # host. Each "slot" will be released after the cooloff time in milliseconds. + rate_limiting: + enabled: true + threshold: 5 + cooloff_ms: 500 + # Configuration for the Current State Server. current_state_server: internal_api: diff --git a/internal/config/config_clientapi.go b/internal/config/config_clientapi.go index f7878276a..521154911 100644 --- a/internal/config/config_clientapi.go +++ b/internal/config/config_clientapi.go @@ -34,6 +34,9 @@ type ClientAPI struct { // TURN options TURN TURN `yaml:"turn"` + + // Rate-limiting options + RateLimiting RateLimiting `yaml:"rate_limiting"` } func (c *ClientAPI) Defaults() { @@ -47,6 +50,7 @@ func (c *ClientAPI) Defaults() { c.RecaptchaBypassSecret = "" c.RecaptchaSiteVerifyAPI = "" c.RegistrationDisabled = false + c.RateLimiting.Defaults() } func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -61,6 +65,7 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI)) } c.TURN.Verify(configErrs) + c.RateLimiting.Verify(configErrs) } type TURN struct { @@ -90,3 +95,29 @@ func (c *TURN) Verify(configErrs *ConfigErrors) { } } } + +type RateLimiting struct { + // Is rate limiting enabled or disabled? + Enabled bool `yaml:"enabled"` + + // How many "slots" a user can occupy sending requests to a rate-limited + // endpoint before we apply rate-limiting + Threshold int64 `yaml:"threshold"` + + // The cooloff period in milliseconds after a request before the "slot" + // is freed again + CooloffMS int64 `yaml:"cooloff_ms"` +} + +func (r *RateLimiting) Verify(configErrs *ConfigErrors) { + if r.Enabled { + checkPositive(configErrs, "client_api.rate_limiting.threshold", r.Threshold) + checkPositive(configErrs, "client_api.rate_limiting.cooloff_ms", r.CooloffMS) + } +} + +func (r *RateLimiting) Defaults() { + r.Enabled = true + r.Threshold = 5 + r.CooloffMS = 500 +} From 6150de6cb3611ffc61ce10ed6714f65e51e38e78 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 3 Sep 2020 15:22:16 +0100 Subject: [PATCH 10/12] FIFO ordering of input events (#1386) * Initial FIFOing of roomserver inputs * Remove EventID response from api.InputRoomEventsResponse * Don't send back event ID unnecessarily * Fix ordering hopefully * Reduce copies, use buffered task channel to reduce contention on other rooms * Fix error handling --- clientapi/routing/createroom.go | 3 +- clientapi/routing/membership.go | 5 +- clientapi/routing/profile.go | 4 +- clientapi/routing/redaction.go | 3 +- clientapi/routing/sendevent.go | 9 ++- clientapi/threepid/invites.go | 3 +- federationapi/routing/join.go | 5 +- federationapi/routing/leave.go | 5 +- federationapi/routing/send.go | 3 +- federationapi/routing/threepid.go | 4 +- roomserver/api/input.go | 1 - roomserver/api/wrapper.go | 11 ++- roomserver/internal/input/input.go | 82 ++++++++++++++++++++--- roomserver/internal/input/input_events.go | 4 +- roomserver/roomserver_test.go | 3 +- 15 files changed, 99 insertions(+), 46 deletions(-) diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index 57fc3f33a..af43064fe 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -342,8 +342,7 @@ func createRoom( } // send events to the room server - _, err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index cba19a24b..202662ab6 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -75,13 +75,12 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us return jsonerror.InternalServerError() } - _, err = roomserverAPI.SendEvents( + if err = roomserverAPI.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4c7895bd3..bc51b0b51 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -171,7 +171,7 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -289,7 +289,7 @@ func SetDisplayName( return jsonerror.InternalServerError() } - if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index a825da64d..178bfafc9 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -122,8 +122,7 @@ func SendRedaction( JSON: jsonerror.NotFound("Room does not exist"), } } - _, err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil) - if err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index a25979ea0..9744a5640 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -90,27 +90,26 @@ func SendEvent( // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded - eventID, err := api.SendEvents( + if err := api.SendEvents( req.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, txnAndSessionID, - ) - if err != nil { + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } util.GetLogger(req.Context()).WithFields(logrus.Fields{ - "event_id": eventID, + "event_id": e.EventID(), "room_id": roomID, "room_version": verRes.RoomVersion, }).Info("Sent event to roomserver") res := util.JSONResponse{ Code: http.StatusOK, - JSON: sendEventResponse{eventID}, + JSON: sendEventResponse{e.EventID()}, } // Add response to transactionsCache if txnID != nil { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 2ffb6bb09..b9575a284 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,7 +359,7 @@ func emit3PIDInviteEvent( return err } - _, err = api.SendEvents( + return api.SendEvents( ctx, rsAPI, []gomatrixserverlib.HeaderedEvent{ (*event).Headered(queryRes.RoomVersion), @@ -367,5 +367,4 @@ func emit3PIDInviteEvent( cfg.Matrix.ServerName, nil, ) - return err } diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 6cac12451..36afe30ab 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -266,15 +266,14 @@ func SendJoin( // We are responsible for notifying other servers that the user has joined // the room, so set SendAsServer to cfg.Matrix.ServerName if !alreadyJoined { - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(stateAndAuthChainResponse.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index 511623445..8bb0a8a94 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -247,15 +247,14 @@ func SendLeave( // Send the events to the room server. // We are responsible for notifying other servers that the user has left // the room, so set SendAsServer to cfg.Matrix.ServerName - _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ event.Headered(verRes.RoomVersion), }, cfg.Matrix.ServerName, nil, - ) - if err != nil { + ); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed") return jsonerror.InternalServerError() } diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index cad779219..570062adc 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -382,7 +382,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro } // pass the event to the roomserver - _, err := api.SendEvents( + return api.SendEvents( t.context, t.rsAPI, []gomatrixserverlib.HeaderedEvent{ e.Headered(stateResp.RoomVersion), @@ -390,7 +390,6 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro api.DoNotSendToOtherServers, nil, ) - return err } func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error { diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index e8d9a9397..ec6cc1488 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if _, err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { + if err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -172,7 +172,7 @@ func ExchangeThirdPartyInvite( } // Send the event to the roomserver - if _, err = api.SendEvents( + if err = api.SendEvents( httpReq.Context(), rsAPI, []gomatrixserverlib.HeaderedEvent{ signedEvent.Event.Headered(verRes.RoomVersion), diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 05c981df4..73c4994a7 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -83,5 +83,4 @@ type InputRoomEventsRequest struct { // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { - EventID string `json:"event_id"` } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 207c12c8f..16f5e8e18 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, -) (string, error) { +) error { ires := make([]InputRoomEvent, len(events)) for i, event := range events { ires[i] = InputRoomEvent{ @@ -77,19 +77,16 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - _, err = SendInputRoomEvents(ctx, rsAPI, ires) - return err + return SendInputRoomEvents(ctx, rsAPI, ires) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, -) (eventID string, err error) { +) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse - err = rsAPI.InputRoomEvents(ctx, &request, &response) - eventID = response.EventID - return + return rsAPI.InputRoomEvents(ctx, &request, &response) } // SendInvite event to the roomserver. diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 87bdc5dbf..7a44ff42c 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,12 +19,14 @@ import ( "context" "encoding/json" "sync" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) type Inputer struct { @@ -33,7 +35,36 @@ type Inputer struct { ServerName gomatrixserverlib.ServerName OutputRoomEventTopic string - mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent + workers sync.Map // room ID -> *inputWorker +} + +type inputTask struct { + ctx context.Context + event *api.InputRoomEvent + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + r *Inputer + running atomic.Bool + input chan *inputTask +} + +func (w *inputWorker) start() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for { + select { + case task := <-w.input: + _, task.err = w.r.processRoomEvent(task.ctx, task.event) + task.wg.Done() + case <-time.After(time.Second * 5): + return + } + } } // WriteOutputEvents implements OutputRoomEventWriter @@ -73,19 +104,54 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) (err error) { +) error { + // Create a wait group. Each task that we dispatch will call Done on + // this wait group so that we know when all of our events have been + // processed. + wg := &sync.WaitGroup{} + wg.Add(len(request.InputRoomEvents)) + tasks := make([]*inputTask, len(request.InputRoomEvents)) + for i, e := range request.InputRoomEvents { + // Work out if we are running per-room workers or if we're just doing + // it on a global basis (e.g. SQLite). roomID := "global" if r.DB.SupportsConcurrentRoomInputs() { roomID = e.Event.RoomID() } - mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{}) - mutex.(*sync.Mutex).Lock() - if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil { - mutex.(*sync.Mutex).Unlock() - return err + + // Look up the worker, or create it if it doesn't exist. This channel + // is buffered to reduce the chance that we'll be blocked by another + // room - the channel will be quite small as it's just pointer types. + w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ + r: r, + input: make(chan *inputTask, 10), + }) + worker := w.(*inputWorker) + + // Create a task. This contains the input event and a reference to + // the wait group, so that the worker can notify us when this specific + // task has been finished. + tasks[i] = &inputTask{ + ctx: ctx, + event: &request.InputRoomEvents[i], + wg: wg, + } + + // Send the task to the worker. + go worker.start() + worker.input <- tasks[i] + } + + // Wait for all of the workers to return results about our tasks. + wg.Wait() + + // If any of the tasks returned an error, we should probably report + // that back to the caller. + for _, task := range tasks { + if task.err != nil { + return task.err } - mutex.(*sync.Mutex).Unlock() } return nil } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 69f51f4b8..6ee679da6 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -38,7 +38,7 @@ import ( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, - input api.InputRoomEvent, + input *api.InputRoomEvent, ) (eventID string, err error) { // Parse and validate the event JSON headered := input.Event @@ -143,7 +143,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) calculateAndSetState( ctx context.Context, - input api.InputRoomEvent, + input *api.InputRoomEvent, roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 0deb7acb1..786d4f31f 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -114,8 +114,7 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) hevents := mustLoadEvents(t, ver, events) - _, err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil) - if err != nil { + if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents From b20386123e0cbdc53016231f0087d0047b5667e9 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 3 Sep 2020 17:20:54 +0100 Subject: [PATCH 11/12] Move currentstateserver API to roomserver (#1387) * Move currentstateserver API to roomserver Stub out DB functions for now, nothing uses the roomserver version yet. * Allow it to startup * Implement some current-state-server storage interface functions * Add missing package --- currentstateserver/acls/acls.go | 12 +- federationapi/routing/send_test.go | 24 +++ roomserver/acls/acls.go | 164 ++++++++++++++++++ roomserver/acls/acls_test.go | 105 +++++++++++ roomserver/api/api.go | 14 ++ roomserver/api/api_trace.go | 41 +++++ roomserver/api/query.go | 104 +++++++++++ roomserver/api/wrapper.go | 99 +++++++++++ roomserver/internal/api.go | 6 +- roomserver/internal/query/query.go | 102 ++++++++++- roomserver/inthttp/client.go | 72 ++++++++ roomserver/inthttp/server.go | 78 +++++++++ roomserver/storage/interface.go | 19 ++ .../storage/postgres/membership_table.go | 24 +++ roomserver/storage/postgres/rooms_table.go | 48 +++++ roomserver/storage/shared/storage.go | 80 +++++++++ .../storage/sqlite3/membership_table.go | 24 +++ roomserver/storage/sqlite3/rooms_table.go | 49 ++++++ roomserver/storage/tables/interface.go | 3 + 19 files changed, 1062 insertions(+), 6 deletions(-) create mode 100644 roomserver/acls/acls.go create mode 100644 roomserver/acls/acls_test.go diff --git a/currentstateserver/acls/acls.go b/currentstateserver/acls/acls.go index 12619f5fc..775b6c73a 100644 --- a/currentstateserver/acls/acls.go +++ b/currentstateserver/acls/acls.go @@ -23,17 +23,25 @@ import ( "strings" "sync" - "github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + type ServerACLs struct { acls map[string]*serverACL // room ID -> ACL aclsMutex sync.RWMutex // protects the above } -func NewServerACLs(db storage.Database) *ServerACLs { +func NewServerACLs(db ServerACLDatabase) *ServerACLs { ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index fa745e286..6dc8621b2 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias( return fmt.Errorf("not implemented") } +func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + return nil +} + +func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + return fmt.Errorf("not implemented") +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + return nil +} + type testStateAPI struct { } diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go new file mode 100644 index 000000000..775b6c73a --- /dev/null +++ b/roomserver/acls/acls.go @@ -0,0 +1,164 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "context" + "encoding/json" + "fmt" + "net" + "regexp" + "strings" + "sync" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +type ServerACLDatabase interface { + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) +} + +type ServerACLs struct { + acls map[string]*serverACL // room ID -> ACL + aclsMutex sync.RWMutex // protects the above +} + +func NewServerACLs(db ServerACLDatabase) *ServerACLs { + ctx := context.TODO() + acls := &ServerACLs{ + acls: make(map[string]*serverACL), + } + // Look up all of the rooms that the current state server knows about. + rooms, err := db.GetKnownRooms(ctx) + if err != nil { + logrus.WithError(err).Fatalf("Failed to get known rooms") + } + // For each room, let's see if we have a server ACL state event. If we + // do then we'll process it into memory so that we have the regexes to + // hand. + for _, room := range rooms { + state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "") + if err != nil { + logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) + continue + } + if state != nil { + acls.OnServerACLUpdate(&state.Event) + } + } + return acls +} + +type ServerACL struct { + Allowed []string `json:"allow"` + Denied []string `json:"deny"` + AllowIPLiterals bool `json:"allow_ip_literals"` +} + +type serverACL struct { + ServerACL + allowedRegexes []*regexp.Regexp + deniedRegexes []*regexp.Regexp +} + +func compileACLRegex(orig string) (*regexp.Regexp, error) { + escaped := regexp.QuoteMeta(orig) + escaped = strings.Replace(escaped, "\\?", ".", -1) + escaped = strings.Replace(escaped, "\\*", ".*", -1) + return regexp.Compile(escaped) +} + +func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) { + acls := &serverACL{} + if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { + logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") + return + } + // The spec calls only for * (zero or more chars) and ? (exactly one char) + // to be supported as wildcard components, so we will escape all of the regex + // special characters and then replace * and ? with their regex counterparts. + // https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl + for _, orig := range acls.Allowed { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile allowed regex") + } else { + acls.allowedRegexes = append(acls.allowedRegexes, expr) + } + } + for _, orig := range acls.Denied { + if expr, err := compileACLRegex(orig); err != nil { + logrus.WithError(err).Errorf("Failed to compile denied regex") + } else { + acls.deniedRegexes = append(acls.deniedRegexes, expr) + } + } + logrus.WithFields(logrus.Fields{ + "allow_ip_literals": acls.AllowIPLiterals, + "num_allowed": len(acls.allowedRegexes), + "num_denied": len(acls.deniedRegexes), + }).Debugf("Updating server ACLs for %q", state.RoomID()) + s.aclsMutex.Lock() + defer s.aclsMutex.Unlock() + s.acls[state.RoomID()] = acls +} + +func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool { + s.aclsMutex.RLock() + // First of all check if we have an ACL for this room. If we don't then + // no servers are banned from the room. + acls, ok := s.acls[roomID] + if !ok { + s.aclsMutex.RUnlock() + return false + } + s.aclsMutex.RUnlock() + // Split the host and port apart. This is because the spec calls on us to + // validate the hostname only in cases where the port is also present. + if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil { + serverName = gomatrixserverlib.ServerName(serverNameOnly) + } + // Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding + // a /0 prefix length just to trick ParseCIDR into working. If we find that + // the server is an IP literal and we don't allow those then stop straight + // away. + if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil { + if !acls.AllowIPLiterals { + return true + } + } + // Check if the hostname matches one of the denied regexes. If it does then + // the server is banned from the room. + for _, expr := range acls.deniedRegexes { + if expr.MatchString(string(serverName)) { + return true + } + } + // Check if the hostname matches one of the allowed regexes. If it does then + // the server is NOT banned from the room. + for _, expr := range acls.allowedRegexes { + if expr.MatchString(string(serverName)) { + return false + } + } + // If we've got to this point then we haven't matched any regexes or an IP + // hostname if disallowed. The spec calls for default-deny here. + return true +} diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go new file mode 100644 index 000000000..9fb6a5581 --- /dev/null +++ b/roomserver/acls/acls_test.go @@ -0,0 +1,105 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package acls + +import ( + "regexp" + "testing" +) + +func TestOpenACLsWithBlacklist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("*") + if err != nil { + t.Fatalf(err.Error()) + } + denyRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: true, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{denyRegex}, + } + + if acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("bar.com:4567", roomID) { + t.Fatal("Expected bar.com:4567 to be allowed but wasn't") + } +} + +func TestDefaultACLsWithWhitelist(t *testing.T) { + roomID := "!test:test.com" + allowRegex, err := compileACLRegex("foo.com") + if err != nil { + t.Fatalf(err.Error()) + } + + acls := ServerACLs{ + acls: make(map[string]*serverACL), + } + + acls.acls[roomID] = &serverACL{ + ServerACL: ServerACL{ + AllowIPLiterals: false, + }, + allowedRegexes: []*regexp.Regexp{allowRegex}, + deniedRegexes: []*regexp.Regexp{}, + } + + if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) { + t.Fatal("Expected 1.2.3.4 to be banned but wasn't") + } + if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) { + t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com", roomID) { + t.Fatal("Expected foo.com to be allowed but wasn't") + } + if acls.IsServerBannedFromRoom("foo.com:3456", roomID) { + t.Fatal("Expected foo.com:3456 to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("bar.com", roomID) { + t.Fatal("Expected bar.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("baz.com", roomID) { + t.Fatal("Expected baz.com to be allowed but wasn't") + } + if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) { + t.Fatal("Expected qux.com:4567 to be allowed but wasn't") + } +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 0fe30b8b5..96bdc767e 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -106,6 +106,20 @@ type RoomserverInternalAPI interface { response *QueryStateAndAuthChainResponse, ) error + // QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from + // the response. + QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error + // QueryRoomsForUser retrieves a list of room IDs matching the given query. + QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error + // QueryBulkStateContent does a bulk query for state event content in the given rooms. + QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error + // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. + QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error + // QueryKnownUsers returns a list of users that we know about from our joined rooms. + QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. + QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + // Query a given amount (or less) of events prior to a given set of events. PerformBackfill( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 9b53aa88c..25da2e8e0 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -236,6 +236,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias( return err } +func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error { + err := t.Impl.QueryCurrentState(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryRoomsForUser retrieves a list of room IDs matching the given query. +func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error { + err := t.Impl.QueryRoomsForUser(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryBulkStateContent does a bulk query for state event content in the given rooms. +func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error { + err := t.Impl.QueryBulkStateContent(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res)) + return err +} + +// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. +func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error { + err := t.Impl.QuerySharedUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryKnownUsers returns a list of users that we know about from our joined rooms. +func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error { + err := t.Impl.QueryKnownUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res)) + return err +} + +// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. +func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error { + err := t.Impl.QueryServerBannedFromRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 4e1d09c30..d0d0474d8 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -17,6 +17,11 @@ package api import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/gomatrixserverlib" ) @@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct { // The list of published rooms. RoomIDs []string } + +type QuerySharedUsersRequest struct { + UserID string + ExcludeRoomIDs []string + IncludeRoomIDs []string +} + +type QuerySharedUsersResponse struct { + UserIDsToCount map[string]int +} + +type QueryRoomsForUserRequest struct { + UserID string + // The desired membership of the user. If this is the empty string then no rooms are returned. + WantMembership string +} + +type QueryRoomsForUserResponse struct { + RoomIDs []string +} + +type QueryBulkStateContentRequest struct { + // Returns state events in these rooms + RoomIDs []string + // If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*' + AllowWildcards bool + // The state events to return. Only a small subset of tuples are allowed in this request as only certain events + // have their content fields extracted. Specifically, the tuple Type must be one of: + // m.room.avatar + // m.room.create + // m.room.canonical_alias + // m.room.guest_access + // m.room.history_visibility + // m.room.join_rules + // m.room.member + // m.room.name + // m.room.topic + // Any other tuple type will result in the query failing. + StateTuples []gomatrixserverlib.StateKeyTuple +} +type QueryBulkStateContentResponse struct { + // map of room ID -> tuple -> content_value + Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string +} + +type QueryCurrentStateRequest struct { + RoomID string + StateTuples []gomatrixserverlib.StateKeyTuple +} + +type QueryCurrentStateResponse struct { + StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent +} + +type QueryKnownUsersRequest struct { + UserID string `json:"user_id"` + SearchString string `json:"search_string"` + Limit int `json:"limit"` +} + +type QueryKnownUsersResponse struct { + Users []authtypes.FullyQualifiedProfile `json:"profiles"` +} + +type QueryServerBannedFromRoomRequest struct { + ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomID string `json:"room_id"` +} + +type QueryServerBannedFromRoomResponse struct { + Banned bool `json:"banned"` +} + +// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. +func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { + se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) + for k, v := range r.StateEvents { + // use 0x1F (unit separator) as the delimiter between type/state key, + se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v + } + return json.Marshal(se) +} + +func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { + res := make(map[string]*gomatrixserverlib.HeaderedEvent) + err := json.Unmarshal(data, &res) + if err != nil { + return err + } + r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res)) + for k, v := range res { + fields := strings.Split(k, "\x1F") + r.StateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: fields[0], + StateKey: fields[1], + }] = v + } + return nil +} diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 16f5e8e18..82a4a5719 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -133,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string) } return &res.Events[0] } + +// GetStateEvent returns the current state event in the room or nil. +func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent { + var res QueryCurrentStateResponse + err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{ + RoomID: roomID, + StateTuples: []gomatrixserverlib.StateKeyTuple{tuple}, + }, &res) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState") + return nil + } + ev, ok := res.StateEvents[tuple] + if ok { + return ev + } + return nil +} + +// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs. +func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool { + req := &QueryServerBannedFromRoomRequest{ + ServerName: serverName, + RoomID: roomID, + } + res := &QueryServerBannedFromRoomResponse{} + if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom") + return true + } + return res.Banned +} + +// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the +// published room directory. +// due to lots of switches +// nolint:gocyclo +func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) { + avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""} + nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""} + canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""} + topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""} + guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""} + visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""} + joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""} + + var stateRes QueryBulkStateContentResponse + err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{ + RoomIDs: roomIDs, + AllowWildcards: true, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple, + {EventType: gomatrixserverlib.MRoomMember, StateKey: "*"}, + }, + }, &stateRes) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed") + return nil, err + } + chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs)) + i := 0 + for roomID, data := range stateRes.Rooms { + pub := gomatrixserverlib.PublicRoom{ + RoomID: roomID, + } + joinCount := 0 + var joinRule, guestAccess string + for tuple, contentVal := range data { + if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" { + joinCount++ + continue + } + switch tuple { + case avatarTuple: + pub.AvatarURL = contentVal + case nameTuple: + pub.Name = contentVal + case topicTuple: + pub.Topic = contentVal + case canonicalTuple: + pub.CanonicalAlias = contentVal + case visibilityTuple: + pub.WorldReadable = contentVal == "world_readable" + // need both of these to determine whether guests can join + case joinRuleTuple: + joinRule = contentVal + case guestTuple: + guestAccess = contentVal + } + } + if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" { + pub.GuestCanJoin = true + } + pub.JoinedMembersCount = joinCount + chunk[i] = pub + i++ + } + return chunk, nil +} diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 93c0be77b..bdea650ea 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -7,6 +7,7 @@ import ( fsAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/internal/perform" @@ -46,8 +47,9 @@ func NewRoomserverAPI( ServerName: cfg.Matrix.ServerName, KeyRing: keyRing, Queryer: &query.Queryer{ - DB: roomserverDB, - Cache: caches, + DB: roomserverDB, + Cache: caches, + ServerACLs: acls.NewServerACLs(roomserverDB), }, Inputer: &input.Inputer{ DB: roomserverDB, diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b2799aefb..f76c93166 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,9 +16,12 @@ package query import ( "context" + "errors" "fmt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" @@ -31,8 +34,9 @@ import ( ) type Queryer struct { - DB storage.Database - Cache caching.RoomServerCaches + DB storage.Database + Cache caching.RoomServerCaches + ServerACLs *acls.ServerACLs } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -502,3 +506,97 @@ func (r *Queryer) QueryPublishedRooms( res.RoomIDs = rooms return nil } + +func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { + res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) + for _, tuple := range req.StateTuples { + ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey) + if err != nil { + return err + } + if ev != nil { + res.StateEvents[tuple] = ev + } + } + return nil +} + +func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) + if err != nil { + return err + } + res.RoomIDs = roomIDs + return nil +} + +func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { + users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit) + if err != nil { + return err + } + for _, user := range users { + res.Users = append(res.Users, authtypes.FullyQualifiedProfile{ + UserID: user, + }) + } + return nil +} + +func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { + events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) + if err != nil { + return err + } + res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string) + for _, ev := range events { + if res.Rooms[ev.RoomID] == nil { + res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string) + } + room := res.Rooms[ev.RoomID] + room[gomatrixserverlib.StateKeyTuple{ + EventType: ev.EventType, + StateKey: ev.StateKey, + }] = ev.ContentValue + res.Rooms[ev.RoomID] = room + } + return nil +} + +func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { + roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") + if err != nil { + return err + } + roomIDs = append(roomIDs, req.IncludeRoomIDs...) + excludeMap := make(map[string]bool) + for _, roomID := range req.ExcludeRoomIDs { + excludeMap[roomID] = true + } + // filter out excluded rooms + j := 0 + for i := range roomIDs { + // move elements to include to the beginning of the slice + // then trim elements on the right + if !excludeMap[roomIDs[i]] { + roomIDs[j] = roomIDs[i] + j++ + } + } + roomIDs = roomIDs[:j] + + users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs) + if err != nil { + return err + } + res.UserIDsToCount = users + return nil +} + +func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error { + if r.ServerACLs == nil { + return errors.New("no server ACL tracking") + } + res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID) + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1657bcdeb..b414b0d8c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -43,6 +43,12 @@ const ( RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities" RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom" RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms" + RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState" + RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser" + RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent" + RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers" + RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" + RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" ) type httpRoomserverInternalAPI struct { @@ -371,3 +377,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom( } return err } + +func (h *httpRoomserverInternalAPI) QueryCurrentState( + ctx context.Context, + request *api.QueryCurrentStateRequest, + response *api.QueryCurrentStateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryRoomsForUser( + ctx context.Context, + request *api.QueryRoomsForUserRequest, + response *api.QueryRoomsForUserResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QueryBulkStateContent( + ctx context.Context, + request *api.QueryBulkStateContentRequest, + response *api.QueryBulkStateContentResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpRoomserverInternalAPI) QuerySharedUsers( + ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryKnownUsers( + ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom( + ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 0ac36a2a4..ebfb296d8 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryCurrentStatePath, + httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse { + request := api.QueryCurrentStateRequest{} + response := api.QueryCurrentStateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryRoomsForUserPath, + httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { + request := api.QueryRoomsForUserRequest{} + response := api.QueryRoomsForUserResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryBulkStateContentPath, + httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { + request := api.QueryBulkStateContentRequest{} + response := api.QueryBulkStateContentResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse { + request := api.QuerySharedUsersRequest{} + response := api.QuerySharedUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQuerySharedUsersPath, + httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse { + request := api.QueryKnownUsersRequest{} + response := api.QueryKnownUsersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath, + httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse { + request := api.QueryServerBannedFromRoomRequest{} + response := api.QueryServerBannedFromRoomResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index ef7a9f090..c4119f7ed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -17,6 +17,7 @@ package storage import ( "context" + "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" @@ -138,4 +139,22 @@ type Database interface { PublishRoom(ctx context.Context, roomID string, publish bool) error // Returns a list of room IDs for rooms which are published. GetPublishedRooms(ctx context.Context) ([]string, error) + + // TODO: factor out - from currentstateserver + + // GetStateEvent returns the state event of a given type for a given room with a given state key + // If no event could be found, returns nil + // If there was an issue during the retrieval, returns an error + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) + // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). + GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. + // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. + GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) + // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. + JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) + // GetKnownUsers searches all users that userID knows about. + GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownRooms returns a list of all rooms we know about. + GetKnownRooms(ctx context.Context) ([]string, error) } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 13cef638f..0799647e9 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -99,6 +99,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" + " WHERE room_nid = $1 AND target_nid = $2" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -108,6 +111,7 @@ type membershipStatements struct { selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -126,6 +130,7 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -222,3 +227,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 13c8e703d..9d359146a 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -21,6 +21,7 @@ import ( "errors" "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -74,6 +75,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -82,6 +89,8 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt + bulkSelectRoomIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -98,9 +107,27 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, @@ -197,3 +224,24 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + var array pq.Int64Array + for _, nid := range roomNIDs { + array = append(array, int64(nid)) + } + rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 6e0ebd2c2..5c447d66f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" + csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" @@ -711,3 +712,82 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { } return &evs[0] } + +// GetStateEvent returns the current state event of a given type for a given room with a given state key +// If no event could be found, returns nil +// If there was an issue during the retrieval, returns an error +func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { + /* + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) + if err != nil { + return nil, err + } + */ + return nil, nil +} + +// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). +func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { + var membershipState tables.MembershipState + switch membership { + case "join": + membershipState = tables.MembershipStateJoin + case "invite": + membershipState = tables.MembershipStateInvite + case "leave": + membershipState = tables.MembershipStateLeaveOrBan + case "ban": + membershipState = tables.MembershipStateLeaveOrBan + default: + return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err) + } + roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState) + if err != nil { + return nil, err + } + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs) + if err != nil { + return nil, err + } + if len(roomIDs) != len(roomNIDs) { + return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs)) + } + return roomIDs, nil +} + +// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. +// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. +func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. +func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// GetKnownUsers searches all users that userID knows about. +func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { + return nil, fmt.Errorf("not implemented yet") +} + +// GetKnownRooms returns a list of all rooms we know about. +func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { + return d.RoomsTable.SelectRoomIDs(ctx) +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index b3ee69c00..e850c80bb 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -75,6 +75,9 @@ const updateMembershipSQL = "" + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + " WHERE room_nid = $4 AND target_nid = $5" +const selectRoomsWithMembershipSQL = "" + + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -84,6 +87,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt selectMembershipsFromRoomStmt *sql.Stmt selectLocalMembershipsFromRoomStmt *sql.Stmt + selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt } @@ -105,6 +109,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, + {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, }.Prepare(db) } @@ -203,3 +208,22 @@ func (s *membershipStatements) UpdateMembership( ) return err } + +func (s *membershipStatements) SelectRoomsWithMembership( + ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState, +) ([]types.RoomNID, error) { + rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 4c1699d00..daacf86fa 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -21,7 +21,9 @@ import ( "encoding/json" "errors" "fmt" + "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" @@ -64,6 +66,12 @@ const selectRoomVersionForRoomNIDSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" +const selectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms" + +const bulkSelectRoomIDsSQL = "" + + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -73,6 +81,7 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt + selectRoomIDsStmt *sql.Stmt } func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -91,9 +100,27 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, + {&s.selectRoomIDsStmt, selectRoomIDsSQL}, }.Prepare(db) } +func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) { + rows, err := s.selectRoomIDsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} + func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string @@ -203,3 +230,25 @@ func (s *roomStatements) SelectRoomVersionForRoomNID( } return roomVersion, err } + +func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed") + var roomIDs []string + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index c599dd3fe..126c27b57 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -65,6 +65,8 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) + SelectRoomIDs(ctx context.Context) ([]string, error) + BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) } type Transactions interface { @@ -120,6 +122,7 @@ type Membership interface { SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error + SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) } type Published interface { From 33b8143a9597ff8c6b75ea47a588d50dc6e72259 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 3 Sep 2020 18:27:02 +0100 Subject: [PATCH 12/12] Implement more CSS storage functions in roomserver (#1388) --- .../storage/postgres/membership_table.go | 59 +++++++ roomserver/storage/postgres/rooms_table.go | 26 +++ roomserver/storage/shared/storage.go | 150 +++++++++++++++--- .../storage/sqlite3/membership_table.go | 58 +++++++ roomserver/storage/sqlite3/rooms_table.go | 25 +++ roomserver/storage/tables/interface.go | 5 + 6 files changed, 303 insertions(+), 20 deletions(-) diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0799647e9..5164f654f 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -18,7 +18,9 @@ package postgres import ( "context" "database/sql" + "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -102,6 +108,16 @@ const updateMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid = ANY(" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -112,6 +128,8 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership( } return roomNIDs, nil } + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + roomIDarray := make([]int64, len(roomNIDs)) + for i := range roomNIDs { + roomIDarray[i] = int64(roomNIDs[i]) + } + rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 9d359146a..ef1b7891a 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -81,6 +81,9 @@ const selectRoomIDsSQL = "" + const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -91,6 +94,7 @@ type roomStatements struct { selectRoomInfoStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt bulkSelectRoomIDsStmt *sql.Stmt + bulkSelectRoomNIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -109,6 +113,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, + {&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL}, }.Prepare(db) } @@ -245,3 +250,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types } return roomIDs, nil } + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + var array pq.StringArray + for _, roomID := range roomIDs { + array = append(array, roomID) + } + rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 5c447d66f..a3b33a4fe 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sort" csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" @@ -13,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -717,25 +719,42 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - /* - roomInfo, err := d.RoomInfo(ctx, roomID) - if err != nil { - return nil, err + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + // return the event requested + for _, e := range entries { + if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID) + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + h := ev.Headered(roomInfo.RoomVersion) + return &h, nil } - eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) - if err != nil { - return nil, err - } - stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) - if err != nil { - return nil, err - } - blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) - if err != nil { - return nil, err - } - */ - return nil, nil + } + + return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID) } // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). @@ -779,15 +798,106 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - return nil, fmt.Errorf("not implemented yet") + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + if err != nil { + return nil, err + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + if err != nil { + return nil, err + } + stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount)) + i := 0 + for nid := range userNIDToCount { + stateKeyNIDs[i] = nid + i++ + } + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + if err != nil { + return nil, err + } + if len(nidToUserID) != len(userNIDToCount) { + return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID)) + } + result := make(map[string]int, len(userNIDToCount)) + for nid, count := range userNIDToCount { + result[nidToUserID[nid]] = count + } + return result, nil } // GetKnownUsers searches all users that userID knows about. func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { - return nil, fmt.Errorf("not implemented yet") + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, err + } + return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDs(ctx) } + +// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops +// it should live in this package! + +func (d *Database) loadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index e850c80bb..0d5ce516d 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -38,6 +40,10 @@ const membershipSchema = ` ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -78,6 +84,16 @@ const updateMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid IN (" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -89,6 +105,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -110,6 +127,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -227,3 +245,43 @@ func (s *membershipStatements) SelectRoomsWithMembership( } return roomNIDs, nil } + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index daacf86fa..b4564aff9 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -72,6 +72,9 @@ const selectRoomIDsSQL = "" + const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -252,3 +255,25 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types } return roomIDs, nil } + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 126c27b57..a142f2b1a 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -67,6 +67,7 @@ type Rooms interface { SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomIDs(ctx context.Context) ([]string, error) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) } type Transactions interface { @@ -123,6 +124,10 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the + // counts of how many rooms they are joined. + SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) } type Published interface {