From 1d1446253e78f66a04d3b4b3c18838a110ab8ef4 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 12 Feb 2020 17:48:25 +0000 Subject: [PATCH] Move QueryVariadic etc into common, other device fixes --- .../storage/devices/sqlite3/devices_table.go | 25 +++++++++++++------ common/sql.go | 18 +++++++++++++ .../storage/sqlite3/event_json_table.go | 2 +- .../storage/sqlite3/event_state_keys_table.go | 5 ++-- .../storage/sqlite3/event_types_table.go | 2 +- roomserver/storage/sqlite3/events_table.go | 12 ++++----- roomserver/storage/sqlite3/sql.go | 18 ------------- .../storage/sqlite3/state_block_table.go | 8 +++--- .../storage/sqlite3/state_snapshot_table.go | 3 ++- 9 files changed, 52 insertions(+), 41 deletions(-) diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go index 55b8b5f4e..dc88890d3 100644 --- a/clientapi/auth/storage/devices/sqlite3/devices_table.go +++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go @@ -17,9 +17,9 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" - "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -70,9 +70,10 @@ const deleteDevicesByLocalpartSQL = "" + "DELETE FROM device_devices WHERE localpart = $1" const deleteDevicesSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" type devicesStatements struct { + db *sql.DB insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -81,11 +82,11 @@ type devicesStatements struct { updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt - deleteDevicesStmt *sql.Stmt serverName gomatrixserverlib.ServerName } func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db _, err = db.Exec(devicesSchema) if err != nil { return @@ -114,9 +115,6 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } s.serverName = server return } @@ -158,8 +156,19 @@ func (s *devicesStatements) deleteDevice( func (s *devicesStatements) deleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { - stmt := common.TxStmt(txn, s.deleteDevicesStmt) - _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1) + prep, err := s.db.Prepare(orig) + if err != nil { + return err + } + stmt := common.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + params = append(params, params...) + _, err = stmt.ExecContext(ctx, params...) return err } diff --git a/common/sql.go b/common/sql.go index 043de8cd0..975930202 100644 --- a/common/sql.go +++ b/common/sql.go @@ -16,6 +16,7 @@ package common import ( "database/sql" + "fmt" "github.com/lib/pq" ) @@ -81,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool { pqErr, ok := err.(*pq.Error) return ok && pqErr.Code == "23505" } + +// Hack of the century +func QueryVariadic(count int) string { + return QueryVariadicOffset(count, 0) +} + +func QueryVariadicOffset(count, offset int) string { + str := "(" + for i := 0; i < count; i++ { + str += fmt.Sprintf("$%d", i+offset+1) + if i < (count - 1) { + str += ", " + } + } + str += ")" + return str +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 4ccf16d3c..f6c83906a 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -82,7 +82,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) if err != nil { diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 26771de33..b8bc6c02d 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "strings" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -109,7 +110,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( for k, v := range eventStateKeys { iEventStateKeys[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1) rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) if err != nil { @@ -135,7 +136,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey( for k, v := range eventStateKeyNIDs { iEventStateKeyNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1) rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) if err != nil { diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index edc759c01..edc06d4c6 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -126,7 +126,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID( for k, v := range eventTypes { iEventTypes[k] = v } - selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1) + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1) selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 15ec7b7f2..4ed1395da 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -184,7 +184,7 @@ func (s *eventStatements) bulkSelectStateEventByID( for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err @@ -237,7 +237,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID( for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err @@ -318,7 +318,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) ////////////// rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) @@ -364,7 +364,7 @@ func (s *eventStatements) bulkSelectEventReference( for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err @@ -398,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev for k, v := range eventNIDs { iEventNIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err @@ -435,7 +435,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e for k, v := range eventIDs { iEventIDs[k] = v } - selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) selectPrep, err := txn.Prepare(selectOrig) if err != nil { return nil, err diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go index 408b46e0b..0d49432b8 100644 --- a/roomserver/storage/sqlite3/sql.go +++ b/roomserver/storage/sqlite3/sql.go @@ -17,7 +17,6 @@ package sqlite3 import ( "database/sql" - "fmt" ) type statements struct { @@ -59,20 +58,3 @@ func (s *statements) prepare(db *sql.DB) error { return nil } - -// Hack of the century -func queryVariadic(count int) string { - return queryVariadicOffset(count, 0) -} - -func queryVariadicOffset(count, offset int) string { - str := "(" - for i := 0; i < count; i++ { - str += fmt.Sprintf("$%d", i+offset+1) - if i < (count - 1) { - str += ", " - } - } - str += ")" - return str -} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 5c1829833..41a8bcff3 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -133,7 +133,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries( for k, v := range stateBlockNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1) selectPrep, err := s.db.Prepare(selectOrig) if err != nil { return nil, err @@ -190,9 +190,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( sort.Sort(tuples) eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", queryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", queryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", queryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) var params []interface{} for _, val := range stateBlockNIDs { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 1cecb3778..795d6b388 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -89,7 +90,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( for k, v := range stateNIDs { nids[k] = v } - selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1) selectStmt, err := txn.Prepare(selectOrig) if err != nil { return nil, err