mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
Move QueryVariadic etc into common, other device fixes
This commit is contained in:
parent
f168da4d63
commit
1d1446253e
|
|
@ -17,9 +17,9 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
|
@ -70,9 +70,10 @@ const deleteDevicesByLocalpartSQL = "" +
|
|||
"DELETE FROM device_devices WHERE localpart = $1"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
|
|
@ -81,11 +82,11 @@ type devicesStatements struct {
|
|||
updateDeviceNameStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||
deleteDevicesStmt *sql.Stmt
|
||||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(devicesSchema)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -114,9 +115,6 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
|
@ -158,8 +156,19 @@ func (s *devicesStatements) deleteDevice(
|
|||
func (s *devicesStatements) deleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
||||
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := common.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params[0] = localpart
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
}
|
||||
params = append(params, params...)
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ package common
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
|
@ -81,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
|
|||
pqErr, ok := err.(*pq.Error)
|
||||
return ok && pqErr.Code == "23505"
|
||||
}
|
||||
|
||||
// Hack of the century
|
||||
func QueryVariadic(count int) string {
|
||||
return QueryVariadicOffset(count, 0)
|
||||
}
|
||||
|
||||
func QueryVariadicOffset(count, offset int) string {
|
||||
str := "("
|
||||
for i := 0; i < count; i++ {
|
||||
str += fmt.Sprintf("$%d", i+offset+1)
|
||||
if i < (count - 1) {
|
||||
str += ", "
|
||||
}
|
||||
}
|
||||
str += ")"
|
||||
return str
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
|
|||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||
|
||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
|
|
@ -109,7 +110,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
|||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
if err != nil {
|
||||
|
|
@ -135,7 +136,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
|||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
|||
for k, v := range eventTypes {
|
||||
iEventTypes[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
|||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectPrep, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -237,7 +237,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
|
|||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectPrep, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -318,7 +318,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
|||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||
//////////////
|
||||
|
||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
|
|
@ -364,7 +364,7 @@ func (s *eventStatements) bulkSelectEventReference(
|
|||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectPrep, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -398,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
|
|||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectPrep, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -435,7 +435,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
|
|||
for k, v := range eventIDs {
|
||||
iEventIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||
selectPrep, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ package sqlite3
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
|
|
@ -59,20 +58,3 @@ func (s *statements) prepare(db *sql.DB) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Hack of the century
|
||||
func queryVariadic(count int) string {
|
||||
return queryVariadicOffset(count, 0)
|
||||
}
|
||||
|
||||
func queryVariadicOffset(count, offset int) string {
|
||||
str := "("
|
||||
for i := 0; i < count; i++ {
|
||||
str += fmt.Sprintf("$%d", i+offset+1)
|
||||
if i < (count - 1) {
|
||||
str += ", "
|
||||
}
|
||||
}
|
||||
str += ")"
|
||||
return str
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
|||
for k, v := range stateBlockNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -190,9 +190,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
|||
sort.Sort(tuples)
|
||||
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", queryVariadic(len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($2)", queryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($3)", queryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
|
||||
var params []interface{}
|
||||
for _, val := range stateBlockNIDs {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
|
|
@ -89,7 +90,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
|||
for k, v := range stateNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := txn.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
Loading…
Reference in a new issue