Skip to content

Commit 92a39b8

Browse files
committed
[fnc] unaryize_function + binding
1 parent 466428f commit 92a39b8

2 files changed

Lines changed: 59 additions & 0 deletions

File tree

python/src/core/functions/analytic/codac2_py_AnalyticFunction.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
#include <codac2_Parallelepiped.h>
1919
#include <codac2_AnalyticFunction.h>
2020
#include <codac2_analytic_variables.h>
21+
#include <codac2_analytic_flat_input_layout.h>
2122
#include "codac2_py_AnalyticFunction_docs.h" // Generated file from Doxygen XML (doxygen2docstring.py)
2223
#include "codac2_py_AnalyticFunction_impl_docs.h" // Generated file from Doxygen XML (doxygen2docstring.py)
2324
#include "codac2_py_FunctionBase_docs.h" // Generated file from Doxygen XML (doxygen2docstring.py)
25+
#include "codac2_py_analytic_flat_input_layout_docs.h" // Generated file from Doxygen XML (doxygen2docstring.py)
2426
#include "codac2_py_AnalyticExprWrapper.h"
2527
#include "codac2_py_cast.h"
2628

@@ -486,4 +488,12 @@ void export_AnalyticFunction(py::module& m, const std::string& export_name)
486488
},
487489
OSTREAM_REF_OPERATOROUT_OSTREAM_REF_CONST_ANALYTICFUNCTION_U_REF)
488490
;
491+
492+
493+
m.def("unaryize_function", (AnalyticFunction<ScalarType> (*)(const AnalyticFunction<ScalarType>&))&codac2::unaryize_function,
494+
ANALYTICFUNCTION_T_UNARYIZE_FUNCTION_CONST_ANALYTICFUNCTION_T_REF,
495+
"f"_a);
496+
m.def("unaryize_function", (AnalyticFunction<VectorType> (*)(const AnalyticFunction<VectorType>&))&codac2::unaryize_function,
497+
ANALYTICFUNCTION_T_UNARYIZE_FUNCTION_CONST_ANALYTICFUNCTION_T_REF,
498+
"f"_a);
489499
}

src/core/functions/analytic/codac2_analytic_flat_input_layout.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,53 @@ namespace codac2
172172
std::unordered_map<Index,FlatInputBinding> _bindings; //!< Bindings indexed by input expression identifier.
173173
Index _size = 0; //!< Total number of scalar inputs in the flattened domain.
174174
};
175+
176+
template<typename T>
177+
static std::shared_ptr<ExprBase> as_expr_base(const AnalyticExprWrapper<T>& e)
178+
{
179+
return std::static_pointer_cast<ExprBase>(
180+
std::shared_ptr<AnalyticExpr<T>>(e)
181+
);
182+
}
183+
184+
template<typename T>
185+
requires std::is_base_of_v<AnalyticTypeBase,T>
186+
class AnalyticFunction;
187+
188+
template<typename T>
189+
inline AnalyticFunction<T> unaryize_function(const AnalyticFunction<T>& f)
190+
{
191+
if(f.nb_args() == 0 || (f.nb_args() == 1 && std::dynamic_pointer_cast<VectorVar>(f.args()[0])))
192+
return f;
193+
194+
FlatInputLayout layout(f.args());
195+
VectorVar flat_x(layout.size(), "x");
196+
197+
auto y = std::dynamic_pointer_cast<AnalyticExpr<T>>(f.expr()->copy());
198+
assert(y && "unaryize_function: unable to copy analytic expression");
199+
200+
for(const auto& arg : f.args())
201+
{
202+
const auto& b = layout.binding_of(arg->unique_id());
203+
204+
if(std::dynamic_pointer_cast<ScalarVar>(arg))
205+
y->replace_arg(arg->unique_id(), as_expr_base(flat_x[b.offset]));
206+
207+
else if(std::dynamic_pointer_cast<VectorVar>(arg))
208+
{
209+
assert(b.cols == 1 && "unaryize_function: invalid flat binding for vector input");
210+
y->replace_arg(
211+
arg->unique_id(),
212+
as_expr_base(flat_x.subvector(b.offset, b.offset + b.rows - 1))
213+
);
214+
}
215+
216+
else
217+
{
218+
assert(false && "unaryize_function: only scalar/vector input arguments are currently supported");
219+
}
220+
}
221+
222+
return AnalyticFunction<T>({flat_x}, y);
223+
}
175224
}

0 commit comments

Comments
 (0)