Fix sync API tests

This commit is contained in:
Neil Alexander 2020-07-21 12:14:45 +01:00
parent 7ea8ad2dfb
commit d1ca37a921
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 30 additions and 22 deletions

View file

@ -281,16 +281,16 @@ func (d *Database) WriteEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
) )
if err != nil { if err != nil {
return err return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
} }
pduPosition = pos pduPosition = pos
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
return err return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
} }
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
return err return fmt.Errorf("d.handleBackwardExtremities: %w", err)
} }
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
@ -313,7 +313,7 @@ func (d *Database) updateRoomState(
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs { for _, eventID := range removedEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
return err return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
} }
} }
@ -326,13 +326,13 @@ func (d *Database) updateRoomState(
if event.Type() == "m.room.member" { if event.Type() == "m.room.member" {
value, err := event.Membership() value, err := event.Membership()
if err != nil { if err != nil {
return err return fmt.Errorf("event.Membership: %w", err)
} }
membership = &value membership = &value
} }
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
return err return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
} }
} }

View file

@ -79,7 +79,7 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err return err
}) })
@ -110,7 +110,7 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
}) })

View file

@ -200,7 +200,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID( func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string, ctx context.Context, txn *sql.Tx, eventID string,
) error { ) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID) _, err := stmt.ExecContext(ctx, eventID)
return err return err
@ -225,7 +225,7 @@ func (s *currentRoomStateStatements) UpsertRoomState(
} }
// upsert state event // upsert state event
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,

View file

@ -95,7 +95,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent( func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) { ) (streamPos types.StreamPosition, err error) {
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var err error var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil { if err != nil {

View file

@ -295,11 +295,6 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return 0, err
}
addStateJSON, err := json.Marshal(addState) addStateJSON, err := json.Marshal(addState)
if err != nil { if err != nil {
return 0, err return 0, err
@ -309,7 +304,13 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err return 0, err
} }
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { var streamPos types.StreamPosition
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, ierr := insertStmt.ExecContext( _, ierr := insertStmt.ExecContext(
ctx, ctx,

View file

@ -107,7 +107,7 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos, ctx, event.EventID(), event.Depth(), event.RoomID(), pos,

View file

@ -103,7 +103,7 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return err return err
}) })
@ -163,7 +163,7 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k+1] = v params[k+1] = v
} }
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, query, params...) _, err := txn.ExecContext(ctx, query, params...)
return err return err
}) })
@ -177,7 +177,7 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids { for k, v := range nids {
params[k] = v params[k] = v
} }
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, query, params...) _, err := txn.ExecContext(ctx, query, params...)
return err return err
}) })

View file

@ -5,6 +5,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"testing" "testing"
"time" "time"
@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head
} }
func MustCreateDatabase(t *testing.T) storage.Database { func MustCreateDatabase(t *testing.T) storage.Database {
db, err := sqlite3.NewDatabase("file::memory:") dbname := fmt.Sprintf("test_%s.db", t.Name())
if _, err := os.Stat(dbname); err == nil {
if err = os.Remove(dbname); err != nil {
t.Fatalf("tried to delete stale test database but failed: %s", err)
}
}
db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname))
if err != nil { if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err) t.Fatalf("NewSyncServerDatasource returned %s", err)
} }