From faeb0b4ba0149eb49512283734997c502d3e4a24 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Wed, 31 Jul 2024 13:55:58 +0800 Subject: [PATCH] half done --- syncapi/routing/threads.go | 77 +++++++++++++++++++++ syncapi/storage/postgres/relations_table.go | 72 ++++++++++++++++++- syncapi/storage/shared/storage_sync.go | 63 +++++++++++++++++ syncapi/storage/tables/interface.go | 7 +- 4 files changed, 216 insertions(+), 3 deletions(-) create mode 100644 syncapi/routing/threads.go diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go new file mode 100644 index 000000000..6d0e2518a --- /dev/null +++ b/syncapi/routing/threads.go @@ -0,0 +1,77 @@ +package routing + +import ( + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/synctypes" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "net/http" + "strconv" +) + +type ThreadsResponse struct { + Chunk []synctypes.ClientEvent `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` +} + +func Threads( + req *http.Request, + device userapi.Device, + syncDB storage.Database, + rsAPI api.SyncRoomserverAPI, + rawRoomID string) util.JSONResponse { + var err error + roomID, err := spec.NewRoomID(rawRoomID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.InvalidParam("invalid room ID"), + } + } + + limit, err := strconv.ParseUint(req.URL.Query().Get("limit"), 10, 64) + if err != nil { + limit = 50 + } + if limit > 100 { + limit = 100 + } + + from := req.URL.Query().Get("from") + include := req.URL.Query().Get("include") + + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to get snapshot for relations") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + + res := &ThreadsResponse{ + Chunk: []synctypes.ClientEvent{}, + } + + if include == "participated" { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.Unknown("internal server error"), + } + } + var events []types.StreamEvent + events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( + req.Context(), roomID.String(), "", relType, eventType, from, to, dir == "b", limit, + ) + } +} diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 5a76e9c33..861eb3d3f 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -17,9 +17,10 @@ package postgres import ( "context" "database/sql" - + "encoding/json" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" + types2 "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -63,6 +64,23 @@ const selectRelationsInRangeDescSQL = "" + " AND id >= $5 AND id < $6" + " ORDER BY id DESC LIMIT $7" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2 AND syncapi_relations.id < $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3 AND syncapi_relations.id < $4" + + " ORDER BY syncapi_relations.id LIMIT $5" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" @@ -70,6 +88,8 @@ type relationsStatements struct { insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -84,6 +104,8 @@ func NewPostgresRelationsTable(db *sql.DB) (tables.Relations, error) { {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -149,6 +171,54 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + r types.Range, + limit int, +) ([]map[string]any, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, r.Low(), r.High(), limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []map[string]any + var ( + id types.StreamPosition + childEventID string + sender string + eventId string + headeredEventJson string + ) + + for rows.Next() { + if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + var event types2.HeaderedEvent + json.Unmarshal(event) + result = append(result) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index cd17fdc69..0c70025e3 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -811,3 +811,66 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } + +func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit int) ( + events []types.StreamEvent, prevBatch, nextBatch string, err error, +) { + r := types.Range{ + From: from, + } + + if r.From == 0 { + // If we're working backwards (dir=b) and there's no ?from= specified then + // we will automatically want to work backwards from the current position, + // so find out what that is. + if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil { + return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err) + } + // The result normally isn't inclusive of the event *at* the ?from= + // position, so add 1 here so that we include the most recent relation. + r.From++ + } + + // First look up any relations from the database. We add one to the limit here + // so that we can tell if we're overflowing, as we will only set the "next_batch" + // in the response if we are. + relations, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, limit+1) + if err != nil { + return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) + } + + // If we specified a relation type then just get those results, otherwise collate + // them from all of the returned relation types. + entries := []types.RelationEntry{} + for _, e := range relations { + entries = append(entries, e...) + } + + // If there were no entries returned, there were no relations, so stop at this point. + if len(entries) == 0 { + return nil, "", "", nil + } + + // Otherwise, let's try and work out what sensible prev_batch and next_batch values + // could be. We've requested an extra event by adding one to the limit already so + // that we can determine whether or not to provide a "next_batch", so trim off that + // event off the end if needs be. + if len(entries) > limit { + entries = entries[:len(entries)-1] + nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position) + } + // TODO: set prevBatch? doesn't seem to affect the tests... + + // Extract all of the event IDs from the relation entries so that we can pull the + // events out of the database. Then go and fetch the events. + eventIDs := make([]string, 0, len(entries)) + for _, entry := range entries { + eventIDs = append(eventIDs, entry.EventID) + } + events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) + if err != nil { + return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) + } + + return events, prevBatch, nextBatch, nil +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 45117d6d3..adf46a70f 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -223,10 +223,10 @@ type Presence interface { } type Relations interface { - // Inserts a relation which refers from the child event ID to the event ID in the given room. + // InsertRelation Inserts a relation which refers from the child event ID to the event ID in the given room. // If the relation already exists then this function will do nothing and return no error. InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error) - // Deletes a relation which already exists as the result of an event redaction. If the relation + // DeleteRelation Deletes a relation which already exists as the result of an event redaction. If the relation // does not exist then this function will do nothing and return no error. DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error // SelectRelationsInRange will return relations grouped by relation type within the given range. @@ -235,6 +235,9 @@ type Relations interface { // will be returned, inclusive of the "to" position but excluding the "from" position. The stream // position returned is the maximum position of the returned results. SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) + // SelectThreads this will find some threads from a room + // if userID is not empty then it will only include the threads that the user has participated + SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a // "from" or want to work forwards and don't have a "to").