diff --git a/syncapi/storage/postgres/relations_table.go b/syncapi/storage/postgres/relations_table.go index 0607c9c26..2b710a3a9 100644 --- a/syncapi/storage/postgres/relations_table.go +++ b/syncapi/storage/postgres/relations_table.go @@ -64,7 +64,7 @@ const selectRelationsInRangeDescSQL = "" + " ORDER BY id DESC LIMIT $7" const selectThreadsSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_relations.rel_type = 'm.thread'" + @@ -72,7 +72,7 @@ const selectThreadsSQL = "" + " ORDER BY syncapi_relations.id LIMIT $3" const selectThreadsWithSenderSQL = "" + - "SELECT syncapi_relations.id, syncapi_relations.child_event_id, syncapi_output_room_events.sender, syncapi_relations.event_id, syncapi_output_room_events.headered_event_json, syncapi_output_room_events.type FROM syncapi_relations" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + " WHERE syncapi_relations.room_id = $1" + " AND syncapi_output_room_events.sender = $2" + @@ -196,16 +196,12 @@ func (s *relationsStatements) SelectThreads( defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") var result []string var ( - id types.StreamPosition - childEventID string - sender string - eventId string - headeredEventJson string - eventType string + id types.StreamPosition + eventId string ) for rows.Next() { - if err = rows.Scan(&id, &childEventID, &sender, &eventId, &headeredEventJson, &eventType); err != nil { + if err = rows.Scan(&id, &eventId); err != nil { return nil, lastPos, err } if id > lastPos { diff --git a/syncapi/storage/sqlite3/relations_table.go b/syncapi/storage/sqlite3/relations_table.go index 7cbb5408f..512178c5b 100644 --- a/syncapi/storage/sqlite3/relations_table.go +++ b/syncapi/storage/sqlite3/relations_table.go @@ -64,11 +64,30 @@ const selectRelationsInRangeDescSQL = "" + const selectMaxRelationIDSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_relations" +const selectThreadsSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $2 AND" + + " ORDER BY syncapi_relations.id LIMIT $3" + +const selectThreadsWithSenderSQL = "" + + "SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" + + " JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" + + " WHERE syncapi_relations.room_id = $1" + + " AND syncapi_output_room_events.sender = $2" + + " AND syncapi_relations.rel_type = 'm.thread'" + + " AND syncapi_relations.id >= $3" + + " ORDER BY syncapi_relations.id LIMIT $4" + type relationsStatements struct { streamIDStatements *StreamIDStatements insertRelationStmt *sql.Stmt selectRelationsInRangeAscStmt *sql.Stmt selectRelationsInRangeDescStmt *sql.Stmt + selectThreadsStmt *sql.Stmt + selectThreadsWithSenderStmt *sql.Stmt deleteRelationStmt *sql.Stmt selectMaxRelationIDStmt *sql.Stmt } @@ -85,6 +104,8 @@ func NewSqliteRelationsTable(db *sql.DB, streamID *StreamIDStatements) (tables.R {&s.insertRelationStmt, insertRelationSQL}, {&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL}, {&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL}, + {&s.selectThreadsStmt, selectThreadsSQL}, + {&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL}, {&s.deleteRelationStmt, deleteRelationSQL}, {&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL}, }.Prepare(db) @@ -154,6 +175,49 @@ func (s *relationsStatements) SelectRelationsInRange( return result, lastPos, rows.Err() } +func (s *relationsStatements) SelectThreads( + ctx context.Context, + txn *sql.Tx, + roomID, userID string, + from types.StreamPosition, + limit uint64, +) ([]string, types.StreamPosition, error) { + var lastPos types.StreamPosition + var stmt *sql.Stmt + var rows *sql.Rows + var err error + + if userID == "" { + stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt) + rows, err = stmt.QueryContext(ctx, roomID, from, limit) + } else { + stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt) + rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit) + } + if err != nil { + return nil, lastPos, err + } + + defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed") + var result []string + var ( + id types.StreamPosition + eventId string + ) + + for rows.Next() { + if err = rows.Scan(&id, &eventId); err != nil { + return nil, lastPos, err + } + if id > lastPos { + lastPos = id + } + result = append(result, eventId) + } + + return result, lastPos, rows.Err() +} + func (s *relationsStatements) SelectMaxRelationID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) {