From 87283e9de785f5153c5cf9b326d2640e202a36b3 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Fri, 6 Mar 2020 14:31:12 +0000 Subject: [PATCH] bugfix: fix sytest 155 by actually returning depth+1 and not 0 --- roomserver/storage/sqlite3/events_table.go | 10 ++++++---- syncapi/routing/messages.go | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 56e379100..4fa095913 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -111,7 +111,6 @@ type eventStatements struct { bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt bulkSelectEventNIDStmt *sql.Stmt - selectMaxEventDepthStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -135,7 +134,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, - {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, }.prepare(db) } @@ -462,8 +460,12 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 - selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt) - err := selectStmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) + iEventIDs := make([]interface{}, len(eventNIDs)) + for i, v := range eventNIDs { + iEventIDs[i] = v + } + sqlStr := strings.Replace(selectMaxEventDepthSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + err := txn.QueryRowContext(ctx, sqlStr, iEventIDs...).Scan(&result) if err != nil { return 0, err } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 7bbe16f3c..83bf75b2e 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -16,6 +16,7 @@ package routing import ( "context" + "fmt" "net/http" "sort" "strconv" @@ -176,6 +177,7 @@ func (r *messagesReq) retrieveEvents() ( r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, ) if err != nil { + err = fmt.Errorf("GetEventsInRange: %s", err) return } @@ -226,12 +228,14 @@ func (r *messagesReq) retrieveEvents() ( r.ctx, events[0].EventID(), ) if err != nil { + err = fmt.Errorf("EventPositionInTopology: for start event %s: %s", events[0].EventID(), err) return } endPos, err := r.db.EventPositionInTopology( r.ctx, events[len(events)-1].EventID(), ) if err != nil { + err = fmt.Errorf("EventPositionInTopology: for end event %s: %s", events[len(events)-1].EventID(), err) return } // Generate pagination tokens to send to the client using the positions