From a9fc1ab0f230103efac5b6d8e83297b7cbbee340 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Mon, 28 Feb 2022 17:41:15 +0000 Subject: [PATCH] msc2946: Make TestClientSpacesSummary pass --- go.mod | 1 + setup/mscs/msc2946/msc2946.go | 165 +++++++++++++++++++++------------- 2 files changed, 106 insertions(+), 60 deletions(-) diff --git a/go.mod b/go.mod index 81d999814..387d52fb7 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/frankban/quicktest v1.14.0 // indirect github.com/getsentry/sentry-go v0.12.0 github.com/gologme/log v1.3.0 + github.com/google/uuid v1.2.0 // indirect github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/h2non/filetype v1.1.3 // indirect diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 3f3bf4d4f..9e10e557c 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/clientapi/jsonerror" fs "github.com/matrix-org/dendrite/federationapi/api" @@ -86,7 +87,6 @@ func federatedSpacesHandler( rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) u, err := url.Parse(fedReq.RequestURI()) if err != nil { return util.JSONResponse{ @@ -106,9 +106,10 @@ func federatedSpacesHandler( // This is somewhat equivalent to a Client-Server request with a max_depth=1. maxDepth: 1, - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, + rsAPI: rsAPI, + fsAPI: fsAPI, + // inline cache as we don't have pagination in federation mode + paginationCache: make(map[string]paginationInfo), } return w.walk() } @@ -117,8 +118,11 @@ func spacesHandler( rsAPI roomserver.RoomserverInternalAPI, fsAPI fs.FederationInternalAPI, thisServer gomatrixserverlib.ServerName, ) func(*http.Request, *userapi.Device) util.JSONResponse { + // declared outside the returned handler so it persists between calls + // TODO: clear based on... time? + paginationCache := make(map[string]paginationInfo) + return func(req *http.Request, device *userapi.Device) util.JSONResponse { - inMemoryBatchCache := make(map[string]set) // Extract the room ID from the request. Sanity check request data. params, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { @@ -126,37 +130,43 @@ func spacesHandler( } roomID := params["roomID"] w := walker{ - suggestedOnly: req.URL.Query().Get("suggested_only") == "true", - limit: parseInt(req.URL.Query().Get("limit"), 1000), - maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), - rootRoomID: roomID, - caller: device, - thisServer: thisServer, - ctx: req.Context(), + suggestedOnly: req.URL.Query().Get("suggested_only") == "true", + limit: parseInt(req.URL.Query().Get("limit"), 1000), + maxDepth: parseInt(req.URL.Query().Get("max_depth"), -1), + paginationToken: req.URL.Query().Get("from"), + rootRoomID: roomID, + caller: device, + thisServer: thisServer, + ctx: req.Context(), - rsAPI: rsAPI, - fsAPI: fsAPI, - inMemoryBatchCache: inMemoryBatchCache, + rsAPI: rsAPI, + fsAPI: fsAPI, + paginationCache: paginationCache, } return w.walk() } } -type walker struct { - rootRoomID string - caller *userapi.Device - serverName gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - rsAPI roomserver.RoomserverInternalAPI - fsAPI fs.FederationInternalAPI - ctx context.Context - suggestedOnly bool - limit int - maxDepth int +type paginationInfo struct { + processed set + unvisited []roomVisit +} - // user ID|device ID|batch_num => event/room IDs sent to client - inMemoryBatchCache map[string]set - mu sync.Mutex +type walker struct { + rootRoomID string + caller *userapi.Device + serverName gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + rsAPI roomserver.RoomserverInternalAPI + fsAPI fs.FederationInternalAPI + ctx context.Context + suggestedOnly bool + limit int + maxDepth int + paginationToken string + + paginationCache map[string]paginationInfo + mu sync.Mutex } func (w *walker) callerID() string { @@ -166,25 +176,26 @@ func (w *walker) callerID() string { return string(w.serverName) } -func (w *walker) alreadySent(id string) bool { - w.mu.Lock() - defer w.mu.Unlock() - m, ok := w.inMemoryBatchCache[w.callerID()] - if !ok { - return false +func (w *walker) newPaginationCache() (string, paginationInfo) { + p := paginationInfo{ + processed: make(set), + unvisited: nil, } - return m[id] + tok := uuid.NewString() + return tok, p } -func (w *walker) markSent(id string) { +func (w *walker) loadPaginationCache(paginationToken string) *paginationInfo { w.mu.Lock() defer w.mu.Unlock() - m := w.inMemoryBatchCache[w.callerID()] - if m == nil { - m = make(set) - } - m[id] = true - w.inMemoryBatchCache[w.callerID()] = m + p := w.paginationCache[paginationToken] + return &p +} + +func (w *walker) storePaginationCache(paginationToken string, cache paginationInfo) { + w.mu.Lock() + defer w.mu.Unlock() + w.paginationCache[paginationToken] = cache } type roomVisit struct { @@ -212,13 +223,30 @@ func (w *walker) walk() util.JSONResponse { var discoveredRooms []gomatrixserverlib.MSC2946Room - // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + var cache *paginationInfo + if w.paginationToken != "" { + cache = w.loadPaginationCache(w.paginationToken) + if cache == nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.InvalidArgumentValue("invalid from"), + } + } + } else { + tok, c := w.newPaginationCache() + cache = &c + w.paginationToken = tok + // Begin walking the graph starting with the room ID in the request in a queue of unvisited rooms + c.unvisited = append(c.unvisited, roomVisit{ + roomID: w.rootRoomID, + depth: 0, + }) + } + + processed := cache.processed + unvisited := cache.unvisited + // Depth first -> stack data structure - unvisited := []roomVisit{{ - roomID: w.rootRoomID, - depth: 0, - }} - processed := make(set) for len(unvisited) > 0 { if len(discoveredRooms) >= w.limit { break @@ -229,11 +257,12 @@ func (w *walker) walk() util.JSONResponse { unvisited = unvisited[:len(unvisited)-1] // If this room has already been processed, skip. // If this room exceeds the specified depth, skip. - if processed[rv.roomID] || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { + if processed.isSet(rv.roomID) || rv.roomID == "" || (w.maxDepth > 0 && rv.depth > w.maxDepth) { continue } // Mark this room as processed. - processed[rv.roomID] = true + + processed.set(rv.roomID) // if this room is not a space room, skip. var roomType string @@ -278,13 +307,6 @@ func (w *walker) walk() util.JSONResponse { } } - // mark processed rooms for pagination purposes - for _, room := range discoveredRooms { - if !w.alreadySent(room.RoomID) { - w.markSent(room.RoomID) - } - } - // don't walk the children // if the parent is not a space room if roomType != ConstCreateEventContentValueSpace { @@ -309,12 +331,27 @@ func (w *walker) walk() util.JSONResponse { }) } } + + if len(unvisited) > 0 { + // we still have more rooms so we need to send back a pagination token, + // we probably hit a room limit + cache.processed = processed + cache.unvisited = unvisited + w.storePaginationCache(w.paginationToken, *cache) + } else { + // clear the pagination token so we don't send it back to the client + // Note we do NOT nuke the cache just in case this response is lost + // and the client retries it. + w.paginationToken = "" + } + if w.caller != nil { // return CS API format return util.JSONResponse{ Code: 200, JSON: MSC2946ClientResponse{ - Rooms: discoveredRooms, + Rooms: discoveredRooms, + NextBatch: w.paginationToken, }, } } @@ -548,7 +585,15 @@ func (w *walker) childReferences(roomID string) ([]gomatrixserverlib.MSC2946Stri return el, nil } -type set map[string]bool +type set map[string]struct{} + +func (s set) set(val string) { + s[val] = struct{}{} +} +func (s set) isSet(val string) bool { + _, ok := s[val] + return ok +} func stripped(ev *gomatrixserverlib.Event) *gomatrixserverlib.MSC2946StrippedEvent { if ev.StateKey() == nil {