From 490c40f5f38ddc7c63bff70ea9ecc677fb650c7b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 20 Jan 2020 15:56:40 +0000 Subject: [PATCH] Merge forward --- cmd/dendrite-monolith-server/main.go | 2 +- cmd/dendrite-sync-api-server/main.go | 3 +- roomserver/api/query.go | 40 +- roomserver/query/query.go | 49 ++ syncapi/consumers/roomserver.go | 1 + syncapi/routing/messages.go | 480 ++++++++++++++++++ syncapi/routing/routing.go | 15 +- .../postgres/backward_extremities_table.go | 118 +++++ .../postgres/current_room_state_table.go | 3 +- .../postgres/output_room_events_table.go | 149 ++++-- .../output_room_events_topology_table.go | 187 +++++++ syncapi/storage/postgres/syncserver.go | 316 ++++++++++-- syncapi/storage/storage.go | 7 +- syncapi/sync/request.go | 24 + syncapi/syncapi.go | 7 +- syncapi/types/types.go | 93 ++++ 16 files changed, 1396 insertions(+), 98 deletions(-) create mode 100644 syncapi/routing/messages.go create mode 100644 syncapi/storage/postgres/backward_extremities_table.go create mode 100644 syncapi/storage/postgres/output_room_events_topology_table.go diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 5ea6b154a..7515ec5c0 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -70,7 +70,7 @@ func main() { federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query, asQuery, fedSenderAPI) mediaapi.SetupMediaAPIComponent(base, deviceDB) publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB) - syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query) + syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) httpHandler := common.WrapHandlerInCORS(base.APIMux) diff --git a/cmd/dendrite-sync-api-server/main.go b/cmd/dendrite-sync-api-server/main.go index 1c47ec525..55e9faeef 100644 --- a/cmd/dendrite-sync-api-server/main.go +++ b/cmd/dendrite-sync-api-server/main.go @@ -26,10 +26,11 @@ func main() { deviceDB := base.CreateDeviceDB() accountDB := base.CreateAccountsDB() + federation := base.CreateFederationClient() _, _, query := base.CreateHTTPRoomserverAPIs() - syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query) + syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query, federation, cfg) base.SetupAndServeHTTP(string(base.Cfg.Bind.SyncAPI), string(base.Cfg.Listen.SyncAPI)) diff --git a/roomserver/api/query.go b/roomserver/api/query.go index e52c74ac3..a990cd0b4 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -230,6 +230,20 @@ type QueryBackfillResponse struct { Events []gomatrixserverlib.Event `json:"events"` } +// QueryServersInRoomAtEventRequest is a request to QueryServersInRoomAtEvent +type QueryServersInRoomAtEventRequest struct { + // ID of the room to retrieve member servers for. + RoomID string `json:"room_id"` + // ID of the event for which to retrieve member servers. + EventID string `json:"event_id"` +} + +// QueryServersInRoomAtEventResponse is a response to QueryServersInRoomAtEvent +type QueryServersInRoomAtEventResponse struct { + // Servers present in the room for these events. + Servers []gomatrixserverlib.ServerName `json:"servers"` +} + // RoomserverQueryAPI is used to query information from the room server. type RoomserverQueryAPI interface { // Query the latest events and state for a room from the room server. @@ -303,6 +317,12 @@ type RoomserverQueryAPI interface { request *QueryBackfillRequest, response *QueryBackfillResponse, ) error + + QueryServersInRoomAtEvent( + ctx context.Context, + request *QueryServersInRoomAtEventRequest, + response *QueryServersInRoomAtEventResponse, + ) error } // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. @@ -332,8 +352,11 @@ const RoomserverQueryMissingEventsPath = "/api/roomserver/queryMissingEvents" // RoomserverQueryStateAndAuthChainPath is the HTTP path for the QueryStateAndAuthChain API const RoomserverQueryStateAndAuthChainPath = "/api/roomserver/queryStateAndAuthChain" -// RoomserverQueryBackfillPath is the HTTP path for the QueryBackfill API -const RoomserverQueryBackfillPath = "/api/roomserver/QueryBackfill" +// RoomserverQueryBackfillPath is the HTTP path for the QueryMissingEvents API +const RoomserverQueryBackfillPath = "/api/roomserver/queryBackfill" + +// RoomserverQueryServersInRoomAtEvent is the HTTP path for the QueryServersInRoomAtEvent API +const RoomserverQueryServersInRoomAtEvent = "/api/roomserver/queryServersInRoomAtEvents" // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil then it uses the http.DefaultClient @@ -478,3 +501,16 @@ func (h *httpRoomserverQueryAPI) QueryBackfill( apiURL := h.roomserverURL + RoomserverQueryBackfillPath return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// QueryServersInRoomAtEvent implements RoomServerQueryAPI +func (h *httpRoomserverQueryAPI) QueryServersInRoomAtEvent( + ctx context.Context, + request *QueryServersInRoomAtEventRequest, + response *QueryServersInRoomAtEventResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBackfill") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryServersInRoomAtEvent + return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/roomserver/query/query.go b/roomserver/query/query.go index caa7a95bb..9dbf3a44b 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -660,6 +660,41 @@ func getAuthChain( return authEvents, nil } +// QueryServersInRoomAtEvent implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryServersInRoomAtEvent( + ctx context.Context, + request *api.QueryServersInRoomAtEventRequest, + response *api.QueryServersInRoomAtEventResponse, +) error { + // getMembershipsBeforeEventNID requires a NID, so retrieving the NID for + // the event is necessary. + NIDs, err := r.DB.EventNIDs(ctx, []string{request.EventID}) + if err != nil { + return err + } + + // Retrieve all "m.room.member" state events of "join" membership, which + // contains the list of users in the room before the event, therefore all + // the servers in it at that moment. + events, err := r.getMembershipsBeforeEventNID(ctx, NIDs[request.EventID], true) + if err != nil { + return err + } + + // Store the server names in a temporary map to avoid duplicates. + servers := make(map[gomatrixserverlib.ServerName]bool) + for _, event := range events { + servers[event.Origin()] = true + } + + // Populate the response. + for server := range servers { + response.Servers = append(response.Servers, server) + } + + return nil +} + // SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux. // nolint: gocyclo func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { @@ -803,4 +838,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + servMux.Handle( + api.RoomserverQueryServersInRoomAtEvent, + common.MakeInternalAPI("QueryServersInRoomAtEvent", func(req *http.Request) util.JSONResponse { + var request api.QueryServersInRoomAtEventRequest + var response api.QueryServersInRoomAtEventResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryServersInRoomAtEvent(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index cde2f5080..c9f572d14 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -133,6 +133,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( msg.AddsStateEventIDs, msg.RemovesStateEventIDs, msg.TransactionID, + false, ) if err != nil { // panic rather than continue with an inconsistent database diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go new file mode 100644 index 000000000..c61b5322b --- /dev/null +++ b/syncapi/routing/messages.go @@ -0,0 +1,480 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "net/http" + "sort" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +type messagesReq struct { + ctx context.Context + db storage.Database + queryAPI api.RoomserverQueryAPI + federation *gomatrixserverlib.FederationClient + cfg *config.Dendrite + roomID string + from *types.PaginationToken + to *types.PaginationToken + wasToProvided bool + limit int + backwardOrdering bool +} + +type messagesResp struct { + Start string `json:"start"` + End string `json:"end"` + Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` +} + +const defaultMessagesLimit = 10 + +// OnIncomingMessagesRequest implements the /messages endpoint from the +// client-server API. +// See: https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-rooms-roomid-messages +func OnIncomingMessagesRequest( + req *http.Request, db storage.Database, roomID string, + federation *gomatrixserverlib.FederationClient, + queryAPI api.RoomserverQueryAPI, + cfg *config.Dendrite, +) util.JSONResponse { + var err error + + // Extract parameters from the request's URL. + // Pagination tokens. + from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + } + } + + // Direction to return events from. + dir := req.URL.Query().Get("dir") + if dir != "b" && dir != "f" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + } + } + // A boolean is easier to handle in this case, especially since dir is sure + // to have one of the two accepted values (so dir == "f" <=> !backwardOrdering). + backwardOrdering := (dir == "b") + + // Pagination tokens. To is optional, and its default value depends on the + // direction ("b" or "f"). + var to *types.PaginationToken + wasToProvided := true + if s := req.URL.Query().Get("to"); len(s) > 0 { + to, err = types.NewPaginationTokenFromString(s) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), + } + } + } else { + // If "to" isn't provided, it defaults to either the earliest stream + // position (if we're going backward) or to the latest one (if we're + // going forward). + to, err = setToDefault(req.Context(), db, backwardOrdering, roomID) + if err != nil { + return httputil.LogThenError(req, err) + } + wasToProvided = false + } + + // Maximum number of events to return; defaults to 10. + limit := defaultMessagesLimit + if len(req.URL.Query().Get("limit")) > 0 { + limit, err = strconv.Atoi(req.URL.Query().Get("limit")) + + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), + } + } + } + // TODO: Implement filtering (#587) + + // Check the room ID's format. + if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + } + } + + mReq := messagesReq{ + ctx: req.Context(), + db: db, + queryAPI: queryAPI, + federation: federation, + cfg: cfg, + roomID: roomID, + from: from, + to: to, + wasToProvided: wasToProvided, + limit: limit, + backwardOrdering: backwardOrdering, + } + + clientEvents, start, end, err := mReq.retrieveEvents() + if err != nil { + return httputil.LogThenError(req, err) + } + + // Respond with the events. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: messagesResp{ + Chunk: clientEvents, + Start: start.String(), + End: end.String(), + }, + } +} + +// retrieveEvents retrieve events from the local database for a request on +// /messages. If there's not enough events to retrieve, it asks another +// homeserver in the room for older events. +// Returns an error if there was an issue talking to the database or with the +// remote homeserver. +func (r *messagesReq) retrieveEvents() ( + clientEvents []gomatrixserverlib.ClientEvent, start, + end *types.PaginationToken, err error, +) { + // Retrieve the events from the local database. + streamEvents, err := r.db.GetEventsInRange( + r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + ) + if err != nil { + return + } + + var events []gomatrixserverlib.Event + + // There can be two reasons for streamEvents to be empty: either we've + // reached the oldest event in the room (or the most recent one, depending + // on the ordering), or we've reached a backward extremity. + if len(streamEvents) == 0 { + if events, err = r.handleEmptyEventsSlice(); err != nil { + return + } + } else { + if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil { + return + } + } + + // If we didn't get any event, we don't need to proceed any further. + if len(events) == 0 { + return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil + } + + // Sort the events to ensure we send them in the right order. We currently + // do that based on the event's timestamp. + if r.backwardOrdering { + sort.SliceStable(events, func(i int, j int) bool { + // Backward ordering is antichronological (latest event to oldest + // one). + return sortEvents(&(events[j]), &(events[i])) + }) + } else { + sort.SliceStable(events, func(i int, j int) bool { + // Forward ordering is chronological (oldest event to latest one). + return sortEvents(&(events[i]), &(events[j])) + }) + } + + // Convert all of the events into client events. + clientEvents = gomatrixserverlib.ToClientEvents(events, gomatrixserverlib.FormatAll) + // Get the position of the first and the last event in the room's topology. + // This position is currently determined by the event's depth, so we could + // also use it instead of retrieving from the database. However, if we ever + // change the way topological positions are defined (as depth isn't the most + // reliable way to define it), it would be easier and less troublesome to + // only have to change it in one place, i.e. the database. + startPos, err := r.db.EventPositionInTopology( + r.ctx, streamEvents[0].EventID(), + ) + if err != nil { + return + } + endPos, err := r.db.EventPositionInTopology( + r.ctx, streamEvents[len(streamEvents)-1].EventID(), + ) + if err != nil { + return + } + // Generate pagination tokens to send to the client using the positions + // retrieved previously. + start = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, startPos, + ) + end = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, endPos, + ) + + if r.backwardOrdering { + // A stream/topological position is a cursor located between two events. + // While they are identified in the code by the event on their right (if + // we consider a left to right chronological order), tokens need to refer + // to them by the event on their left, therefore we need to decrement the + // end position we send in the response if we're going backward. + end.Position-- + } + + // The lowest token value is 1, therefore we need to manually set it to that + // value if we're below it. + if end.Position < types.StreamPosition(1) { + end.Position = types.StreamPosition(1) + } + + return +} + +// handleEmptyEventsSlice handles the case where the initial request to the +// database returned an empty slice of events. It does so by checking whether +// the set is empty because we've reached a backward extremity, and if that is +// the case, by retrieving as much events as requested by backfilling from +// another homeserver. +// Returns an error if there was an issue talking with the database or +// backfilling. +func (r *messagesReq) handleEmptyEventsSlice() ( + events []gomatrixserverlib.Event, err error, +) { + backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + + // Check if we have backward extremities for this room. + if len(backwardExtremities) > 0 { + // If so, retrieve as much events as needed through backfilling. + events, err = r.backfill(backwardExtremities, r.limit) + if err != nil { + return + } + } else { + // If not, it means the slice was empty because we reached the room's + // creation, so return an empty slice. + events = []gomatrixserverlib.Event{} + } + + return +} + +// handleNonEmptyEventsSlice handles the case where the initial request to the +// database returned a non-empty slice of events. It does so by checking whether +// events are missing from the expected result, and retrieve missing events +// through backfilling if needed. +// Returns an error if there was an issue while backfilling. +func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) ( + events []gomatrixserverlib.Event, err error, +) { + // Check if we have enough events. + isSetLargeEnough := true + if len(streamEvents) < r.limit { + if r.backwardOrdering { + if r.wasToProvided { + // The condition in the SQL query is a strict "greater than" so + // we need to check against to-1. + isSetLargeEnough = (r.to.Position-1 == types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)) + } + } else { + isSetLargeEnough = (r.from.Position-1 == types.StreamPosition(streamEvents[0].StreamPosition)) + } + } + + // Check if the slice contains a backward extremity. + backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + if err != nil { + return + } + + // Backfill is needed if we've reached a backward extremity and need more + // events. It's only needed if the direction is backward. + if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { + var pdus []gomatrixserverlib.Event + // Only ask the remote server for enough events to reach the limit. + pdus, err = r.backfill(backwardExtremities, r.limit-len(streamEvents)) + if err != nil { + return + } + + // Append the PDUs to the list to send back to the client. + events = append(events, pdus...) + } + + // Append the events ve previously retrieved locally. + events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) + + return +} + +// containsBackwardExtremity checks if a slice of StreamEvent contains a +// backward extremity. It does so by selecting the earliest event in the slice +// and by checking the presence in the database of all of its parent events, and +// considers the event itself a backward extremity if at least one of the parent +// events doesn't exist in the database. +// Returns an error if there was an issue with talking to the database. +func (r *messagesReq) containsBackwardExtremity(events []types.StreamEvent) (bool, error) { + // Select the earliest retrieved event. + var ev *types.StreamEvent + if r.backwardOrdering { + ev = &(events[len(events)-1]) + } else { + ev = &(events[0]) + } + // Get the earliest retrieved event's parents. + prevIDs := ev.PrevEventIDs() + prevs, err := r.db.Events(r.ctx, prevIDs) + if err != nil { + return false, nil + } + // Check if we have all of the events we requested. If not, it means we've + // reached a backward extremity. + var eventInDB bool + var id string + // Iterate over the IDs we used in the request. + for _, id = range prevIDs { + eventInDB = false + // Iterate over the events we got in response. + for _, ev := range prevs { + if ev.EventID() == id { + eventInDB = true + } + } + // One occurrence of one the event's parents not being present in the + // database is enough to say that the event is a backward extremity. + if !eventInDB { + return true, nil + } + } + + return false, nil +} + +// backfill performs a backfill request over the federation on another +// homeserver in the room. +// See: https://matrix.org/docs/spec/server_server/unstable.html#get-matrix-federation-v1-backfill-roomid +// It also stores the PDUs retrieved from the remote homeserver's response to +// the database. +// Returns with an empty string if the remote homeserver didn't return with any +// event, or if there is no remote homeserver to contact. +// Returns an error if there was an issue with retrieving the list of servers in +// the room or sending the request. +func (r *messagesReq) backfill(fromEventIDs []string, limit int) ([]gomatrixserverlib.Event, error) { + // Query the list of servers in the room when one of the backward extremities + // was sent. + var serversResponse api.QueryServersInRoomAtEventResponse + serversRequest := api.QueryServersInRoomAtEventRequest{ + RoomID: r.roomID, + EventID: fromEventIDs[0], + } + if err := r.queryAPI.QueryServersInRoomAtEvent(r.ctx, &serversRequest, &serversResponse); err != nil { + return nil, err + } + + // Use the first server from the response, except if that server is us. + // In that case, use the second one if the roomserver responded with + // enough servers. If not, use an empty string to prevent the backfill + // from happening as there's no server to direct the request towards. + // TODO: Be smarter at selecting the server to direct the request + // towards. + srvToBackfillFrom := serversResponse.Servers[0] + if srvToBackfillFrom == r.cfg.Matrix.ServerName { + if len(serversResponse.Servers) > 1 { + srvToBackfillFrom = serversResponse.Servers[1] + } else { + srvToBackfillFrom = gomatrixserverlib.ServerName("") + log.Warn("Not enough servers to backfill from") + } + } + + pdus := make([]gomatrixserverlib.Event, 0) + + // If the roomserver responded with at least one server that isn't us, + // send it a request for backfill. + if len(srvToBackfillFrom) > 0 { + txn, err := r.federation.Backfill( + r.ctx, srvToBackfillFrom, r.roomID, limit, fromEventIDs, + ) + if err != nil { + return nil, err + } + + pdus = txn.PDUs + + // Store the events in the database, while marking them as unfit to show + // up in responses to sync requests. + for _, pdu := range pdus { + if _, err = r.db.WriteEvent( + r.ctx, &pdu, []gomatrixserverlib.Event{}, []string{}, []string{}, + nil, true, + ); err != nil { + return nil, err + } + } + } + + return pdus, nil +} + +// setToDefault returns the default value for the "to" query parameter of a +// request to /messages if not provided. It defaults to either the earliest +// topological position (if we're going backward) or to the latest one (if we're +// going forward). +// Returns an error if there was an issue with retrieving the latest position +// from the database +func setToDefault( + ctx context.Context, db storage.Database, backwardOrdering bool, + roomID string, +) (to *types.PaginationToken, err error) { + if backwardOrdering { + to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1) + } else { + var pos types.StreamPosition + pos, err = db.MaxTopologicalPosition(ctx, roomID) + if err != nil { + return + } + + to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos) + } + + return +} + +// sortEvents is a function to give to sort.SliceStable, and compares the +// timestamp of two Matrix events. +// Returns true if the first event happened before the second one, false +// otherwise. +func sortEvents(e1 *gomatrixserverlib.Event, e2 *gomatrixserverlib.Event) bool { + t := e1.OriginServerTS().Time() + return e2.OriginServerTS().Time().After(t) +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bd9389bdd..922a7df6e 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -22,8 +22,11 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -34,7 +37,12 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB *devices.Database) { +func Setup( + apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, + deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient, + queryAPI api.RoomserverQueryAPI, + cfg *config.Dendrite, +) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ @@ -71,4 +79,9 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, d } return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"]) })).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars := mux.Vars(req) + return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg) + })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backward_extremities_table.go new file mode 100644 index 000000000..476d26faa --- /dev/null +++ b/syncapi/storage/postgres/backward_extremities_table.go @@ -0,0 +1,118 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" +) + +const backwardExtremitiesSchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The event ID for the event. + event_id TEXT NOT NULL, + + PRIMARY KEY(room_id, event_id) +); +` + +const insertBackwardExtremitySQL = "" + + "INSERT INTO syncapi_backward_extremities (room_id, event_id)" + + " VALUES ($1, $2)" + +const selectBackwardExtremitiesForRoomSQL = "" + + "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1" + +const isBackwardExtremitySQL = "" + + "SELECT EXISTS (" + + " SELECT TRUE FROM syncapi_backward_extremities" + + " WHERE room_id = $1 AND event_id = $2" + + ")" + +const deleteBackwardExtremitySQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2" + +type backwardExtremitiesStatements struct { + insertBackwardExtremityStmt *sql.Stmt + selectBackwardExtremitiesForRoomStmt *sql.Stmt + isBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremityStmt *sql.Stmt +} + +func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(backwardExtremitiesSchema) + if err != nil { + return + } + if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { + return + } + if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { + return + } + if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil { + return + } + if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { + return + } + return +} + +func (s *backwardExtremitiesStatements) insertsBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} + +func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (eventIDs []string, err error) { + eventIDs = make([]string, 0) + + rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var eID string + if err = rows.Scan(&eID); err != nil { + return + } + + eventIDs = append(eventIDs, eID) + } + + return +} + +func (s *backwardExtremitiesStatements) isBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (isBE bool, err error) { + err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE) + return +} + +func (s *backwardExtremitiesStatements) deleteBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 8b2080438..f857c0141 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -242,7 +243,7 @@ func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) selectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index ca2715934..e1b1df5ec 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/lib/pq" @@ -57,7 +58,12 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( add_state_ids TEXT[], remove_state_ids TEXT[], session_id BIGINT, -- The client session that sent the event, if any - transaction_id TEXT -- The transaction id used to send the event, if any + transaction_id TEXT, -- The transaction id used to send the event, if any + -- Should the event be excluded from responses to /sync requests. Useful for + -- events retrieved through backfilling that have a position in the stream + -- that relates to the moment these were retrieved rather than the moment these + -- were emitted. + exclude_from_sync BOOL DEFAULT FALSE ); -- for event selection CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id); @@ -65,23 +71,33 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id" const selectEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" const selectRecentEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC LIMIT $4" +const selectRecentEventsForSyncSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + + " ORDER BY id DESC LIMIT $4" + +const selectEarlyEventsSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3" + + " ORDER BY id ASC LIMIT $4" + const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). const selectStateInRangeSQL = "" + - "SELECT id, event_json, add_state_ids, remove_state_ids" + + "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + @@ -93,11 +109,13 @@ const selectStateInRangeSQL = "" + " LIMIT $8" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { @@ -117,6 +135,12 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { return } + if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { + return + } + if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { + return + } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { return } @@ -129,7 +153,7 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { func (s *outputRoomEventsStatements) selectStateInRange( ctx context.Context, txn *sql.Tx, oldPos, newPos int64, stateFilterPart *gomatrix.FilterPart, -) (map[string]map[string]bool, map[string]streamEvent, error) { +) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( @@ -149,19 +173,20 @@ func (s *outputRoomEventsStatements) selectStateInRange( // - For each room ID, build up an array of event IDs which represents cumulative adds/removes // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID // if they aren't in the event ID cache. We don't handle state deletion yet. - eventIDToEvent := make(map[string]streamEvent) + eventIDToEvent := make(map[string]types.StreamEvent) // RoomID => A set (map[string]bool) of state event IDs which are between the two positions stateNeeded := make(map[string]map[string]bool) for rows.Next() { var ( - streamPos int64 - eventBytes []byte - addIDs pq.StringArray - delIDs pq.StringArray + streamPos int64 + eventBytes []byte + excludeFromSync bool + addIDs pq.StringArray + delIDs pq.StringArray ) - if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil { return nil, nil, err } // Sanity check for deleted state and whine if we see it. We don't need to do anything @@ -192,9 +217,10 @@ func (s *outputRoomEventsStatements) selectStateInRange( } stateNeeded[ev.RoomID()] = needSet - eventIDToEvent[ev.EventID()] = streamEvent{ - Event: ev, - streamPosition: streamPos, + eventIDToEvent[ev.EventID()] = types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + ExcludeFromSync: excludeFromSync, } } @@ -221,7 +247,7 @@ func (s *outputRoomEventsStatements) selectMaxEventID( func (s *outputRoomEventsStatements) insertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string, - transactionID *api.TransactionID, + transactionID *api.TransactionID, excludeFromSync bool, ) (streamPos int64, err error) { var txnID *string var sessionID *int64 @@ -251,16 +277,53 @@ func (s *outputRoomEventsStatements) insertEvent( pq.StringArray(removeState), sessionID, txnID, + excludeFromSync, ).Scan(&streamPos) return } -// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. +// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. +// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude +// from sync. func (s *outputRoomEventsStatements) selectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos int64, limit int, -) ([]streamEvent, error) { - stmt := common.TxStmt(txn, s.selectRecentEventsStmt) + roomID string, fromPos, toPos types.StreamPosition, limit int, + chronologicalOrder bool, onlySyncEvents bool, +) ([]types.StreamEvent, error) { + var stmt *sql.Stmt + if onlySyncEvents { + stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt) + } else { + stmt = common.TxStmt(txn, s.selectRecentEventsStmt) + } + + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + return events, nil +} + +// selectEarlyEvents returns the earliest events in the given room, starting +// from a given position, up to a maximum of 'limit'. +func (s *outputRoomEventsStatements) selectEarlyEvents( + ctx context.Context, txn *sql.Tx, + roomID string, fromPos, toPos types.StreamPosition, limit int, +) ([]types.StreamEvent, error) { + stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) if err != nil { return nil, err @@ -274,16 +337,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // necessarily the way the SQL query returns them, so a sort is necessary to // ensure the events are in the right order in the slice. sort.SliceStable(events, func(i int, j int) bool { - return events[i].streamPosition < events[j].streamPosition + return events[i].StreamPosition < events[j].StreamPosition }) return events, nil } -// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing -// from the database. +// selectEvents returns the events for the given event IDs. If an event is +// missing from the database, it will be omitted. func (s *outputRoomEventsStatements) selectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { @@ -293,17 +356,18 @@ func (s *outputRoomEventsStatements) selectEvents( return rowsToStreamEvents(rows) } -func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { - var result []streamEvent +func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var result []types.StreamEvent for rows.Next() { var ( - streamPos int64 - eventBytes []byte - sessionID *int64 - txnID *string - transactionID *api.TransactionID + streamPos int64 + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID ) - if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { return nil, err } // TODO: Handle redacted events @@ -319,10 +383,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { } } - result = append(result, streamEvent{ - Event: ev, - streamPosition: streamPos, - transactionID: transactionID, + result = append(result, types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, }) } return result, nil diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go new file mode 100644 index 000000000..4362cd3f0 --- /dev/null +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -0,0 +1,187 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const outputRoomEventsTopologySchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( + -- The event ID for the event. + event_id TEXT PRIMARY KEY, + -- The place of the event in the room's topology. This can usually be determined + -- from the event's depth. + topological_position BIGINT NOT NULL, + -- The 'room_id' key for the event. + room_id TEXT NOT NULL +); +-- The topological order will be used in events selection and ordering +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); +` + +const insertEventInTopologySQL = "" + + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + + " VALUES ($1, $2, $3)" + +const selectEventIDsInRangeASCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position ASC LIMIT $4" + +const selectEventIDsInRangeDESCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position DESC LIMIT $4" + +const selectPositionInTopologySQL = "" + + "SELECT topological_position FROM syncapi_output_room_events_topology" + + " WHERE event_id = $1" + +const selectMaxPositionInTopologySQL = "" + + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1" + +const selectEventIDsFromPositionSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position = $2" + +type outputRoomEventsTopologyStatements struct { + insertEventInTopologyStmt *sql.Stmt + selectEventIDsInRangeASCStmt *sql.Stmt + selectEventIDsInRangeDESCStmt *sql.Stmt + selectPositionInTopologyStmt *sql.Stmt + selectMaxPositionInTopologyStmt *sql.Stmt + selectEventIDsFromPositionStmt *sql.Stmt +} + +func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(outputRoomEventsTopologySchema) + if err != nil { + return + } + if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { + return + } + if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { + return + } + if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { + return + } + if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { + return + } + if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { + return + } + if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { + return + } + return +} + +// insertEventInTopology inserts the given event in the room's topology, based +// on the event's depth. +func (s *outputRoomEventsTopologyStatements) insertEventInTopology( + ctx context.Context, event *gomatrixserverlib.Event, +) (err error) { + _, err = s.insertEventInTopologyStmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), + ) + return +} + +// selectEventIDsInRange selects the IDs of events which positions are within a +// given range in a given room's topological order. +// Returns an empty slice if no events match the given range. +func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( + ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, + limit int, chronologicalOrder bool, +) (eventIDs []string, err error) { + // Decide on the selection's order according to whether chronological order + // is requested or not. + var stmt *sql.Stmt + if chronologicalOrder { + stmt = s.selectEventIDsInRangeASCStmt + } else { + stmt = s.selectEventIDsInRangeDESCStmt + } + + // Query the event IDs. + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + + return +} + +// selectPositionInTopology returns the position of a given event in the +// topology of the room it belongs to. +func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( + ctx context.Context, eventID string, +) (pos types.StreamPosition, err error) { + err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) + return +} + +func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( + ctx context.Context, roomID string, +) (pos types.StreamPosition, err error) { + err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) + return +} + +// selectEventIDsFromPosition returns the IDs of all events that have a given +// position in the topology of a given room. +func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) (eventIDs []string, err error) { + // Query the event IDs. + rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 3a62d1364..207d28735 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -46,26 +46,21 @@ type stateDelta struct { membershipPos int64 } -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type streamEvent struct { - gomatrixserverlib.Event - streamPosition int64 - transactionID *api.TransactionID -} - -// SyncServerDatabase represents a sync server datasource which manages +// SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements - accountData accountDataStatements - events outputRoomEventsStatements - roomstate currentRoomStateStatements - invites inviteEventsStatements - typingCache *cache.TypingCache + accountData accountDataStatements + events outputRoomEventsStatements + roomstate currentRoomStateStatements + invites inviteEventsStatements + typingCache *cache.TypingCache + topology outputRoomEventsTopologyStatements + backwardExtremities backwardExtremitiesStatements } -// NewSyncServerDatabase creates a new sync server database +// NewSyncServerDatasource creates a new sync server database func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error @@ -87,6 +82,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er if err := d.invites.prepare(d.db); err != nil { return nil, err } + if err := d.topology.prepare(d.db); err != nil { + return nil, err + } + if err := d.backwardExtremities.prepare(d.db); err != nil { + return nil, err + } d.typingCache = cache.NewTypingCache() return &d, nil } @@ -109,7 +110,7 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([ // We don't include a device here as we only include transaction IDs in // incremental syncs. - return streamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(nil, streamEvents), nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -120,16 +121,57 @@ func (d *SyncServerDatasource) WriteEvent( ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, + transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition int64, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) + pos, err := d.events.insertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, + ) if err != nil { return err } pduPosition = pos + if err = d.topology.insertEventInTopology(ctx, ev); err != nil { + return err + } + + // If the event is already known as a backward extremity, don't consider + // it as such anymore now that we have it. + isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID()) + if err != nil { + return err + } + if isBackwardExtremity { + if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs()) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + } + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil @@ -196,11 +238,138 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } +// GetEventsInRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInRange( + ctx context.Context, + from, to *types.PaginationToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + // If the pagination token's type is types.PaginationTokenTypeTopology, the + // events must be retrieved from the rooms' topology table rather than the + // table contaning the syncapi server's whole stream of events. + if from.Type == types.PaginationTokenTypeTopology { + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.Position + forwardLimit = from.Position + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.Position + forwardLimit = to.Position + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return + } + + // If the pagination token's type is types.PaginationTokenTypeStream, the + // events must be retrieved from the table contaning the syncapi server's + // whole stream of events. + + if backwardOrdering { + // When using backward ordering, we want the most recent events first. + if events, err = d.events.selectRecentEvents( + ctx, nil, roomID, to.Position, from.Position, limit, false, false, + ); err != nil { + return + } + } else { + // When using forward ordering, we want the least recent events first. + if events, err = d.events.selectEarlyEvents( + ctx, nil, roomID, from.Position, to.Position, limit, + ); err != nil { + return + } + } + + return +} + // SyncPosition returns the latest positions for syncing. func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { return d.syncPositionTx(ctx, nil) } +// BackwardExtremitiesForRoom returns the event IDs of all of the backward +// extremities we know of for a given room. +func (d *SyncServerDatasource) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities []string, err error) { + return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) +} + +// MaxTopologicalPosition returns the highest topological position for a given +// room. +func (d *SyncServerDatasource) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.StreamPosition, error) { + return d.topology.selectMaxPositionInTopology(ctx, roomID) +} + +// EventsAtTopologicalPosition returns all of the events matching a given +// position in the topology of a given room. +func (d *SyncServerDatasource) EventsAtTopologicalPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) ([]types.StreamEvent, error) { + eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos) + if err != nil { + return nil, err + } + + return d.events.selectEvents(ctx, nil, eIDs) +} + +func (d *SyncServerDatasource) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.StreamPosition, error) { + return d.topology.selectPositionInTopology(ctx, eventID) +} + +// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. +func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { + return d.syncStreamPositionTx(ctx, nil) +} + +func (d *SyncServerDatasource) syncStreamPositionTx( + ctx context.Context, txn *sql.Tx, +) (types.StreamPosition, error) { + maxID, err := d.events.selectMaxEventID(ctx, txn) + if err != nil { + return 0, err + } + maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + if err != nil { + return 0, err + } + if maxAccountDataID > maxID { + maxID = maxAccountDataID + } + maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + if err != nil { + return 0, err + } + if maxInviteID > maxID { + maxID = maxInviteID + } + return types.StreamPosition(maxID), nil +} + func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, ) (sp types.SyncPosition, err error) { @@ -399,9 +568,16 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( defer common.EndTransaction(txn, &succeeded) // Get the current sync position which we will base the sync response on. - toPos, err = d.syncPositionTx(ctx, txn) + /* + toPos, err = d.syncPositionTx(ctx, txn) + if err != nil { + return + } + */ + // Get the current stream position which we will base the sync response on. + pos, err := d.syncStreamPositionTx(ctx, txn) if err != nil { - return + return nil, types.SyncPosition{}, []string{}, err } res = types.NewResponse(toPos) @@ -423,21 +599,38 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - var recentStreamEvents []streamEvent + var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, + ctx, txn, roomID, types.StreamPosition(0), pos, + numRecentEventsPerRoom, true, true, + //ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ) if err != nil { return } + // Retrieve the backward topology position, i.e. the position of the + // oldest event in the room's topology. + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + if err != nil { + return nil, types.SyncPosition{}, []string{}, err + } + if backwardTopologyPos-1 <= 0 { + backwardTopologyPos = types.StreamPosition(1) + } else { + backwardTopologyPos = backwardTopologyPos - 1 + } + // We don't include a device here as we don't need to send down // transaction IDs for complete syncs - recentEvents := streamEventsToEvents(nil, recentStreamEvents) - + recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, + ).String() + if prevPDUPos := recentStreamEvents[0].StreamPosition - 1; prevPDUPos > 0 { // Use the short form of batch token for prev_batch jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) } else { @@ -598,12 +791,13 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( endPos = delta.membershipPos } recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, + ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), + numRecentEventsPerRoom, true, true, ) if err != nil { return err } - recentEvents := streamEventsToEvents(device, recentStreamEvents) + recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back var prevPDUPos int64 @@ -618,18 +812,35 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // state events but no recent events. prevPDUPos = toPos - 1 } else { - prevPDUPos = recentStreamEvents[0].streamPosition - 1 + prevPDUPos = recentStreamEvents[0].StreamPosition - 1 } if prevPDUPos <= 0 { prevPDUPos = 1 } + // Retrieve the backward topology position, i.e. the position of the + // oldest event in the room's topology. + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + if err != nil { + return err + } + if backwardTopologyPos-1 <= 0 { + backwardTopologyPos = types.StreamPosition(1) + } else { + backwardTopologyPos = backwardTopologyPos - 1 + } + switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() + + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, + ).String() // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + //jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -640,8 +851,11 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() + lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeStream, backwardTopologyPos, + ).String() // Use the short form of batch token for prev_batch - lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + //lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -656,9 +870,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( func (d *SyncServerDatasource) fetchStateEvents( ctx context.Context, txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]streamEvent, -) (map[string][]streamEvent, error) { - stateBetween := make(map[string][]streamEvent) + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) missingEvents := make(map[string][]string) for roomID, ids := range roomIDToEventIDSet { events := stateBetween[roomID] @@ -700,7 +914,7 @@ func (d *SyncServerDatasource) fetchStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. events, err := d.events.selectEvents(ctx, txn, eventIDs) @@ -776,19 +990,25 @@ func (d *SyncServerDatasource) getStateDeltas( if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership == gomatrixserverlib.Join { // send full room state down instead of a delta - var s []streamEvent + var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, nil, err } + /* + s = make([]StreamEvent, len(allState)) + for i := 0; i < len(s); i++ { + s[i] = StreamEvent{Event: allState[i], StreamPosition: types.StreamPosition(0)} + } + */ state[roomID] = s continue // we'll add this room in when we do joined rooms } deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) break @@ -804,7 +1024,7 @@ func (d *SyncServerDatasource) getStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, state[joinedRoomID]), + stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } @@ -837,7 +1057,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( } deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, s), + stateEvents: d.StreamEventsToEvents(device, s), roomID: joinedRoomID, }) } @@ -858,8 +1078,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) } @@ -875,29 +1095,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrix.FilterPart, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, err } - s := make([]streamEvent, len(allState)) + s := make([]types.StreamEvent, len(allState)) for i := 0; i < len(s); i++ { - s[i] = streamEvent{Event: allState[i], streamPosition: 0} + s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0} } return s, nil } -// streamEventsToEvents converts streamEvent to Event. If device is non-nil and +// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. -func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { +func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event { out := make([]gomatrixserverlib.Event, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].Event - if device != nil && in[i].transactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID { + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( - "transaction_id", in[i].transactionID.TransactionID, + "transaction_id", in[i].TransactionID.TransactionID, ) if err != nil { logrus.WithFields(logrus.Fields{ diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 5db4b3a1b..9bd35e28a 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -33,7 +33,7 @@ type Database interface { common.PartitionStorer AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) - WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error) + WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (int64, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) SyncPosition(ctx context.Context) (types.SyncPosition, error) @@ -46,6 +46,11 @@ type Database interface { SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) AddTypingUser(userID, roomID string, expireTime *time.Time) int64 RemoveTypingUser(userID, roomID string) int64 + GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) + MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) + StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event } // NewPublicRoomsServerDatabase opens a database connection. diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index a5d2f60f4..07b3d3cfa 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -17,6 +17,7 @@ package sync import ( "context" "errors" + "fmt" "net/http" "strconv" "strings" @@ -32,6 +33,12 @@ import ( const defaultSyncTimeout = time.Duration(0) const defaultTimelineLimit = 20 +var ( + // ErrNotStreamToken is returned if a pagination token isn't of type + // types.PaginationTokenTypeStream + ErrNotStreamToken = fmt.Errorf("The provided pagination token has the wrong prefix (should be s)") +) + // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { ctx context.Context @@ -74,6 +81,23 @@ func getTimeout(timeoutMS string) time.Duration { return time.Duration(i) * time.Millisecond } +// getPaginationToken tries to parse a 'since' token taken from the API to a +// pagination token. If the string is empty then (nil, nil) is returned. +// Returns an error if the parsed token's type isn't types.PaginationTokenTypeStream. +func getPaginationToken(since string) (*types.StreamPosition, error) { + if since == "" { + return nil, nil + } + p, err := types.NewPaginationTokenFromString(since) + if err != nil { + return nil, err + } + if p.Type != types.PaginationTokenTypeStream { + return nil, ErrNotStreamToken + } + return &(p.Position), nil +} + // getSyncStreamPosition tries to parse a 'since' token taken from the API to a // types.SyncPosition. If the string is empty then (nil, nil) is returned. // There are two forms of tokens: The full length form containing all PDU and EDU diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 4738feea2..81987f5f1 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/syncapi/consumers" @@ -37,6 +39,8 @@ func SetupSyncAPIComponent( deviceDB *devices.Database, accountsDB *accounts.Database, queryAPI api.RoomserverQueryAPI, + federation *gomatrixserverlib.FederationClient, + cfg *config.Dendrite, ) { syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) if err != nil { @@ -77,5 +81,6 @@ func SetupSyncAPIComponent( logrus.WithError(err).Panicf("failed to start typing server consumer") } - routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) + routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, queryAPI, cfg) + //routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index af7ec865f..08a0239b6 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -16,11 +16,23 @@ package types import ( "encoding/json" + "fmt" "strconv" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" ) +var ( + // ErrInvalidPaginationTokenType is returned when an attempt at creating a + // new instance of PaginationToken with an invalid type (i.e. neither "s" + // nor "t"). + ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") +) + +// StreamPosition represents the offset in the sync stream a client is at. +type StreamPosition int64 + // SyncPosition contains the PDU and EDU stream sync positions for a client. type SyncPosition struct { // PDUPosition is the stream position for PDUs the client is at. @@ -29,6 +41,14 @@ type SyncPosition struct { TypingPosition int64 } +// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. +type StreamEvent struct { + gomatrixserverlib.Event + StreamPosition int64 + TransactionID *api.TransactionID + ExcludeFromSync bool +} + // String implements the Stringer interface. func (sp SyncPosition) String() string { return strconv.FormatInt(sp.PDUPosition, 10) + "_" + @@ -55,6 +75,68 @@ func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition { return ret } +// PaginationTokenType represents the type of a pagination token. +// It can be either "s" (representing a position in the whole stream of events) +// or "t" (representing a position in a room's topology/depth). +type PaginationTokenType string + +const ( + // PaginationTokenTypeStream represents a position in the server's whole + // stream of events + PaginationTokenTypeStream PaginationTokenType = "s" + // PaginationTokenTypeTopology represents a position in a room's topology. + PaginationTokenTypeTopology PaginationTokenType = "t" +) + +// PaginationToken represents a pagination token, used for interactions with +// /sync or /messages, for example. +type PaginationToken struct { + Position StreamPosition + Type PaginationTokenType +} + +// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" +// represents the type of a pagination token and "yyyy..." the token itself, and +// parses it in order to create a new instance of PaginationToken. Returns an +// error if the token couldn't be parsed into an int64, or if the token type +// isn't a known type (returns ErrInvalidPaginationTokenType in the latter +// case). +func NewPaginationTokenFromString(s string) (p *PaginationToken, err error) { + p = new(PaginationToken) + + // Parse the token (aka position). + position, err := strconv.ParseInt(s[1:], 10, 64) + if err != nil { + return + } + p.Position = StreamPosition(position) + + // Check if the type is among the known ones. + p.Type = PaginationTokenType(s[:1]) + if p.Type != PaginationTokenTypeStream && p.Type != PaginationTokenTypeTopology { + err = ErrInvalidPaginationTokenType + } + + return +} + +// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a +// StreamPosition and returns an instance of PaginationToken. +func NewPaginationTokenFromTypeAndPosition( + t PaginationTokenType, pos StreamPosition, +) (p *PaginationToken) { + return &PaginationToken{ + Type: t, + Position: pos, + } +} + +// String translates a PaginationToken to a string of the "xyyyy..." (see +// NewPaginationToken to know what it represents). +func (p *PaginationToken) String() string { + return fmt.Sprintf("%s%d", p.Type, p.Position) +} + // PrevEventRef represents a reference to a previous event in a state event upgrade type PrevEventRef struct { PrevContent json.RawMessage `json:"prev_content"` @@ -78,6 +160,17 @@ type Response struct { } `json:"rooms"` } +/* + +func NewResponse(pos StreamPosition) *Response { + res := Response{} + // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume + // we'll always return a stream token. + res.NextBatch = NewPaginationTokenFromTypeAndPosition(PaginationTokenTypeStream, pos).String() +} + +*/ + // NewResponse creates an empty response with initialised maps. func NewResponse(pos SyncPosition) *Response { res := Response{