diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go new file mode 100644 index 000000000..d72f1d12e --- /dev/null +++ b/syncapi/routing/relations.go @@ -0,0 +1,68 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "net/http" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/userapi/api" +) + +type RelationsResponse struct { + Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` +} + +// nolint:gocyclo +func Relations(req *http.Request, device *api.Device, syncDB storage.Database, roomID, eventID, relType, eventType string) util.JSONResponse { + dir := req.URL.Query().Get("dir") + from := req.URL.Query().Get("from") + to := req.URL.Query().Get("to") + limit := req.URL.Query().Get("limit") + + if dir != "b" && dir != "f" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + } + } + if dir == "" { + dir = "b" + } + + res := &RelationsResponse{} + + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + _, _, _, _ = from, to, limit, dir + + succeeded = true + return util.JSONResponse{ + Code: http.StatusOK, + JSON: res, + } +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 069dee81f..3217c72c3 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -110,6 +110,48 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/rooms/{roomId}/relations/{eventId}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Relations( + req, device, syncDB, + vars["roomId"], vars["eventId"], "", "", + ) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Relations( + req, device, syncDB, + vars["roomId"], vars["eventId"], vars["relType"], "", + ) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + v3mux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", + httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + + return Relations( + req, device, syncDB, + vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"], + ) + }), + ).Methods(http.MethodGet, http.MethodOptions) + v3mux.Handle("/search", httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { if !cfg.Fulltext.Enabled { diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 87064ed7c..876b5a3a1 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -51,12 +51,12 @@ const deleteRelationSQL = "" + const selectRelationsInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + " WHERE room_id = $1 AND event_id = $2 AND id > $3 AND id <= $4" + - " ORDER BY id DESC" + " ORDER BY id DESC LIMIT $5" const selectRelationsByTypeInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id > $4 AND id <= $5" + - " ORDER BY id DESC" + " ORDER BY id DESC LIMIT $6" const selectMaxRelationIDSQL = "" + "SELECT MAX(id) FROM syncapi_relations" @@ -117,17 +117,18 @@ func (s *relationsStatements) DeleteRelation( // SelectRelationsInRange returns a map rel_type -> []child_event_id func (s *relationsStatements) SelectRelationsInRange( - ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, r types.Range, + ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, + r types.Range, limit int, ) (map[string][]string, types.StreamPosition, error) { var lastPos types.StreamPosition var rows *sql.Rows var err error if relType != "" { stmt := sqlutil.TxStmt(txn, s.selectRelationsByTypeInRangeStmt) - rows, err = stmt.QueryContext(ctx, roomID, eventID, relType, r.Low(), r.High()) + rows, err = stmt.QueryContext(ctx, roomID, eventID, relType, r.Low(), r.High(), limit) } else { stmt := sqlutil.TxStmt(txn, s.selectRelationsInRangeStmt) - rows, err = stmt.QueryContext(ctx, roomID, eventID, r.Low(), r.High()) + rows, err = stmt.QueryContext(ctx, roomID, eventID, r.Low(), r.High(), limit) } if err != nil { return nil, lastPos, err diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index d5b5b3121..28e669d81 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -589,3 +589,40 @@ func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.Str func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { return d.Presence.GetMaxPresenceID(ctx, d.txn) } + +func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, limit int) ( + clientEvents []gomatrixserverlib.ClientEvent, prevBatch, nextBatch string, err error, +) { + // TODO: Nothing here is limited or setting prev_batch or next_batch + var eventIDs []string + r := types.Range{ + From: from, + To: to, + Backwards: from > to, + } + rels, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, r, limit+1) + if err != nil { + return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) + } + if relType != "" { + eventIDs = rels[relType] + } else { + for _, ids := range rels { + eventIDs = append(eventIDs, ids...) + } + } + events, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) + if err != nil { + return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) + } + for _, event := range events { + if eventType != "" && event.Type() != eventType { + continue + } + clientEvents = append( + clientEvents, + gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll), + ) + } + return clientEvents, prevBatch, nextBatch, nil +} diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 2aac4d0c7..02463668d 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -49,12 +49,12 @@ const deleteRelationSQL = "" + const selectRelationsInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + " WHERE room_id = $1 AND event_id = $2 AND id > $3 AND id <= $4" + - " ORDER BY id DESC" + " ORDER BY id DESC LIMIT $5" const selectRelationsByTypeInRangeSQL = "" + "SELECT id, room_id, child_event_id, rel_type FROM syncapi_relations" + " WHERE room_id = $1 AND event_id = $2 AND rel_type = $3 AND id > $4 AND id <= $5" + - " ORDER BY id DESC" + " ORDER BY id DESC LIMIT $6" const selectMaxRelationIDSQL = "" + "SELECT MAX(id) FROM syncapi_relations" @@ -121,17 +121,18 @@ func (s *relationsStatements) DeleteRelation( // SelectRelationsInRange returns a map rel_type -> []child_event_id func (s *relationsStatements) SelectRelationsInRange( - ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, r types.Range, + ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, + r types.Range, limit int, ) (map[string][]string, types.StreamPosition, error) { var lastPos types.StreamPosition var rows *sql.Rows var err error if relType != "" { stmt := sqlutil.TxStmt(txn, s.selectRelationsByTypeInRangeStmt) - rows, err = stmt.QueryContext(ctx, roomID, eventID, relType, r.Low(), r.High()) + rows, err = stmt.QueryContext(ctx, roomID, eventID, relType, r.Low(), r.High(), limit) } else { stmt := sqlutil.TxStmt(txn, s.selectRelationsInRangeStmt) - rows, err = stmt.QueryContext(ctx, roomID, eventID, r.Low(), r.High()) + rows, err = stmt.QueryContext(ctx, roomID, eventID, r.Low(), r.High(), limit) } if err != nil { return nil, lastPos, err diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a227f4563..1311ef22c 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -210,6 +210,6 @@ type Presence interface { type Relations interface { InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, relType string) (streamPos types.StreamPosition, err error) DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error - SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, r types.Range) (map[string][]string, types.StreamPosition, error) + SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType string, r types.Range, limit int) (map[string][]string, types.StreamPosition, error) SelectMaxRelationID(ctx context.Context, txn *sql.Tx) (id int64, err error) }