diff --git a/SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp b/SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp index 60d040532ca..a6304f0d5a4 100644 --- a/SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp +++ b/SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp @@ -36,38 +36,21 @@ * \brief Block Gauss-Seidel driver for multizone / multiphysics discrete adjoint problems. * \ingroup DiscAdj */ -class CDiscAdjMultizoneDriver : public CMultizoneDriver { -protected: -#ifdef CODI_FORWARD_TYPE - using Scalar = su2double; -#else - using Scalar = passivedouble; -#endif - - class AdjointProduct : public CMatrixVectorProduct { - public: - CDiscAdjMultizoneDriver* const driver; - const unsigned short iZone = 0; - mutable unsigned long iInnerIter = 0; - - AdjointProduct(CDiscAdjMultizoneDriver* d, unsigned short i) : driver(d), iZone(i) {} - - inline void operator()(const CSysVector & u, CSysVector & v) const override { - driver->SetAllSolutions(iZone, true, u); - driver->Iterate(iZone, iInnerIter, true); - driver->GetAllSolutions(iZone, true, v); - v -= u; - ++iInnerIter; - } - }; +template class AdjointProduct; +template class Identity; - class Identity : public CPreconditioner { - public: - inline bool IsIdentity() const override { return true; } - inline void operator()(const CSysVector & u, CSysVector & v) const override { v = u; } - }; +class CDiscAdjMultizoneDriver : public CMultizoneDriver { +protected: + #ifdef CODI_FORWARD_TYPE + using Scalar = su2double; + #else + using Scalar = passivedouble; + #endif + + friend class AdjointProduct; + friend class Identity; /*! * \brief Kinds of recordings. */ @@ -304,3 +287,29 @@ class CDiscAdjMultizoneDriver : public CMultizoneDriver { } }; + + +template +class AdjointProduct : public CMatrixVectorProduct { +public: + CDiscAdjMultizoneDriver* const driver; + const unsigned short iZone = 0; + mutable unsigned long iInnerIter = 0; + + AdjointProduct(CDiscAdjMultizoneDriver* d, unsigned short i) : driver(d), iZone(i) {} + + inline void operator()(const CSysVector & u, CSysVector & v) const override { + driver->SetAllSolutions(iZone, true, u); + driver->Iterate(iZone, iInnerIter, true); + driver->GetAllSolutions(iZone, true, v); + v -= u; + ++iInnerIter; + } +}; + +template +class Identity : public CPreconditioner { +public: + inline bool IsIdentity() const override { return true; } + inline void operator()(const CSysVector & u, CSysVector & v) const override { v = u; } +}; diff --git a/SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp b/SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp index 61673bdad1b..aea7594f9c3 100644 --- a/SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp +++ b/SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp @@ -376,7 +376,7 @@ void CDiscAdjMultizoneDriver::KrylovInnerIters(unsigned short iZone) { GetAllSolutions(iZone, true, AdjSol[iZone]); const bool monitor = config_container[iZone]->GetWrt_ZoneConv(); - const auto product = AdjointProduct(this, iZone); + const auto product = AdjointProduct(this, iZone); /*--- Manipulate the screen output frequency to avoid printing garbage. ---*/ const auto wrtFreq = config_container[iZone]->GetScreen_Wrt_Freq(2); @@ -388,7 +388,7 @@ void CDiscAdjMultizoneDriver::KrylovInnerIters(unsigned short iZone) { Scalar eps_l = 0.0; Scalar tol_l = KrylovTol / eps; auto iter = min(totalIter-2ul, config_container[iZone]->GetnQuasiNewtonSamples()-2ul); - iter = LinSolver[iZone].FGCRODR_LinSolver(AdjRHS[iZone], AdjSol[iZone], product, Identity(), + iter = LinSolver[iZone].FGCRODR_LinSolver(AdjRHS[iZone], AdjSol[iZone], product, Identity(), tol_l, iter, eps_l, monitor, config_container[iZone], FgcrodrMode::SAME_MAT, iter); totalIter -= iter+1; diff --git a/SU2_PY/pySU2/pySU2ad.i b/SU2_PY/pySU2/pySU2ad.i index 7959e06231d..e3394aa5a0e 100644 --- a/SU2_PY/pySU2/pySU2ad.i +++ b/SU2_PY/pySU2/pySU2ad.i @@ -39,6 +39,7 @@ threads="1" %{ #include "../../Common/include/containers/CPyWrapperMatrixView.hpp" #include "../../SU2_CFD/include/drivers/CDiscAdjSinglezoneDriver.hpp" +#include "../../SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp" #include "../../SU2_CFD/include/drivers/CDriver.hpp" #include "../../SU2_CFD/include/drivers/CDriverBase.hpp" #include "../../SU2_CFD/include/drivers/CMultizoneDriver.hpp" @@ -98,4 +99,5 @@ const unsigned int ZONE_1 = 1; /*!< \brief Definition of the first grid domain. %include "../../SU2_CFD/include/drivers/CSinglezoneDriver.hpp" %include "../../SU2_CFD/include/drivers/CMultizoneDriver.hpp" %include "../../SU2_CFD/include/drivers/CDiscAdjSinglezoneDriver.hpp" +%include "../../SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp" %include "../../SU2_DEF/include/drivers/CDiscAdjDeformationDriver.hpp"