diff --git a/roomserver/api/api.go b/roomserver/api/api.go index e6d37e8f1..1ee1bd449 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -166,6 +166,11 @@ type RoomserverInternalAPI interface { // PerformForget forgets a rooms history for a specific user PerformForget(ctx context.Context, req *PerformForgetRequest, resp *PerformForgetResponse) error + QueryEventsAfter( + ctx context.Context, + req *QueryEventsAfterEventIDRequest, + res *QueryEventsAfterEventIDesponse, + ) error // Asks for the default room version as preferred by the server. QueryRoomVersionCapabilities( diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 16f52abb7..a6c03a2ea 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -17,6 +17,11 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } +func (t *RoomserverInternalAPITrace) QueryEventsAfter(ctx context.Context, req *QueryEventsAfterEventIDRequest, res *QueryEventsAfterEventIDesponse) error { + util.GetLogger(ctx).Infof("QueryEventsAfter req=%+v res=%+v", js(req), js(res)) + return t.Impl.QueryEventsAfter(ctx, req, res) +} + func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.FederationInternalAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 96d6711c6..a937d1e40 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -101,6 +101,17 @@ type QueryEventsByIDResponse struct { Events []*gomatrixserverlib.HeaderedEvent `json:"events"` } +// QueryEventsByIDRequest is a request to QueryEventsByID +type QueryEventsAfterEventIDRequest struct { + // The event IDs to look up. + EventIDs string `json:"event_id"` +} + +// QueryEventsByIDResponse is a response to QueryEventsByID +type QueryEventsAfterEventIDesponse struct { + Events []*gomatrixserverlib.ClientEvent `json:"events"` +} + // QueryMembershipForUserRequest is a request to QueryMembership type QueryMembershipForUserRequest struct { // ID of the room to fetch membership from diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index fd963ad83..3823a3fa0 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -198,3 +198,11 @@ func (r *RoomserverInternalAPI) PerformForget( ) error { return r.Forgetter.PerformForget(ctx, req, resp) } + +func (r *RoomserverInternalAPI) QueryEventsAfter( + ctx context.Context, + req *api.QueryEventsAfterEventIDRequest, + res *api.QueryEventsAfterEventIDesponse, +) error { + return r.Queryer.QueryEventsAfter(ctx, req, res) +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index c8bbe7705..14b581c2f 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -724,3 +724,23 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq res.AuthChain = hchain return nil } + +func (r *Queryer) QueryEventsAfter( + ctx context.Context, + req *api.QueryEventsAfterEventIDRequest, + res *api.QueryEventsAfterEventIDesponse, +) error { + eventNIDs, err := r.DB.SelectPreviousEventNIDs(ctx, req.EventIDs) + if err != nil { + return err + } + events, err := r.DB.Events(ctx, eventNIDs) + if err != nil { + return err + } + for _, event := range events { + ev := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll) + res.Events = append(res.Events, &ev) + } + return nil +} diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index a61404efe..9e447dbb3 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -57,6 +57,7 @@ const ( RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers" RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom" RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" + RoomserverQueryEventsAfterPath = "/roomserver/queryEventsAfter" ) type httpRoomserverInternalAPI struct { @@ -534,5 +535,12 @@ func (h *httpRoomserverInternalAPI) PerformForget(ctx context.Context, req *api. apiURL := h.roomserverURL + RoomserverPerformForgetPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) - +} + +func (h *httpRoomserverInternalAPI) QueryEventsAfter(ctx context.Context, req *api.QueryEventsAfterEventIDRequest, res *api.QueryEventsAfterEventIDesponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryEventsAfter") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryEventsAfterPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 691a45830..65680eb71 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -464,4 +464,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(RoomserverQueryEventsAfterPath, + httputil.MakeInternalAPI("queryEventsAfterPath", func(req *http.Request) util.JSONResponse { + request := api.QueryEventsAfterEventIDRequest{} + response := api.QueryEventsAfterEventIDesponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := r.QueryEventsAfter(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index a9851e05b..2aded4288 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -64,6 +64,7 @@ type Database interface { // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + SelectPreviousEventNIDs(ctx context.Context, eventID string) ([]types.EventNID, error) // Look up snapshot NID for an event ID string SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index bd4e853eb..48ee8f2f1 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -59,9 +59,14 @@ const selectPreviousEventExistsSQL = "" + "SELECT 1 FROM roomserver_previous_events" + " WHERE previous_event_id = $1 AND previous_reference_sha256 = $2" +const selectPreviousEventNIDsSQL = "" + + "SELECT event_nids FROM roomserver_previous_events" + + " WHERE previous_event_id = $1" + type previousEventStatements struct { insertPreviousEventStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt + selectPreviousEventNIDsStmt *sql.Stmt } func createPrevEventsTable(db *sql.DB) error { @@ -75,6 +80,7 @@ func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { return s, sqlutil.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, }.Prepare(db) } @@ -101,3 +107,18 @@ func (s *previousEventStatements) SelectPreviousEventExists( stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } + +// SelectPreviousEventNIDs returns all eventNIDs for a given eventID +func (s *previousEventStatements) SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventNIDsStmt) + row := stmt.QueryRowContext(ctx, eventID) + var eventNIDs []uint8 + if err := row.Scan(&eventNIDs); err != nil { + return nil, err + } + result := []types.EventNID{} + for _, nid := range eventNIDs { + result = append(result, types.EventNID(nid)) + } + return result, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e96c77afa..762d735de 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1173,6 +1173,10 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +func (d *Database) SelectPreviousEventNIDs(ctx context.Context, eventID string) ([]types.EventNID, error) { + return d.PrevEventsTable.SelectPreviousEventNIDs(ctx, nil, eventID) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index 7304bf0d5..52af372b5 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strconv" "strings" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -53,7 +54,7 @@ const insertPreviousEventSQL = ` const selectPreviousEventNIDsSQL = ` SELECT event_nids FROM roomserver_previous_events - WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 + WHERE previous_event_id = $1 ` // Check if the event is referenced by another event in the table. @@ -129,3 +130,24 @@ func (s *previousEventStatements) SelectPreviousEventExists( stmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) } + +// SelectPreviousEventNIDs returns all eventNIDs for a given eventID +func (s *previousEventStatements) SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectPreviousEventNIDsStmt) + row := stmt.QueryRowContext(ctx, eventID) + var eventNIDs string + if err := row.Scan(&eventNIDs); err != nil { + return nil, err + } + result := []types.EventNID{} + nids := strings.Split(eventNIDs, ",") + for _, nid := range nids { + i, err := strconv.Atoi(nid) + if err != nil { + return nil, err + } + result = append(result, types.EventNID(i)) + } + + return result, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index fed39b944..05d6f9ef7 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -100,6 +100,7 @@ type PreviousEvents interface { // Check if the event reference exists // Returns sql.ErrNoRows if the event reference doesn't exist. SelectPreviousEventExists(ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error + SelectPreviousEventNIDs(ctx context.Context, txn *sql.Tx, eventID string) ([]types.EventNID, error) } type Invites interface {