Move QueryVariadic etc into common, other device fixes

This commit is contained in:
Neil Alexander 2020-02-12 17:48:25 +00:00
parent f168da4d63
commit 1d1446253e
9 changed files with 52 additions and 41 deletions

View file

@ -17,9 +17,9 @@ package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -70,9 +70,10 @@ const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -81,11 +82,11 @@ type devicesStatements struct {
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema)
if err != nil {
return
@ -114,9 +115,6 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server
return
}
@ -158,8 +156,19 @@ func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := common.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
}

View file

@ -16,6 +16,7 @@ package common
import (
"database/sql"
"fmt"
"github.com/lib/pq"
)
@ -81,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// Hack of the century
func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0)
}
func QueryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -82,7 +82,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
if err != nil {

View file

@ -20,6 +20,7 @@ import (
"database/sql"
"strings"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -109,7 +110,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
if err != nil {
@ -135,7 +136,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
if err != nil {

View file

@ -126,7 +126,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
for k, v := range eventTypes {
iEventTypes[k] = v
}
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1)
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err

View file

@ -184,7 +184,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err
@ -237,7 +237,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err
@ -318,7 +318,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
//////////////
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
@ -364,7 +364,7 @@ func (s *eventStatements) bulkSelectEventReference(
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err
@ -398,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err
@ -435,7 +435,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
for k, v := range eventIDs {
iEventIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err

View file

@ -17,7 +17,6 @@ package sqlite3
import (
"database/sql"
"fmt"
)
type statements struct {
@ -59,20 +58,3 @@ func (s *statements) prepare(db *sql.DB) error {
return nil
}
// Hack of the century
func queryVariadic(count int) string {
return queryVariadicOffset(count, 0)
}
func queryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -133,7 +133,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
for k, v := range stateBlockNIDs {
nids[k] = v
}
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1)
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
@ -190,9 +190,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", queryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", queryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", queryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
var params []interface{}
for _, val := range stateBlockNIDs {

View file

@ -22,6 +22,7 @@ import (
"strings"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
)
@ -89,7 +90,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for k, v := range stateNIDs {
nids[k] = v
}
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1)
selectStmt, err := txn.Prepare(selectOrig)
if err != nil {
return nil, err