bugfix: fix sytest 155 by actually returning depth+1 and not 0

This commit is contained in:
Kegan Dougal 2020-03-06 14:31:12 +00:00
parent a97b8eafd4
commit 87283e9de7
2 changed files with 10 additions and 4 deletions

View file

@ -111,7 +111,6 @@ type eventStatements struct {
bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt
bulkSelectEventIDStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt
bulkSelectEventNIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt
selectMaxEventDepthStmt *sql.Stmt
} }
func (s *eventStatements) prepare(db *sql.DB) (err error) { func (s *eventStatements) prepare(db *sql.DB) (err error) {
@ -135,7 +134,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
{&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
}.prepare(db) }.prepare(db)
} }
@ -462,8 +460,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
var result int64 var result int64
selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt) iEventIDs := make([]interface{}, len(eventNIDs))
err := selectStmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) for i, v := range eventNIDs {
iEventIDs[i] = v
}
sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"sort" "sort"
"strconv" "strconv"
@ -176,6 +177,7 @@ func (r *messagesReq) retrieveEvents() (
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
) )
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %s", err)
return return
} }
@ -226,12 +228,14 @@ func (r *messagesReq) retrieveEvents() (
r.ctx, events[0].EventID(), r.ctx, events[0].EventID(),
) )
if err != nil { if err != nil {
err = fmt.Errorf("EventPositionInTopology: for start event %s: %s", events[0].EventID(), err)
return return
} }
endPos, err := r.db.EventPositionInTopology( endPos, err := r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(), r.ctx, events[len(events)-1].EventID(),
) )
if err != nil { if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %s", events[len(events)-1].EventID(), err)
return return
} }
// Generate pagination tokens to send to the client using the positions // Generate pagination tokens to send to the client using the positions