mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-06 06:23:10 -06:00
Merge 1b4fc3728f into 084181332b
This commit is contained in:
commit
7793b2d6f9
|
|
@ -157,6 +157,18 @@ func Setup(
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
v1unstablemux.Handle("/rooms/{roomId}/threads",
|
||||||
|
httputil.MakeAuthAPI("threads", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Threads(
|
||||||
|
req, device, syncDB, rsAPI, vars["roomId"],
|
||||||
|
)
|
||||||
|
})).Methods(http.MethodGet)
|
||||||
|
|
||||||
v3mux.Handle("/search",
|
v3mux.Handle("/search",
|
||||||
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if !cfg.Fulltext.Enabled {
|
if !cfg.Fulltext.Enabled {
|
||||||
|
|
@ -200,4 +212,5 @@ func Setup(
|
||||||
return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at)
|
return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at)
|
||||||
}, httputil.WithAllowGuests()),
|
}, httputil.WithAllowGuests()),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
107
syncapi/routing/threads.go
Normal file
107
syncapi/routing/threads.go
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/synctypes"
|
||||||
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ThreadsResponse struct {
|
||||||
|
Chunk []synctypes.ClientEvent `json:"chunk"`
|
||||||
|
NextBatch string `json:"next_batch,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Threads(
|
||||||
|
req *http.Request,
|
||||||
|
device *userapi.Device,
|
||||||
|
syncDB storage.Database,
|
||||||
|
rsAPI api.SyncRoomserverAPI,
|
||||||
|
rawRoomID string) util.JSONResponse {
|
||||||
|
var err error
|
||||||
|
roomID, err := spec.NewRoomID(rawRoomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.InvalidParam("invalid room ID"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, err := strconv.ParseUint(req.URL.Query().Get("limit"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
if limit > 100 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
var from types.StreamPosition
|
||||||
|
if f := req.URL.Query().Get("from"); f != "" {
|
||||||
|
if from, err = types.NewStreamPositionFromString(f); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
include := req.URL.Query().Get("include")
|
||||||
|
|
||||||
|
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to get snapshot for relations")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var succeeded bool
|
||||||
|
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
|
||||||
|
|
||||||
|
res := &ThreadsResponse{
|
||||||
|
Chunk: []synctypes.ClientEvent{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
if include == "participated" {
|
||||||
|
_, err := spec.NewUserID(device.UserID, true)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.Unknown("internal server error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userID = device.UserID
|
||||||
|
} else {
|
||||||
|
userID = ""
|
||||||
|
}
|
||||||
|
var headeredEvents []*rstypes.HeaderedEvent
|
||||||
|
headeredEvents, _, res.NextBatch, err = snapshot.ThreadsFor(
|
||||||
|
req.Context(), roomID.String(), userID, from, limit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range headeredEvents {
|
||||||
|
ce, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
res.Chunk = append(res.Chunk, *ce)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -63,6 +63,23 @@ const selectRelationsInRangeDescSQL = "" +
|
||||||
" AND id >= $5 AND id < $6" +
|
" AND id >= $5 AND id < $6" +
|
||||||
" ORDER BY id DESC LIMIT $7"
|
" ORDER BY id DESC LIMIT $7"
|
||||||
|
|
||||||
|
const selectThreadsSQL = "" +
|
||||||
|
"SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" +
|
||||||
|
" JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" +
|
||||||
|
" WHERE syncapi_relations.room_id = $1" +
|
||||||
|
" AND syncapi_relations.rel_type = 'm.thread'" +
|
||||||
|
" AND syncapi_relations.id >= $2" +
|
||||||
|
" ORDER BY syncapi_relations.id LIMIT $3"
|
||||||
|
|
||||||
|
const selectThreadsWithSenderSQL = "" +
|
||||||
|
"SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" +
|
||||||
|
" JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" +
|
||||||
|
" WHERE syncapi_relations.room_id = $1" +
|
||||||
|
" AND syncapi_output_room_events.sender = $2" +
|
||||||
|
" AND syncapi_relations.rel_type = 'm.thread'" +
|
||||||
|
" AND syncapi_relations.id >= $3" +
|
||||||
|
" ORDER BY syncapi_relations.id LIMIT $4"
|
||||||
|
|
||||||
const selectMaxRelationIDSQL = "" +
|
const selectMaxRelationIDSQL = "" +
|
||||||
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
|
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
|
||||||
|
|
||||||
|
|
@ -70,6 +87,8 @@ type relationsStatements struct {
|
||||||
insertRelationStmt *sql.Stmt
|
insertRelationStmt *sql.Stmt
|
||||||
selectRelationsInRangeAscStmt *sql.Stmt
|
selectRelationsInRangeAscStmt *sql.Stmt
|
||||||
selectRelationsInRangeDescStmt *sql.Stmt
|
selectRelationsInRangeDescStmt *sql.Stmt
|
||||||
|
selectThreadsStmt *sql.Stmt
|
||||||
|
selectThreadsWithSenderStmt *sql.Stmt
|
||||||
deleteRelationStmt *sql.Stmt
|
deleteRelationStmt *sql.Stmt
|
||||||
selectMaxRelationIDStmt *sql.Stmt
|
selectMaxRelationIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
@ -84,6 +103,8 @@ func NewPostgresRelationsTable(db *sql.DB) (tables.Relations, error) {
|
||||||
{&s.insertRelationStmt, insertRelationSQL},
|
{&s.insertRelationStmt, insertRelationSQL},
|
||||||
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
|
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
|
||||||
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
|
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
|
||||||
|
{&s.selectThreadsStmt, selectThreadsSQL},
|
||||||
|
{&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL},
|
||||||
{&s.deleteRelationStmt, deleteRelationSQL},
|
{&s.deleteRelationStmt, deleteRelationSQL},
|
||||||
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
|
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
|
|
@ -149,6 +170,49 @@ func (s *relationsStatements) SelectRelationsInRange(
|
||||||
return result, lastPos, rows.Err()
|
return result, lastPos, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *relationsStatements) SelectThreads(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomID, userID string,
|
||||||
|
from types.StreamPosition,
|
||||||
|
limit uint64,
|
||||||
|
) ([]string, types.StreamPosition, error) {
|
||||||
|
var lastPos types.StreamPosition
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx, roomID, from, limit)
|
||||||
|
} else {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, lastPos, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed")
|
||||||
|
var result []string
|
||||||
|
var (
|
||||||
|
id types.StreamPosition
|
||||||
|
eventId string
|
||||||
|
)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&id, &eventId); err != nil {
|
||||||
|
return nil, lastPos, err
|
||||||
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
|
}
|
||||||
|
result = append(result, eventId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, lastPos, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *relationsStatements) SelectMaxRelationID(
|
func (s *relationsStatements) SelectMaxRelationID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
|
|
|
||||||
|
|
@ -811,3 +811,39 @@ func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID,
|
||||||
|
|
||||||
return events, prevBatch, nextBatch, nil
|
return events, prevBatch, nextBatch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *DatabaseTransaction) ThreadsFor(ctx context.Context, roomID, userID string, from types.StreamPosition, limit uint64) (
|
||||||
|
events []*rstypes.HeaderedEvent, prevBatch, nextBatch string, err error,
|
||||||
|
) {
|
||||||
|
r := types.Range{
|
||||||
|
From: from,
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.From == 0 {
|
||||||
|
// If we're working backwards (dir=b) and there's no ?from= specified then
|
||||||
|
// we will automatically want to work backwards from the current position,
|
||||||
|
// so find out what that is.
|
||||||
|
if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
|
||||||
|
}
|
||||||
|
// The result normally isn't inclusive of the event *at* the ?from=
|
||||||
|
// position, so add 1 here so that we include the most recent relation.
|
||||||
|
r.From++
|
||||||
|
}
|
||||||
|
|
||||||
|
// First look up any threads from the database. We add one to the limit here
|
||||||
|
// so that we can tell if we're overflowing, as we will only set the "next_batch"
|
||||||
|
// in the response if we are.
|
||||||
|
eventIDs, pos, err := d.Relations.SelectThreads(ctx, d.txn, roomID, userID, from, limit+1)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err = d.Events(ctx, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, prevBatch, fmt.Sprintf("%d", pos), nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,11 +64,30 @@ const selectRelationsInRangeDescSQL = "" +
|
||||||
const selectMaxRelationIDSQL = "" +
|
const selectMaxRelationIDSQL = "" +
|
||||||
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
|
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
|
||||||
|
|
||||||
|
const selectThreadsSQL = "" +
|
||||||
|
"SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" +
|
||||||
|
" JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" +
|
||||||
|
" WHERE syncapi_relations.room_id = $1" +
|
||||||
|
" AND syncapi_relations.rel_type = 'm.thread'" +
|
||||||
|
" AND syncapi_relations.id >= $2" +
|
||||||
|
" ORDER BY syncapi_relations.id LIMIT $3"
|
||||||
|
|
||||||
|
const selectThreadsWithSenderSQL = "" +
|
||||||
|
"SELECT syncapi_relations.id, syncapi_relations.event_id FROM syncapi_relations" +
|
||||||
|
" JOIN syncapi_output_room_events ON syncapi_output_room_events.event_id = syncapi_relations.event_id" +
|
||||||
|
" WHERE syncapi_relations.room_id = $1" +
|
||||||
|
" AND syncapi_output_room_events.sender = $2" +
|
||||||
|
" AND syncapi_relations.rel_type = 'm.thread'" +
|
||||||
|
" AND syncapi_relations.id >= $3" +
|
||||||
|
" ORDER BY syncapi_relations.id LIMIT $4"
|
||||||
|
|
||||||
type relationsStatements struct {
|
type relationsStatements struct {
|
||||||
streamIDStatements *StreamIDStatements
|
streamIDStatements *StreamIDStatements
|
||||||
insertRelationStmt *sql.Stmt
|
insertRelationStmt *sql.Stmt
|
||||||
selectRelationsInRangeAscStmt *sql.Stmt
|
selectRelationsInRangeAscStmt *sql.Stmt
|
||||||
selectRelationsInRangeDescStmt *sql.Stmt
|
selectRelationsInRangeDescStmt *sql.Stmt
|
||||||
|
selectThreadsStmt *sql.Stmt
|
||||||
|
selectThreadsWithSenderStmt *sql.Stmt
|
||||||
deleteRelationStmt *sql.Stmt
|
deleteRelationStmt *sql.Stmt
|
||||||
selectMaxRelationIDStmt *sql.Stmt
|
selectMaxRelationIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
@ -85,6 +104,8 @@ func NewSqliteRelationsTable(db *sql.DB, streamID *StreamIDStatements) (tables.R
|
||||||
{&s.insertRelationStmt, insertRelationSQL},
|
{&s.insertRelationStmt, insertRelationSQL},
|
||||||
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
|
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
|
||||||
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
|
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
|
||||||
|
{&s.selectThreadsStmt, selectThreadsSQL},
|
||||||
|
{&s.selectThreadsWithSenderStmt, selectThreadsWithSenderSQL},
|
||||||
{&s.deleteRelationStmt, deleteRelationSQL},
|
{&s.deleteRelationStmt, deleteRelationSQL},
|
||||||
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
|
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
|
|
@ -154,6 +175,49 @@ func (s *relationsStatements) SelectRelationsInRange(
|
||||||
return result, lastPos, rows.Err()
|
return result, lastPos, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *relationsStatements) SelectThreads(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomID, userID string,
|
||||||
|
from types.StreamPosition,
|
||||||
|
limit uint64,
|
||||||
|
) ([]string, types.StreamPosition, error) {
|
||||||
|
var lastPos types.StreamPosition
|
||||||
|
var stmt *sql.Stmt
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectThreadsStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx, roomID, from, limit)
|
||||||
|
} else {
|
||||||
|
stmt = sqlutil.TxStmt(txn, s.selectThreadsWithSenderStmt)
|
||||||
|
rows, err = stmt.QueryContext(ctx, roomID, userID, from, limit)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, lastPos, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectThreads: rows.close() failed")
|
||||||
|
var result []string
|
||||||
|
var (
|
||||||
|
id types.StreamPosition
|
||||||
|
eventId string
|
||||||
|
)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&id, &eventId); err != nil {
|
||||||
|
return nil, lastPos, err
|
||||||
|
}
|
||||||
|
if id > lastPos {
|
||||||
|
lastPos = id
|
||||||
|
}
|
||||||
|
result = append(result, eventId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, lastPos, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *relationsStatements) SelectMaxRelationID(
|
func (s *relationsStatements) SelectMaxRelationID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
|
|
|
||||||
|
|
@ -223,10 +223,10 @@ type Presence interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Relations interface {
|
type Relations interface {
|
||||||
// Inserts a relation which refers from the child event ID to the event ID in the given room.
|
// InsertRelation Inserts a relation which refers from the child event ID to the event ID in the given room.
|
||||||
// If the relation already exists then this function will do nothing and return no error.
|
// If the relation already exists then this function will do nothing and return no error.
|
||||||
InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error)
|
InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error)
|
||||||
// Deletes a relation which already exists as the result of an event redaction. If the relation
|
// DeleteRelation Deletes a relation which already exists as the result of an event redaction. If the relation
|
||||||
// does not exist then this function will do nothing and return no error.
|
// does not exist then this function will do nothing and return no error.
|
||||||
DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error
|
DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error
|
||||||
// SelectRelationsInRange will return relations grouped by relation type within the given range.
|
// SelectRelationsInRange will return relations grouped by relation type within the given range.
|
||||||
|
|
@ -235,6 +235,9 @@ type Relations interface {
|
||||||
// will be returned, inclusive of the "to" position but excluding the "from" position. The stream
|
// will be returned, inclusive of the "to" position but excluding the "from" position. The stream
|
||||||
// position returned is the maximum position of the returned results.
|
// position returned is the maximum position of the returned results.
|
||||||
SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error)
|
SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error)
|
||||||
|
// SelectThreads will find threads from a room, if userID is not empty
|
||||||
|
// then it will only include the threads that the user has participated in.
|
||||||
|
SelectThreads(ctx context.Context, txn *sql.Tx, roomID, userID string, from types.StreamPosition, limit uint64) ([]string, types.StreamPosition, error)
|
||||||
// SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries
|
// SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries
|
||||||
// should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a
|
// should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a
|
||||||
// "from" or want to work forwards and don't have a "to").
|
// "from" or want to work forwards and don't have a "to").
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,22 @@ func newRelationsTable(t *testing.T, dbType test.DBType) (tables.Relations, *sql
|
||||||
t.Fatalf("failed to open db: %s", err)
|
t.Fatalf("failed to open db: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
_, err = postgres.NewPostgresEventsTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
var stream sqlite3.StreamIDStatements
|
||||||
|
if err = stream.Prepare(db); err != nil {
|
||||||
|
t.Fatalf("failed to prepare stream stmts: %s", err)
|
||||||
|
}
|
||||||
|
_, err = sqlite3.NewSqliteEventsTable(db, &stream)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make new table: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
var tab tables.Relations
|
var tab tables.Relations
|
||||||
|
|
||||||
switch dbType {
|
switch dbType {
|
||||||
case test.DBTypePostgres:
|
case test.DBTypePostgres:
|
||||||
tab, err = postgres.NewPostgresRelationsTable(db)
|
tab, err = postgres.NewPostgresRelationsTable(db)
|
||||||
|
|
@ -184,3 +199,53 @@ func TestRelationsTable(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const threadRelType = "m.thread"
|
||||||
|
|
||||||
|
func TestThreads(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
firstEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{
|
||||||
|
"body": "first message",
|
||||||
|
})
|
||||||
|
threadReplyEvent := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{
|
||||||
|
"body": "thread reply",
|
||||||
|
"m.relates_to": map[string]interface{}{
|
||||||
|
"event_id": firstEvent.EventID(),
|
||||||
|
"rel_type": threadRelType,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, _, close := newRelationsTable(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
err = tab.InsertRelation(ctx, nil, room.ID, firstEvent.EventID(), threadReplyEvent.EventID(), "m.room.message", threadReplyEvent.EventID())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var eventIds []string
|
||||||
|
eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, "", 0, 100)
|
||||||
|
|
||||||
|
for i, expected := range []string{
|
||||||
|
firstEvent.EventID(),
|
||||||
|
} {
|
||||||
|
eventID := eventIds[i]
|
||||||
|
if eventID != expected {
|
||||||
|
t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
eventIds, _, err = tab.SelectThreads(ctx, nil, room.ID, alice.ID, 0, 100)
|
||||||
|
for i, expected := range []string{
|
||||||
|
firstEvent.EventID(),
|
||||||
|
} {
|
||||||
|
eventID := eventIds[i]
|
||||||
|
if eventID != expected {
|
||||||
|
t.Fatalf("eventID mismatch: got %s, want %s", eventID, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue