dendrite/relayapi/storage/fake_relay_db.go

96 lines
2.6 KiB
Go

package storage
import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/matrix-org/gomatrixserverlib"
)
type testDatabase struct {
nid int64
nidMutex sync.Mutex
transactions map[int64]json.RawMessage
associations map[gomatrixserverlib.ServerName][]int64
}
func NewFakeRelayDatabase() *testDatabase {
return &testDatabase{
nid: 1,
nidMutex: sync.Mutex{},
transactions: make(map[int64]json.RawMessage),
associations: make(map[gomatrixserverlib.ServerName][]int64),
}
}
func (d *testDatabase) InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error {
if _, ok := d.associations[serverName]; !ok {
d.associations[serverName] = []int64{}
}
d.associations[serverName] = append(d.associations[serverName], nid)
return nil
}
func (d *testDatabase) DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error {
for _, nid := range jsonNIDs {
for index, associatedNID := range d.associations[serverName] {
if associatedNID == nid {
d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...)
}
}
}
return nil
}
func (d *testDatabase) SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) {
results := []int64{}
resultCount := limit
if limit > len(d.associations[serverName]) {
resultCount = len(d.associations[serverName])
}
if resultCount > 0 {
for i := 0; i < resultCount; i++ {
results = append(results, d.associations[serverName][i])
}
}
return results, nil
}
func (d *testDatabase) SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) {
return int64(len(d.associations[serverName])), nil
}
func (d *testDatabase) InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) {
d.nidMutex.Lock()
defer d.nidMutex.Unlock()
nid := d.nid
d.transactions[nid] = []byte(json)
d.nid++
return nid, nil
}
func (d *testDatabase) DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error {
for _, nid := range nids {
delete(d.transactions, nid)
}
return nil
}
func (d *testDatabase) SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) {
result := make(map[int64][]byte)
for _, nid := range jsonNIDs {
if transaction, ok := d.transactions[nid]; ok {
result[nid] = transaction
}
}
return result, nil
}