@@ -21,6 +21,48 @@ using namespace codac2;
2121namespace py = pybind11;
2222using namespace pybind11 ::literals;
2323
24+ template <typename T>
25+ void export_contract (py::class_<CtcDeriv>& exported)
26+ {
27+ exported
28+
29+ .def (" contract" ,
30+ [](const CtcDeriv& ctc, Slice<T>& x, const Slice<T>& v, const std::vector<Index_type>& ctc_indices)
31+ -> py::tuple
32+ {
33+ if constexpr (std::is_same_v<T,IntervalVector>)
34+ ctc.contract (x, v, matlab::convert_indices (ctc_indices));
35+ else
36+ ctc.contract (x, v);
37+
38+ return py::make_tuple (
39+ py::cast (x, py::return_value_policy::reference),
40+ py::cast (v, py::return_value_policy::reference)
41+ );
42+ },
43+ VOID_CTCDERIV_CONTRACT_SLICE_T_REF_CONST_SLICE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
44+ " x" _a, " v" _a, " ctc_indices" _a = std::vector<Index>())
45+
46+ .def (" contract" ,
47+ [](const CtcDeriv& ctc, SlicedTube<T>& x, const SlicedTube<T>& v, const std::vector<Index_type>& ctc_indices)
48+ -> py::tuple
49+ {
50+ if constexpr (std::is_same_v<T,IntervalVector>)
51+ ctc.contract (x, v, matlab::convert_indices (ctc_indices));
52+ else
53+ ctc.contract (x, v);
54+
55+ return py::make_tuple (
56+ py::cast (x, py::return_value_policy::reference),
57+ py::cast (v, py::return_value_policy::reference)
58+ );
59+ },
60+ VOID_CTCDERIV_CONTRACT_SLICEDTUBE_T_REF_CONST_SLICEDTUBE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
61+ " x" _a, " v" _a, " ctc_indices" _a = std::vector<Index>())
62+
63+ ;
64+ }
65+
2466void export_CtcDeriv (py::module & m)
2567{
2668 py::class_<CtcDeriv> exported (m, " CtcDeriv" , CTCDERIV_MAIN);
@@ -33,32 +75,8 @@ void export_CtcDeriv(py::module& m)
3375 .def (" restrict_tdomain" , &CtcDeriv::restrict_tdomain,
3476 VOID_CTCDERIV_RESTRICT_TDOMAIN_CONST_INTERVAL_REF,
3577 " tdomain" _a)
36-
37- // Contractions on Slice objects
38-
39- .def (" contract" , (void (CtcDeriv::*)(Slice<Interval>&,const Slice<Interval>&,const std::vector<Index>&) const )&CtcDeriv::contract,
40- VOID_CTCDERIV_CONTRACT_SLICE_T_REF_CONST_SLICE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
41- " x" _a, " v" _a, " ctc_indices" _a=std::vector<Index>())
42-
43- .def (" contract" , [](const CtcDeriv& ctc, Slice<IntervalVector>& x, const Slice<IntervalVector>& v, const std::vector<Index_type>& ctc_indices)
44- {
45- ctc.contract (x, v, matlab::convert_indices (ctc_indices));
46- },
47- VOID_CTCDERIV_CONTRACT_SLICE_T_REF_CONST_SLICE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
48- " x" _a, " v" _a, " ctc_indices" _a=std::vector<Index>())
49-
50- // Contractions on SlicedTube objects
51-
52- .def (" contract" , (void (CtcDeriv::*)(SlicedTube<Interval>&,const SlicedTube<Interval>&,const std::vector<Index>&) const )&CtcDeriv::contract,
53- VOID_CTCDERIV_CONTRACT_SLICEDTUBE_T_REF_CONST_SLICEDTUBE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
54- " x" _a, " v" _a, " ctc_indices" _a=std::vector<Index>())
55-
56- .def (" contract" , [](const CtcDeriv& ctc, SlicedTube<IntervalVector>& x, SlicedTube<IntervalVector>& v, const std::vector<Index_type>& ctc_indices)
57- {
58- return ctc.contract (x, v, matlab::convert_indices (ctc_indices));
59- },
60- VOID_CTCDERIV_CONTRACT_SLICEDTUBE_T_REF_CONST_SLICEDTUBE_T_REF_CONST_VECTOR_INDEX_REF_CONST,
61- " x" _a, " v" _a, " ctc_indices" _a=std::vector<Index>())
62-
6378 ;
79+
80+ export_contract<Interval>(exported);
81+ export_contract<IntervalVector>(exported);
6482}
0 commit comments