diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 2ef02efd9..c7ca019ff 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -44,7 +44,7 @@ func TestAdminResetPassword(t *testing.T) { userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) keyAPI.SetUserAPI(userAPI) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -115,6 +115,7 @@ func TestAdminResetPassword(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) if tc.requestOpt != nil { @@ -215,8 +216,9 @@ func TestPurgeRoom(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { - req := test.NewRequest(t, http.MethodGet, "/_dendrite/admin/purgeRoom/"+tc.roomID) + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 6a976e380..93f6ea901 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -169,7 +169,7 @@ func Setup( httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminPurgeRoom(req, cfg, device, rsAPI) }), - ).Methods(http.MethodGet, http.MethodOptions) + ).Methods(http.MethodPost, http.MethodOptions) dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go index fb8ad7027..efba439bd 100644 --- a/roomserver/storage/postgres/purge_statements.go +++ b/roomserver/storage/postgres/purge_statements.go @@ -95,79 +95,39 @@ func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { }.Prepare(db) } -func (s *purgeStatements) PurgeEventJSONs( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeEventJSONStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeEvents( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeInvites( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeMemberships( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgePreviousEvents( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgePreviousEventsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgePublished( - ctx context.Context, txn *sql.Tx, roomID string, -) error { - _, err := sqlutil.TxStmt(txn, s.purgePublishedStmt).ExecContext(ctx, roomID) - return err -} - -func (s *purgeStatements) PurgeRedactions( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeRedactionStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeRoomAliases( - ctx context.Context, txn *sql.Tx, roomID string, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeRoomAliasesStmt).ExecContext(ctx, roomID) - return err -} - func (s *purgeStatements) PurgeRoom( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, ) error { - _, err := sqlutil.TxStmt(txn, s.purgeRoomStmt).ExecContext(ctx, roomNID) - return err -} -func (s *purgeStatements) PurgeStateBlocks( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeStateBlockEntriesStmt).ExecContext(ctx, roomNID) - return err -} + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } -func (s *purgeStatements) PurgeStateSnapshots( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeStateSnapshotEntriesStmt).ExecContext(ctx, roomNID) - return err + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a0f46e238..654b078d2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1457,40 +1457,7 @@ func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { } return fmt.Errorf("failed to lock the room: %w", err) } - if err = d.Purge.PurgeStateBlocks(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge state blocks: %w", err) - } - if err = d.Purge.PurgeStateSnapshots(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge state snapshots: %w", err) - } - if err = d.Purge.PurgeInvites(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge invites: %w", err) - } - if err = d.Purge.PurgeMemberships(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge memberships: %w", err) - } - if err = d.Purge.PurgeRoomAliases(ctx, txn, roomID); err != nil { - return fmt.Errorf("failed to purge room aliases: %w", err) - } - if err = d.Purge.PurgePublished(ctx, txn, roomID); err != nil { - return fmt.Errorf("failed to purge published: %w", err) - } - if err = d.Purge.PurgePreviousEvents(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge previous events: %w", err) - } - if err = d.Purge.PurgeEventJSONs(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge event JSONs: %w", err) - } - if err = d.Purge.PurgeRedactions(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge redactions: %w", err) - } - if err = d.Purge.PurgeEvents(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge events: %w", err) - } - if err = d.Purge.PurgeRoom(ctx, txn, roomNID); err != nil { - return fmt.Errorf("failed to purge room: %w", err) - } - return nil + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) }) } diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go index cebca0ed4..c7b4d27a5 100644 --- a/roomserver/storage/sqlite3/purge_statements.go +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -89,70 +89,47 @@ func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) }.Prepare(db) } -func (s *purgeStatements) PurgeEventJSONs( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeEventJSONStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeEvents( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeInvites( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeMemberships( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgePreviousEvents( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgePreviousEventsStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgePublished( - ctx context.Context, txn *sql.Tx, roomID string, -) error { - _, err := sqlutil.TxStmt(txn, s.purgePublishedStmt).ExecContext(ctx, roomID) - return err -} - -func (s *purgeStatements) PurgeRedactions( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeRedactionStmt).ExecContext(ctx, roomNID) - return err -} - -func (s *purgeStatements) PurgeRoomAliases( - ctx context.Context, txn *sql.Tx, roomID string, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeRoomAliasesStmt).ExecContext(ctx, roomID) - return err -} - func (s *purgeStatements) PurgeRoom( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, ) error { - _, err := sqlutil.TxStmt(txn, s.purgeRoomStmt).ExecContext(ctx, roomNID) - return err + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil } -func (s *purgeStatements) PurgeStateBlocks( +func (s *purgeStatements) purgeStateBlocks( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) error { // Get all stateBlockNIDs @@ -174,10 +151,3 @@ func (s *purgeStatements) PurgeStateBlocks( query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) } - -func (s *purgeStatements) PurgeStateSnapshots( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, -) error { - _, err := sqlutil.TxStmt(txn, s.purgeStateSnapshotEntriesStmt).ExecContext(ctx, roomNID) - return err -} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index ee6feea91..64145f83d 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -175,17 +175,9 @@ type Redactions interface { } type Purge interface { - PurgeEventJSONs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeEvents(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeStateSnapshots(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeStateBlocks(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgePreviousEvents(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeInvites(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeMemberships(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgePublished(ctx context.Context, txn *sql.Tx, roomID string) error - PurgeRedactions(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) error - PurgeRoomAliases(ctx context.Context, txn *sql.Tx, roomID string) error + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error } // StrippedEvent represents a stripped event for returning extracted content values. diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index d30554c19..21838039a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -131,7 +131,8 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms case api.OutputTypePurgeRoom: err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) if err != nil { - return true + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room } default: log.WithField("type", output.Type).Debug(