From 867c59a20a97efd317a96fed5732f96354e000b2 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 19 Feb 2020 15:01:14 +0000 Subject: [PATCH] sqlite: Fix account data table in syncapi by commiting insert txns! --- syncapi/storage/sqlite3/account_data_table.go | 51 ++++++++++++------- syncapi/storage/sqlite3/syncserver.go | 32 ++++++------ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 3274e66ea..71105d0c3 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -19,15 +19,13 @@ import ( "context" "database/sql" - "github.com/lib/pq" - "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` CREATE TABLE IF NOT EXISTS syncapi_account_data_type ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER PRIMARY KEY, user_id TEXT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, @@ -43,9 +41,7 @@ const insertAccountDataSQL = "" + const selectAccountDataInRangeSQL = "" + "SELECT room_id, type FROM syncapi_account_data_type" + " WHERE user_id = $1 AND id > $2 AND id <= $3" + - " AND ( $4 IS NULL OR type IN ($4) )" + - " AND ( $5 IS NULL OR NOT(type IN ($5)) )" + - " ORDER BY id ASC LIMIT $6" + " ORDER BY id ASC" const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" @@ -53,8 +49,8 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt - selectAccountDataInRangeStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt + selectAccountDataInRangeStmt *sql.Stmt } func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { @@ -66,10 +62,10 @@ func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { return } - if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { + if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { return } - if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { + if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { return } return @@ -83,8 +79,7 @@ func (s *accountDataStatements) insertAccountData( if err != nil { return } - insertStmt := common.TxStmt(txn, s.insertAccountDataStmt) - _, err = insertStmt.ExecContext(ctx, pos, userID, roomID, dataType) + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) return } @@ -103,14 +98,13 @@ func (s *accountDataStatements) selectAccountDataInRange( oldPos-- } - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos, - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)), - pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)), - accountDataFilterPart.Limit, - ) + rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos) if err != nil { return } + defer rows.Close() // nolint: errcheck + + var entries int for rows.Next() { var dataType string @@ -120,22 +114,41 @@ func (s *accountDataStatements) selectAccountDataInRange( return } + // check if we should add this by looking at the filter. + // It would be nice if we could do this in SQL-land, but the mix of variadic + // and positional parameters makes the query annoyingly hard to do, it's easier + // and clearer to do it in Go-land. If there are no filters for [not]types then + // this gets skipped. + for _, includeType := range accountDataFilterPart.Types { + if includeType != dataType { // TODO: wildcard support + continue + } + } + for _, excludeType := range accountDataFilterPart.NotTypes { + if excludeType == dataType { // TODO: wildcard support + continue + } + } + if len(data[roomID]) > 0 { data[roomID] = append(data[roomID], dataType) } else { data[roomID] = []string{dataType} } + entries++ + if entries >= accountDataFilterPart.Limit { + break + } } - return + return data, nil } func (s *accountDataStatements) selectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 - stmt := common.TxStmt(txn, s.selectMaxAccountDataIDStmt) - err = stmt.QueryRowContext(ctx).Scan(&nullableID) + err = txn.Stmt(s.selectMaxAccountDataIDStmt).QueryRowContext(ctx).Scan(&nullableID) if nullableID.Valid { id = nullableID.Int64 } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 4714de764..6ad3419c7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -336,8 +336,12 @@ func (d *SyncServerDatasource) GetEventsInRange( } // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { - return d.syncPositionTx(ctx, nil) +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + tok, err = d.syncPositionTx(ctx, txn) + return err + }) + return } // BackwardExtremitiesForRoom returns the event IDs of all of the backward @@ -376,8 +380,12 @@ func (d *SyncServerDatasource) EventPositionInTopology( } // SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. -func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) +func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (pos types.StreamPosition, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + pos, err = d.syncStreamPositionTx(ctx, txn) + return err + }) + return } func (d *SyncServerDatasource) syncStreamPositionTx( @@ -734,18 +742,10 @@ func (d *SyncServerDatasource) GetAccountDataInRange( func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - txn, err := d.db.BeginTx(ctx, nil) - if err != nil { - return types.StreamPosition(0), err - } - var succeeded bool - defer func() { - txerr := common.EndTransaction(txn, &succeeded) - if err == nil && txerr != nil { - err = txerr - } - }() - sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) return }