Move QueryVariadic etc into common, other device fixes

This commit is contained in:
Neil Alexander 2020-02-12 17:48:25 +00:00
parent f168da4d63
commit 1d1446253e
9 changed files with 52 additions and 41 deletions

View file

@ -17,9 +17,9 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings"
"time" "time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -70,9 +70,10 @@ const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -81,11 +82,11 @@ type devicesStatements struct {
updateDeviceNameStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema) _, err = db.Exec(devicesSchema)
if err != nil { if err != nil {
return return
@ -114,9 +115,6 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return return
} }
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server s.serverName = server
return return
} }
@ -158,8 +156,19 @@ func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
stmt := common.TxStmt(txn, s.deleteDevicesStmt) orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := common.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err return err
} }

View file

@ -16,6 +16,7 @@ package common
import ( import (
"database/sql" "database/sql"
"fmt"
"github.com/lib/pq" "github.com/lib/pq"
) )
@ -81,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error) pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505" return ok && pqErr.Code == "23505"
} }
// Hack of the century
func QueryVariadic(count int) string {
return QueryVariadicOffset(count, 0)
}
func QueryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -82,7 +82,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
if err != nil { if err != nil {

View file

@ -20,6 +20,7 @@ import (
"database/sql" "database/sql"
"strings" "strings"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -109,7 +110,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
for k, v := range eventStateKeys { for k, v := range eventStateKeys {
iEventStateKeys[k] = v iEventStateKeys[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
if err != nil { if err != nil {
@ -135,7 +136,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
for k, v := range eventStateKeyNIDs { for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v iEventStateKeyNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
if err != nil { if err != nil {

View file

@ -126,7 +126,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
for k, v := range eventTypes { for k, v := range eventTypes {
iEventTypes[k] = v iEventTypes[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1) selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -184,7 +184,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -237,7 +237,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -318,7 +318,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
////////////// //////////////
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
@ -364,7 +364,7 @@ func (s *eventStatements) bulkSelectEventReference(
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -398,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
for k, v := range eventNIDs { for k, v := range eventNIDs {
iEventNIDs[k] = v iEventNIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1) selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -435,7 +435,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
for k, v := range eventIDs { for k, v := range eventIDs {
iEventIDs[k] = v iEventIDs[k] = v
} }
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1) selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
selectPrep, err := txn.Prepare(selectOrig) selectPrep, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -17,7 +17,6 @@ package sqlite3
import ( import (
"database/sql" "database/sql"
"fmt"
) )
type statements struct { type statements struct {
@ -59,20 +58,3 @@ func (s *statements) prepare(db *sql.DB) error {
return nil return nil
} }
// Hack of the century
func queryVariadic(count int) string {
return queryVariadicOffset(count, 0)
}
func queryVariadicOffset(count, offset int) string {
str := "("
for i := 0; i < count; i++ {
str += fmt.Sprintf("$%d", i+offset+1)
if i < (count - 1) {
str += ", "
}
}
str += ")"
return str
}

View file

@ -133,7 +133,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
for k, v := range stateBlockNIDs { for k, v := range stateBlockNIDs {
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1)
selectPrep, err := s.db.Prepare(selectOrig) selectPrep, err := s.db.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -190,9 +190,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
sort.Sort(tuples) sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", queryVariadic(len(stateBlockNIDs)), 1) sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", queryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", queryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
var params []interface{} var params []interface{}
for _, val := range stateBlockNIDs { for _, val := range stateBlockNIDs {

View file

@ -22,6 +22,7 @@ import (
"strings" "strings"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
@ -89,7 +90,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for k, v := range stateNIDs { for k, v := range stateNIDs {
nids[k] = v nids[k] = v
} }
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1) selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1)
selectStmt, err := txn.Prepare(selectOrig) selectStmt, err := txn.Prepare(selectOrig)
if err != nil { if err != nil {
return nil, err return nil, err