diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index c22d893f7..5df684318 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index e29026fab..33f452b68 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go new file mode 100644 index 000000000..11679addb --- /dev/null +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -0,0 +1,148 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationOutboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresOutboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestOutboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateDatabase(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err := tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // delete the peek + if err := tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err := tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) > 0 { + t.Fatal("got outbound peeks which should be deleted") + } + + }) +}