Add federation peeking table tests (#2920)
As the title says, adds tests for inbound/outbound peeking federation table tests. Also removes some unused code
This commit is contained in:
parent
76db8e90de
commit
d3db542fbf
|
@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
|
||||||
d.dbMutex.Lock()
|
|
||||||
defer d.dbMutex.Unlock()
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
if pdus, ok := d.associatedPDUs[serverName]; ok {
|
|
||||||
count = int64(len(pdus))
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) {
|
|
||||||
d.dbMutex.Lock()
|
|
||||||
defer d.dbMutex.Unlock()
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
if edus, ok := d.associatedEDUs[serverName]; ok {
|
|
||||||
count = int64(len(edus))
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
||||||
d.dbMutex.Lock()
|
d.dbMutex.Lock()
|
||||||
defer d.dbMutex.Unlock()
|
defer d.dbMutex.Unlock()
|
||||||
|
|
|
@ -45,9 +45,6 @@ type Database interface {
|
||||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
|
|
||||||
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
|
|
||||||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
|
||||||
|
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectInboundPeeksSQL = "" +
|
const selectInboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts"
|
||||||
|
|
||||||
const renewInboundPeekSQL = "" +
|
const renewInboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteInboundPeekSQL = "" +
|
const deleteInboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteInboundPeeksSQL = "" +
|
const deleteInboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||||
|
@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||||
}
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
return
|
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||||
}
|
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||||
return
|
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||||
|
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectOutboundPeeksSQL = "" +
|
const selectOutboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewOutboundPeekSQL = "" +
|
const renewOutboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteOutboundPeekSQL = "" +
|
const deleteOutboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteOutboundPeeksSQL = "" +
|
const deleteOutboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||||
|
@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||||
}
|
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||||
return
|
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||||
}
|
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||||
return
|
}.Prepare(db)
|
||||||
}
|
|
||||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||||
|
|
|
@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueueEDUCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueueServerNamesSQL = "" +
|
const selectQueueServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||||
|
|
||||||
|
@ -81,7 +77,6 @@ type queueEDUsStatements struct {
|
||||||
deleteQueueEDUStmt *sql.Stmt
|
deleteQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueueEDUCountStmt *sql.Stmt
|
|
||||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||||
selectExpiredEDUsStmt *sql.Stmt
|
selectExpiredEDUsStmt *sql.Stmt
|
||||||
deleteExpiredEDUsStmt *sql.Stmt
|
deleteExpiredEDUsStmt *sql.Stmt
|
||||||
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
||||||
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
|
{&s.deleteQueueEDUStmt, deleteQueueEDUSQL},
|
||||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
|
||||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||||
|
@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
|
|
@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueuePDUsCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueuePDUServerNamesSQL = "" +
|
const selectQueuePDUServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||||
|
|
||||||
|
@ -71,7 +67,6 @@ type queuePDUsStatements struct {
|
||||||
deleteQueuePDUsStmt *sql.Stmt
|
deleteQueuePDUsStmt *sql.Stmt
|
||||||
selectQueuePDUsStmt *sql.Stmt
|
selectQueuePDUsStmt *sql.Stmt
|
||||||
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueuePDUsCountStmt *sql.Stmt
|
|
||||||
selectQueuePDUServerNamesStmt *sql.Stmt
|
selectQueuePDUServerNamesStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
@ -162,15 +162,6 @@ func (d *Database) CleanEDUs(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingEDUCount returns the number of EDUs waiting to be
|
|
||||||
// sent for a given servername.
|
|
||||||
func (d *Database) GetPendingEDUCount(
|
|
||||||
ctx context.Context,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingServerNames returns the server names that have EDUs
|
// GetPendingServerNames returns the server names that have EDUs
|
||||||
// waiting to be sent.
|
// waiting to be sent.
|
||||||
func (d *Database) GetPendingEDUServerNames(
|
func (d *Database) GetPendingEDUServerNames(
|
||||||
|
|
|
@ -141,15 +141,6 @@ func (d *Database) CleanPDUs(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPendingPDUCount returns the number of PDUs waiting to be
|
|
||||||
// sent for a given servername.
|
|
||||||
func (d *Database) GetPendingPDUCount(
|
|
||||||
ctx context.Context,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPendingServerNames returns the server names that have PDUs
|
// GetPendingServerNames returns the server names that have PDUs
|
||||||
// waiting to be sent.
|
// waiting to be sent.
|
||||||
func (d *Database) GetPendingPDUServerNames(
|
func (d *Database) GetPendingPDUServerNames(
|
||||||
|
|
|
@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectInboundPeeksSQL = "" +
|
const selectInboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewInboundPeekSQL = "" +
|
const renewInboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteInboundPeekSQL = "" +
|
const deleteInboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteInboundPeeksSQL = "" +
|
const deleteInboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||||
|
@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertInboundPeekStmt, insertInboundPeekSQL},
|
||||||
}
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
{&s.selectInboundPeekStmt, selectInboundPeekSQL},
|
||||||
return
|
{&s.selectInboundPeeksStmt, selectInboundPeeksSQL},
|
||||||
}
|
{&s.renewInboundPeekStmt, renewInboundPeekSQL},
|
||||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
{&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL},
|
||||||
return
|
{&s.deleteInboundPeekStmt, deleteInboundPeekSQL},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||||
|
|
|
@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const selectOutboundPeeksSQL = "" +
|
const selectOutboundPeeksSQL = "" +
|
||||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts"
|
||||||
|
|
||||||
const renewOutboundPeekSQL = "" +
|
const renewOutboundPeekSQL = "" +
|
||||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||||
|
|
||||||
const deleteOutboundPeekSQL = "" +
|
const deleteOutboundPeekSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||||
|
|
||||||
const deleteOutboundPeeksSQL = "" +
|
const deleteOutboundPeeksSQL = "" +
|
||||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||||
|
@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return
|
{&s.insertOutboundPeekStmt, insertOutboundPeekSQL},
|
||||||
}
|
{&s.selectOutboundPeekStmt, selectOutboundPeekSQL},
|
||||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
{&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL},
|
||||||
return
|
{&s.renewOutboundPeekStmt, renewOutboundPeekSQL},
|
||||||
}
|
{&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL},
|
||||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
{&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL},
|
||||||
return
|
}.Prepare(db)
|
||||||
}
|
|
||||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||||
|
|
|
@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueueEDUCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueueServerNamesSQL = "" +
|
const selectQueueServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||||
|
|
||||||
|
@ -82,7 +78,6 @@ type queueEDUsStatements struct {
|
||||||
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
|
// deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueueEDUCountStmt *sql.Stmt
|
|
||||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||||
selectExpiredEDUsStmt *sql.Stmt
|
selectExpiredEDUsStmt *sql.Stmt
|
||||||
deleteExpiredEDUsStmt *sql.Stmt
|
deleteExpiredEDUsStmt *sql.Stmt
|
||||||
|
@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error {
|
||||||
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
|
{&s.insertQueueEDUStmt, insertQueueEDUSQL},
|
||||||
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
{&s.selectQueueEDUStmt, selectQueueEDUSQL},
|
||||||
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
{&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL},
|
||||||
{&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL},
|
|
||||||
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
{&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL},
|
||||||
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
{&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL},
|
||||||
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
{&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL},
|
||||||
|
@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
|
|
|
@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
const selectQueuePDUsCountSQL = "" +
|
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
|
||||||
" WHERE server_name = $1"
|
|
||||||
|
|
||||||
const selectQueuePDUsServerNamesSQL = "" +
|
const selectQueuePDUsServerNamesSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||||
|
|
||||||
|
@ -79,7 +75,6 @@ type queuePDUsStatements struct {
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsStmt *sql.Stmt
|
selectQueuePDUsStmt *sql.Stmt
|
||||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||||
selectQueuePDUsCountStmt *sql.Stmt
|
|
||||||
selectQueueServerNamesStmt *sql.Stmt
|
selectQueueServerNamesStmt *sql.Stmt
|
||||||
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
|
||||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
|
||||||
) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
// It's acceptable for there to be no rows referencing a given
|
|
||||||
// JSON NID but it's not an error condition. Just return as if
|
|
||||||
// there's a zero count.
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
@ -2,10 +2,12 @@ package storage_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
|
@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
assert.Equal(t, 2, len(data))
|
assert.Equal(t, 2, len(data))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOutboundPeeking(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
|
||||||
|
// Add outbound peek
|
||||||
|
if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if outboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
// Renew the peek
|
||||||
|
if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
peekIDs := []string{peekID}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||||
|
}
|
||||||
|
for i := range outboundPeeks {
|
||||||
|
if outboundPeeks[i].PeekID != peekIDs[i] {
|
||||||
|
t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInboundPeeking(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
|
||||||
|
// Add inbound peek
|
||||||
|
if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if inboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
// Renew the peek
|
||||||
|
if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
peekIDs := []string{peekID}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||||
|
}
|
||||||
|
for i := range inboundPeeks {
|
||||||
|
if inboundPeeks[i].PeekID != peekIDs[i] {
|
||||||
|
t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
148
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
148
federationapi/storage/tables/inbound_peeks_table_test.go
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open database: %s", err)
|
||||||
|
}
|
||||||
|
var tab tables.FederationInboundPeeks
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresInboundPeeksTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteInboundPeeksTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInboundPeeksTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, closeDB := mustCreateInboundpeeksTable(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
// Insert a peek
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if inboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew the peek
|
||||||
|
if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(inboundPeek1, inboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if inboundPeek1.ServerName != inboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if inboundPeek1.RoomID != inboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the peek
|
||||||
|
if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be no peek anymore
|
||||||
|
peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if peek != nil {
|
||||||
|
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
var peekIDs []string
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks))
|
||||||
|
}
|
||||||
|
for i := range inboundPeeks {
|
||||||
|
if inboundPeeks[i].PeekID != peekIDs[i] {
|
||||||
|
t.Fatalf("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// And delete them again
|
||||||
|
if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// they should be gone now
|
||||||
|
inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inboundPeeks) > 0 {
|
||||||
|
t.Fatal("got inbound peeks which should be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
|
@ -28,7 +28,6 @@ type FederationQueuePDUs interface {
|
||||||
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||||
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||||
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||||
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||||
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||||
}
|
}
|
||||||
|
@ -38,7 +37,6 @@ type FederationQueueEDUs interface {
|
||||||
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||||
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||||
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||||
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
|
||||||
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||||
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
|
SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error)
|
||||||
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
|
DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error
|
||||||
|
|
147
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
147
federationapi/storage/tables/outbound_peeks_table_test.go
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, sqlutil.NewExclusiveWriter())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open database: %s", err)
|
||||||
|
}
|
||||||
|
var tab tables.FederationOutboundPeeks
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresOutboundPeeksTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutboundPeeksTable(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice)
|
||||||
|
_, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, closeDB := mustCreateOutboundpeeksTable(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
// Insert a peek
|
||||||
|
peekID := util.RandomString(8)
|
||||||
|
var renewalInterval int64 = 1000
|
||||||
|
if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// select the newly inserted peek
|
||||||
|
outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert fields are set as expected
|
||||||
|
if outboundPeek1.PeekID != peekID {
|
||||||
|
t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != room.ID {
|
||||||
|
t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID)
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != serverName {
|
||||||
|
t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RenewalInterval != renewalInterval {
|
||||||
|
t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Renew the peek
|
||||||
|
if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the values changed
|
||||||
|
outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reflect.DeepEqual(outboundPeek1, outboundPeek2) {
|
||||||
|
t.Fatal("expected a change peek, but they are the same")
|
||||||
|
}
|
||||||
|
if outboundPeek1.ServerName != outboundPeek2.ServerName {
|
||||||
|
t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName)
|
||||||
|
}
|
||||||
|
if outboundPeek1.RoomID != outboundPeek2.RoomID {
|
||||||
|
t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the peek
|
||||||
|
if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be no peek anymore
|
||||||
|
peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if peek != nil {
|
||||||
|
t.Fatalf("got a peek which should be deleted: %+v", peek)
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert some peeks
|
||||||
|
var peekIDs []string
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
peekID = util.RandomString(8)
|
||||||
|
if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peekIDs = append(peekIDs, peekID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now select them
|
||||||
|
outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) != len(peekIDs) {
|
||||||
|
t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks))
|
||||||
|
}
|
||||||
|
for i := range outboundPeeks {
|
||||||
|
if outboundPeeks[i].PeekID != peekIDs[i] {
|
||||||
|
t.Fatalf("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// And delete them again
|
||||||
|
if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// they should be gone now
|
||||||
|
outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(outboundPeeks) > 0 {
|
||||||
|
t.Fatal("got outbound peeks which should be deleted")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in a new issue