bugfix: fix sytest 155 by actually returning depth+1 and not 0
This commit is contained in:
parent
a97b8eafd4
commit
87283e9de7
|
@ -111,7 +111,6 @@ type eventStatements struct {
|
|||
bulkSelectEventReferenceStmt *sql.Stmt
|
||||
bulkSelectEventIDStmt *sql.Stmt
|
||||
bulkSelectEventNIDStmt *sql.Stmt
|
||||
selectMaxEventDepthStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||
|
@ -135,7 +134,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
|||
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
|
||||
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
|
||||
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
||||
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
|
@ -462,8 +460,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
|
|||
|
||||
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
|
||||
var result int64
|
||||
selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt)
|
||||
err := selectStmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
|
||||
iEventIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventIDs[i] = v
|
||||
}
|
||||
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||
err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package routing
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
@ -176,6 +177,7 @@ func (r *messagesReq) retrieveEvents() (
|
|||
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
|
||||
)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetEventsInRange: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -226,12 +228,14 @@ func (r *messagesReq) retrieveEvents() (
|
|||
r.ctx, events[0].EventID(),
|
||||
)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("EventPositionInTopology: for start event %s: %s", events[0].EventID(), err)
|
||||
return
|
||||
}
|
||||
endPos, err := r.db.EventPositionInTopology(
|
||||
r.ctx, events[len(events)-1].EventID(),
|
||||
)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("EventPositionInTopology: for end event %s: %s", events[len(events)-1].EventID(), err)
|
||||
return
|
||||
}
|
||||
// Generate pagination tokens to send to the client using the positions
|
||||
|
|
Loading…
Reference in a new issue