mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-16 11:23:11 -06:00
163 lines
4.3 KiB
Go
163 lines
4.3 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"github.com/lib/pq"
|
|
"github.com/matrix-org/dendrite/common"
|
|
"github.com/matrix-org/dendrite/syncapi/types"
|
|
)
|
|
|
|
// we treat send to device as abbrev as STD in the context below.
|
|
|
|
const sendToDeviceSchema = `
|
|
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
|
|
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
|
|
txn_id TEXT NOT NULL,
|
|
sender TEXT NOT NULL,
|
|
event_type TEXT NOT NULL,
|
|
target_device_id TEXT NOT NULL,
|
|
target_user_id TEXT NOT NULL,
|
|
event_json TEXT NOT NULL,
|
|
del_read INTEGER DEFAULT 0,
|
|
max_read BIGINT DEFAULT currval('syncapi_stream_id') ,
|
|
CONSTRAINT syncapi_send_to_device_unique UNIQUE (txn_id, target_device_id, target_user_id)
|
|
);
|
|
`
|
|
|
|
const insertSTDSQL = "" +
|
|
"INSERT INTO syncapi_send_to_device (" +
|
|
" sender, event_type, target_user_id, target_device_id, txn_id, event_json" +
|
|
") VALUES ($1, $2, $3, $4, $5, $6) RETURNING id"
|
|
|
|
const deleteSTDSQL = "" +
|
|
"DELETE FROM syncapi_send_to_device WHERE target_user_id = $1 AND target_device_id = $2 AND max_read < $3 AND del_read = 1"
|
|
|
|
const selectSTDEventsInRangeSQL = "" +
|
|
"SELECT id, sender, event_type, event_json FROM syncapi_send_to_device" +
|
|
" WHERE target_user_id = $1 AND target_device_id = $2 AND id <= $3" +
|
|
" ORDER BY id LIMIT 100 "
|
|
|
|
const updateSTDEventSQL = "" +
|
|
"UPDATE syncapi_send_to_device SET del_read = 1 , max_read = $1 WHERE id = ANY($2)"
|
|
|
|
const selectMaxSTDIDSQL = "" +
|
|
"SELECT MAX(id) FROM syncapi_send_to_device"
|
|
|
|
type stdEventsStatements struct {
|
|
insertStdEventStmt *sql.Stmt
|
|
selectStdEventsInRangeStmt *sql.Stmt
|
|
deleteStdEventStmt *sql.Stmt
|
|
selectStdIDStmt *sql.Stmt
|
|
updateStdStmt *sql.Stmt
|
|
}
|
|
|
|
func (s *stdEventsStatements) prepare(db *sql.DB) (err error) {
|
|
_, err = db.Exec(sendToDeviceSchema)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if s.insertStdEventStmt, err = db.Prepare(insertSTDSQL); err != nil {
|
|
return
|
|
}
|
|
if s.selectStdEventsInRangeStmt, err = db.Prepare(selectSTDEventsInRangeSQL); err != nil {
|
|
return
|
|
}
|
|
if s.deleteStdEventStmt, err = db.Prepare(deleteSTDSQL); err != nil {
|
|
return
|
|
}
|
|
if s.selectStdIDStmt, err = db.Prepare(selectMaxSTDIDSQL); err != nil {
|
|
return
|
|
}
|
|
if s.updateStdStmt, err = db.Prepare(updateSTDEventSQL); err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func (s *stdEventsStatements) insertStdEvent(
|
|
ctx context.Context, stdEvent types.StdHolder,
|
|
transactionID string, targetUID, targetDevice string,
|
|
) (streamPos int64, err error) {
|
|
err = s.insertStdEventStmt.QueryRowContext(
|
|
ctx,
|
|
stdEvent.Sender,
|
|
stdEvent.EventTyp,
|
|
targetUID,
|
|
targetDevice,
|
|
transactionID,
|
|
stdEvent.Event,
|
|
).Scan(&streamPos)
|
|
return
|
|
}
|
|
|
|
func (s *stdEventsStatements) deleteStdEvent(
|
|
ctx context.Context, userID, deviceID string,
|
|
idUpBound int64,
|
|
) error {
|
|
_, err := s.deleteStdEventStmt.ExecContext(ctx, userID, deviceID, idUpBound)
|
|
return err
|
|
}
|
|
|
|
func (s *stdEventsStatements) selectStdEventsInRange(
|
|
ctx context.Context, txn *sql.Tx,
|
|
targetUserID, targetDeviceID string,
|
|
endPos int64,
|
|
) ([]types.StdHolder, error) {
|
|
stdHolder := []types.StdHolder{}
|
|
stmt := common.TxStmt(txn, s.selectStdEventsInRangeStmt)
|
|
rows, err := stmt.QueryContext(ctx, targetUserID, targetDeviceID, endPos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for rows.Next() {
|
|
holder := types.StdHolder{}
|
|
var (
|
|
id int64
|
|
sender string
|
|
eventType string
|
|
eventJSON []byte
|
|
)
|
|
if err = rows.Scan(&id, &sender, &eventType, &eventJSON); err != nil {
|
|
closeErr := rows.Close()
|
|
if closeErr != nil {
|
|
return nil, closeErr
|
|
}
|
|
return nil, err
|
|
}
|
|
holder.StreamID = id
|
|
holder.Sender = sender
|
|
holder.Event = eventJSON
|
|
holder.EventTyp = eventType
|
|
stdHolder = append(stdHolder, holder)
|
|
}
|
|
err = rows.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// update events with read mark
|
|
update := []int64{}
|
|
for _, val := range stdHolder {
|
|
update = append(update, val.StreamID)
|
|
}
|
|
updateStmt := common.TxStmt(txn, s.updateStdStmt)
|
|
_, err = updateStmt.ExecContext(ctx, endPos, pq.Array(update))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return stdHolder, nil
|
|
}
|
|
|
|
func (s *stdEventsStatements) selectMaxStdID(
|
|
ctx context.Context, txn *sql.Tx,
|
|
) (id int64, err error) {
|
|
var nullableID sql.NullInt64
|
|
stmt := common.TxStmt(txn, s.selectStdIDStmt)
|
|
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
|
|
if nullableID.Valid {
|
|
id = nullableID.Int64
|
|
}
|
|
return
|
|
}
|