dendrite/syncapi/storage/send_to_device_table.go
2019-10-14 15:55:34 +05:30

163 lines
4.3 KiB
Go

package storage
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
)
// we treat send to device as abbrev as STD in the context below.
const sendToDeviceSchema = `
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
txn_id TEXT NOT NULL,
sender TEXT NOT NULL,
event_type TEXT NOT NULL,
target_device_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
event_json TEXT NOT NULL,
del_read INTEGER DEFAULT 0,
max_read BIGINT DEFAULT currval('syncapi_stream_id') ,
CONSTRAINT syncapi_send_to_device_unique UNIQUE (txn_id, target_device_id, target_user_id)
);
`
const insertSTDSQL = "" +
"INSERT INTO syncapi_send_to_device (" +
" sender, event_type, target_user_id, target_device_id, txn_id, event_json" +
") VALUES ($1, $2, $3, $4, $5, $6) RETURNING id"
const deleteSTDSQL = "" +
"DELETE FROM syncapi_send_to_device WHERE target_user_id = $1 AND target_device_id = $2 AND max_read < $3 AND del_read = 1"
const selectSTDEventsInRangeSQL = "" +
"SELECT id, sender, event_type, event_json FROM syncapi_send_to_device" +
" WHERE target_user_id = $1 AND target_device_id = $2 AND id <= $3" +
" ORDER BY id LIMIT 100 "
const updateSTDEventSQL = "" +
"UPDATE syncapi_send_to_device SET del_read = 1 , max_read = $1 WHERE id = ANY($2)"
const selectMaxSTDIDSQL = "" +
"SELECT MAX(id) FROM syncapi_send_to_device"
type stdEventsStatements struct {
insertStdEventStmt *sql.Stmt
selectStdEventsInRangeStmt *sql.Stmt
deleteStdEventStmt *sql.Stmt
selectStdIDStmt *sql.Stmt
updateStdStmt *sql.Stmt
}
func (s *stdEventsStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(sendToDeviceSchema)
if err != nil {
return
}
if s.insertStdEventStmt, err = db.Prepare(insertSTDSQL); err != nil {
return
}
if s.selectStdEventsInRangeStmt, err = db.Prepare(selectSTDEventsInRangeSQL); err != nil {
return
}
if s.deleteStdEventStmt, err = db.Prepare(deleteSTDSQL); err != nil {
return
}
if s.selectStdIDStmt, err = db.Prepare(selectMaxSTDIDSQL); err != nil {
return
}
if s.updateStdStmt, err = db.Prepare(updateSTDEventSQL); err != nil {
return
}
return
}
func (s *stdEventsStatements) insertStdEvent(
ctx context.Context, stdEvent types.StdHolder,
transactionID string, targetUID, targetDevice string,
) (streamPos int64, err error) {
err = s.insertStdEventStmt.QueryRowContext(
ctx,
stdEvent.Sender,
stdEvent.EventTyp,
targetUID,
targetDevice,
transactionID,
stdEvent.Event,
).Scan(&streamPos)
return
}
func (s *stdEventsStatements) deleteStdEvent(
ctx context.Context, userID, deviceID string,
idUpBound int64,
) error {
_, err := s.deleteStdEventStmt.ExecContext(ctx, userID, deviceID, idUpBound)
return err
}
func (s *stdEventsStatements) selectStdEventsInRange(
ctx context.Context, txn *sql.Tx,
targetUserID, targetDeviceID string,
endPos int64,
) ([]types.StdHolder, error) {
stdHolder := []types.StdHolder{}
stmt := common.TxStmt(txn, s.selectStdEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, targetDeviceID, endPos)
if err != nil {
return nil, err
}
for rows.Next() {
holder := types.StdHolder{}
var (
id int64
sender string
eventType string
eventJSON []byte
)
if err = rows.Scan(&id, &sender, &eventType, &eventJSON); err != nil {
closeErr := rows.Close()
if closeErr != nil {
return nil, closeErr
}
return nil, err
}
holder.StreamID = id
holder.Sender = sender
holder.Event = eventJSON
holder.EventTyp = eventType
stdHolder = append(stdHolder, holder)
}
err = rows.Close()
if err != nil {
return nil, err
}
// update events with read mark
update := []int64{}
for _, val := range stdHolder {
update = append(update, val.StreamID)
}
updateStmt := common.TxStmt(txn, s.updateStdStmt)
_, err = updateStmt.ExecContext(ctx, endPos, pq.Array(update))
if err != nil {
return nil, err
}
return stdHolder, nil
}
func (s *stdEventsStatements) selectMaxStdID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := common.TxStmt(txn, s.selectStdIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}