mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
Move QueryVariadic etc into common, other device fixes
This commit is contained in:
parent
f168da4d63
commit
1d1446253e
|
|
@ -17,9 +17,9 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/common"
|
"github.com/matrix-org/dendrite/common"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
|
@ -70,9 +70,10 @@ const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1"
|
"DELETE FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
|
db *sql.DB
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
|
|
@ -81,11 +82,11 @@ type devicesStatements struct {
|
||||||
updateDeviceNameStmt *sql.Stmt
|
updateDeviceNameStmt *sql.Stmt
|
||||||
deleteDeviceStmt *sql.Stmt
|
deleteDeviceStmt *sql.Stmt
|
||||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||||
deleteDevicesStmt *sql.Stmt
|
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
_, err = db.Exec(devicesSchema)
|
_, err = db.Exec(devicesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -114,9 +115,6 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
s.serverName = server
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -158,8 +156,19 @@ func (s *devicesStatements) deleteDevice(
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) deleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
stmt := common.TxStmt(txn, s.deleteDevicesStmt)
|
orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
prep, err := s.db.Prepare(orig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stmt := common.TxStmt(txn, prep)
|
||||||
|
params := make([]interface{}, len(devices)+1)
|
||||||
|
params[0] = localpart
|
||||||
|
for i, v := range devices {
|
||||||
|
params[i+1] = v
|
||||||
|
}
|
||||||
|
params = append(params, params...)
|
||||||
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
@ -81,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
|
||||||
pqErr, ok := err.(*pq.Error)
|
pqErr, ok := err.(*pq.Error)
|
||||||
return ok && pqErr.Code == "23505"
|
return ok && pqErr.Code == "23505"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hack of the century
|
||||||
|
func QueryVariadic(count int) string {
|
||||||
|
return QueryVariadicOffset(count, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryVariadicOffset(count, offset int) string {
|
||||||
|
str := "("
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
str += fmt.Sprintf("$%d", i+offset+1)
|
||||||
|
if i < (count - 1) {
|
||||||
|
str += ", "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str += ")"
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
|
|
||||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -109,7 +110,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||||
for k, v := range eventStateKeys {
|
for k, v := range eventStateKeys {
|
||||||
iEventStateKeys[k] = v
|
iEventStateKeys[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeys)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1)
|
||||||
|
|
||||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -135,7 +136,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||||
for k, v := range eventStateKeyNIDs {
|
for k, v := range eventStateKeyNIDs {
|
||||||
iEventStateKeyNIDs[k] = v
|
iEventStateKeyNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", queryVariadic(len(eventStateKeyNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||||
|
|
||||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
||||||
for k, v := range eventTypes {
|
for k, v := range eventTypes {
|
||||||
iEventTypes[k] = v
|
iEventTypes[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", queryVariadic(len(iEventTypes)), 1)
|
selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
|
||||||
for k, v := range eventIDs {
|
for k, v := range eventIDs {
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -237,7 +237,7 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||||
for k, v := range eventIDs {
|
for k, v := range eventIDs {
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -318,7 +318,7 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
//////////////
|
//////////////
|
||||||
|
|
||||||
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||||
|
|
@ -364,7 +364,7 @@ func (s *eventStatements) bulkSelectEventReference(
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -398,7 +398,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
|
||||||
for k, v := range eventNIDs {
|
for k, v := range eventNIDs {
|
||||||
iEventNIDs[k] = v
|
iEventNIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", queryVariadic(len(iEventNIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -435,7 +435,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
|
||||||
for k, v := range eventIDs {
|
for k, v := range eventIDs {
|
||||||
iEventIDs[k] = v
|
iEventIDs[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", queryVariadic(len(iEventIDs)), 1)
|
selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
|
||||||
selectPrep, err := txn.Prepare(selectOrig)
|
selectPrep, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type statements struct {
|
type statements struct {
|
||||||
|
|
@ -59,20 +58,3 @@ func (s *statements) prepare(db *sql.DB) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hack of the century
|
|
||||||
func queryVariadic(count int) string {
|
|
||||||
return queryVariadicOffset(count, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryVariadicOffset(count, offset int) string {
|
|
||||||
str := "("
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
str += fmt.Sprintf("$%d", i+offset+1)
|
|
||||||
if i < (count - 1) {
|
|
||||||
str += ", "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
str += ")"
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||||
for k, v := range stateBlockNIDs {
|
for k, v := range stateBlockNIDs {
|
||||||
nids[k] = v
|
nids[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", queryVariadic(len(nids)), 1)
|
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1)
|
||||||
selectPrep, err := s.db.Prepare(selectOrig)
|
selectPrep, err := s.db.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -190,9 +190,9 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||||
sort.Sort(tuples)
|
sort.Sort(tuples)
|
||||||
|
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", queryVariadic(len(stateBlockNIDs)), 1)
|
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||||
sqlStatement = strings.Replace(sqlStatement, "($2)", queryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||||
sqlStatement = strings.Replace(sqlStatement, "($3)", queryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||||
|
|
||||||
var params []interface{}
|
var params []interface{}
|
||||||
for _, val := range stateBlockNIDs {
|
for _, val := range stateBlockNIDs {
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/common"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -89,7 +90,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||||
for k, v := range stateNIDs {
|
for k, v := range stateNIDs {
|
||||||
nids[k] = v
|
nids[k] = v
|
||||||
}
|
}
|
||||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", queryVariadic(len(nids)), 1)
|
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1)
|
||||||
selectStmt, err := txn.Prepare(selectOrig)
|
selectStmt, err := txn.Prepare(selectOrig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue