From 9934f9543b6d7a57a7569a10346b91223824e862 Mon Sep 17 00:00:00 2001 From: qwqtoday Date: Fri, 9 Aug 2024 15:24:12 +0800 Subject: [PATCH] closer to complete --- syncapi/routing/threads.go | 42 +++++++++++++++++---- syncapi/storage/postgres/relations_table.go | 34 ++++++++--------- syncapi/storage/shared/storage_sync.go | 39 +++---------------- syncapi/storage/tables/interface.go | 2 +- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/syncapi/routing/threads.go b/syncapi/routing/threads.go index 6d0e2518a..175ece197 100644 --- a/syncapi/routing/threads.go +++ b/syncapi/routing/threads.go @@ -1,6 +1,10 @@ package routing import ( + rstypes "github.com/matrix-org/dendrite/roomserver/types" + "net/http" + "strconv" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" @@ -10,8 +14,6 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "net/http" - "strconv" ) type ThreadsResponse struct { @@ -42,7 +44,13 @@ func Threads( limit = 100 } - from := req.URL.Query().Get("from") + var from types.StreamPosition + if f := req.URL.Query().Get("from"); f != "" { + if from, err = types.NewStreamPositionFromString(f); err != nil { + return util.ErrorResponse(err) + } + } + include := req.URL.Query().Get("include") snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) @@ -60,8 +68,9 @@ func Threads( Chunk: []synctypes.ClientEvent{}, } + var userID string if include == "participated" { - userID, err := spec.NewUserID(device.UserID, true) + _, err := spec.NewUserID(device.UserID, true) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid") return util.JSONResponse{ @@ -69,9 +78,26 @@ func Threads( JSON: spec.Unknown("internal server error"), } } - var events []types.StreamEvent - events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( - req.Context(), roomID.String(), "", relType, eventType, from, to, dir == "b", limit, - ) + userID = device.UserID + } else { + userID = "" + } + var headeredEvents []*rstypes.HeaderedEvent + headeredEvents, _, res.NextBatch, err = snapshot.ThreadsFor( + req.Context(), roomID.String(), userID, from, limit, + ) + + for _, event := range headeredEvents { + ce, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) + if err != nil { + return util.ErrorResponse(err) + } + res.Chunk = append(res.Chunk, *ce) + } + + return util.JSONResponse{ + JSON: res, } } diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 861eb3d3f..0607c9c26 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -17,10 +17,9 @@ package postgres import ( "context" "database/sql" - "encoding/json" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - types2 "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -65,21 +64,21 @@ const selectRelationsInRangeDescSQL = "" + " ORDER BY id DESC LIMIT $7" const selectThreadsSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $2 AND syncapi_relations.id < $3" + - " ORDER BY syncapi_relations.id LIMIT $4" + " AND syncapi_relations.id >= $2 AND" + + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_output_room_events.sender = $2" + " AND syncapi_relations.rel_type = 'm.thread'" + - " AND syncapi_relations.id >= $3 AND syncapi_relations.id < $4" + - " ORDER BY syncapi_relations.id LIMIT $5" + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" @@ -175,9 +174,9 @@ func (s *relationsStatements) SelectThreads( ctx context.Context, txn *sql.Tx, roomID, userID string, - r types.Range, - limit int, -) ([]map[string]any, types.StreamPosition, error) { + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { var lastPos types.StreamPosition var stmt *sql.Stmt var rows *sql.Rows @@ -185,35 +184,34 @@ func (s *relationsStatements) SelectThreads( if userID == "" { stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) - rows, err = stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) } else { stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) - rows, err = stmt.QueryContext(ctx, roomID, userID, r.Low(), r.High(), limit) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) } if err != nil { return nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") - var result []map[string]any + var result []string var ( id types.StreamPosition childEventID string sender string eventId string headeredEventJson string + eventType string ) for rows.Next() { - if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson); err != nil { + if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson, &eventType); err != nil { return nil, lastPos, err } if id > lastPos { lastPos = id } - var event types2.HeaderedEvent - json.Unmarshal(event) - result = append(result) + result = append(result, eventId) } return result, lastPos, rows.Err() diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 0c70025e3..a7d783d6f 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -812,8 +812,8 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, return events, prevBatch, nextBatch, nil } -func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit int) ( - events []types.StreamEvent, prevBatch, nextBatch string, err error, +func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit uint64) ( + events []*rstypes.HeaderedEvent, prevBatch, nextBatch string, err error, ) { r := types.Range{ From: from, @@ -831,43 +831,16 @@ func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID str r.From++ } - // First look up any relations from the database. We add one to the limit here + // First look up any threads from the database. We add one to the limit here // so that we can tell if we're overflowing, as we will only set the "next_batch" // in the response if we are. - relations, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, limit+1) + eventIDs, _, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1) + if err != nil { return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err) } - // If we specified a relation type then just get those results, otherwise collate - // them from all of the returned relation types. - entries := []types.RelationEntry{} - for _, e := range relations { - entries = append(entries, e...) - } - - // If there were no entries returned, there were no relations, so stop at this point. - if len(entries) == 0 { - return nil, "", "", nil - } - - // Otherwise, let's try and work out what sensible prev_batch and next_batch values - // could be. We've requested an extra event by adding one to the limit already so - // that we can determine whether or not to provide a "next_batch", so trim off that - // event off the end if needs be. - if len(entries) > limit { - entries = entries[:len(entries)-1] - nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position) - } - // TODO: set prevBatch? doesn't seem to affect the tests... - - // Extract all of the event IDs from the relation entries so that we can pull the - // events out of the database. Then go and fetch the events. - eventIDs := make([]string, 0, len(entries)) - for _, entry := range entries { - eventIDs = append(eventIDs, entry.EventID) - } - events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true) + events, err = d.Events(ctx, eventIDs) if err != nil { return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index adf46a70f..5470349a0 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -237,7 +237,7 @@ type Relations interface { SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) // SelectThreads this will find some threads from a room // if userID is not empty then it will only include the threads that the user has participated - SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error) + SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, from types.StreamPosition, limit uint64) ([]string, types.StreamPosition, error) // SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries // should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a // "from" or want to work forwards and don't have a "to").