Skip to content

Commit 5e0cd3f

Browse files
author
Martin D. Weinberg
committed
Merge branch 'WrapperFix' into devel
2 parents 6ec69f7 + d7c0e7e commit 5e0cd3f

6 files changed

Lines changed: 49 additions & 21 deletions

File tree

expui/BasisFactory.H

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ namespace BasisClasses
259259
if (Naccel > 0) pseudo = currentAccel(time);
260260
}
261261

262+
//! Get the field label vector
263+
std::vector<std::string> getFieldLabels(void)
264+
{ return getFieldLabels(coordinates); }
265+
262266
};
263267

264268
using BasisPtr = std::shared_ptr<Basis>;

expui/Coefficients.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2792,7 +2792,7 @@ namespace CoefClasses
27922792

27932793
for (int t=0; t<ntim; t++) {
27942794
auto & cof = *(coefs[roundTime(times[t])]->coefs);
2795-
for (int i=0; i<4; i++) {
2795+
for (int i=0; i<Nfld; i++) {
27962796
for (int l=0; l<(Lmax+2)*(Lmax+1)/2; l++) {
27972797
for (int n=0; n<Nmax; n++) {
27982798
ret(i, l, n, t) = cof(i, l, n);

expui/FieldBasis.H

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ namespace BasisClasses
9898
virtual std::vector<double>
9999
crt_eval(double x, double y, double z);
100100

101+
//! Get the field labels
102+
std::vector<std::string> getFieldLabels(const Coord ctype)
103+
{ return fieldLabels; }
104+
101105
public:
102106

103-
//! Constructor from YAML node
107+
//! Constructor from YAML node
104108
FieldBasis(const YAML::Node& conf,
105109
const std::string& name="FieldBasis") : Basis(conf, name)
106110
{ configure(); }
@@ -158,10 +162,6 @@ namespace BasisClasses
158162
{
159163
}
160164

161-
//! Get the field labels
162-
std::vector<std::string> getFieldLabels(const Coord ctype)
163-
{ return fieldLabels; }
164-
165165
//! Return current maximum harmonic order in expansion
166166
int getLmax() { return lmax; }
167167

expui/FieldBasis.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,8 @@ namespace BasisClasses
731731
double r = sqrt(R*R + z*z);
732732

733733
double vr = (u*x + v*y + w*z)/r;
734-
double vt = (u*z*x + v*z*y - w*R)/R/r;
735-
double vp = (u*y - v*x)/R;
734+
double vt = (u*z*x + v*z*y - w*R*R)/R/r;
735+
double vp = (v*x - u*y)/R;
736736

737737
return {vr, vt, vp, vr*vr, vt*vt, vp*vp};
738738
}

pyEXP/BasisWrappers.cc

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,25 @@ void BasisFactoryClasses(py::module &m)
928928
Returns
929929
-------
930930
None
931-
)");
931+
)")
932+
.def("getFieldLabels",
933+
[](BasisClasses::Basis& A)
934+
{
935+
return A.getFieldLabels();
936+
},
937+
R"(
938+
Provide the field labels for the basis functions
932939
940+
Parameters
941+
----------
942+
None
943+
944+
Returns
945+
-------
946+
list: str
947+
list of basis function labels
948+
)"
949+
);
933950

934951
py::class_<BasisClasses::BiorthBasis, std::shared_ptr<BasisClasses::BiorthBasis>, PyBiorthBasis, BasisClasses::Basis>
935952
(m, "BiorthBasis")
@@ -1279,9 +1296,9 @@ void BasisFactoryClasses(py::module &m)
12791296
// orthoCheck is not in the base class and needs to have different
12801297
// parameters depending on the basis type. Here, the quadrature
12811298
// is determined by the scale of the meridional grid.
1282-
.def("orthoCheck", [](BasisClasses::Cylindrical& A)
1299+
.def("orthoCheck", [](BasisClasses::Cylindrical& A, int knots)
12831300
{
1284-
return A.orthoCheck();
1301+
return A.orthoCheck(knots);
12851302
},
12861303
R"(
12871304
Check orthgonality of basis functions by quadrature
@@ -1298,7 +1315,7 @@ void BasisFactoryClasses(py::module &m)
12981315
-------
12991316
list(numpy.ndarray)
13001317
list of numpy.ndarrays from [0, ... , Mmax]
1301-
)")
1318+
)", py::arg("knots")=400)
13021319
.def_static("cacheInfo", [](std::string cachefile)
13031320
{
13041321
return BasisClasses::Cylindrical::cacheInfo(cachefile);
@@ -2044,11 +2061,11 @@ void BasisFactoryClasses(py::module &m)
20442061
// orthoCheck is not in the base class and needs to have
20452062
// different parameters depending on the basis type. Here the
20462063
// user can and will often need to specify a quadrature value.
2047-
.def("orthoCheck", [](BasisClasses::FieldBasis& A)
2064+
.def("orthoCheck", [](BasisClasses::FieldBasis& A)
20482065
{
20492066
return A.orthoCheck();
20502067
},
2051-
R"(
2068+
R"(
20522069
Check orthgonality of basis functions by quadrature
20532070
20542071
Inner-product matrix of orthogonal functions
@@ -2062,10 +2079,10 @@ void BasisFactoryClasses(py::module &m)
20622079
numpy.ndarray
20632080
orthogonality matrix
20642081
)"
2065-
);
2082+
);
20662083

20672084
py::class_<BasisClasses::VelocityBasis, std::shared_ptr<BasisClasses::VelocityBasis>, BasisClasses::FieldBasis>(m, "VelocityBasis")
2068-
.def(py::init<const std::string&>(),
2085+
.def(py::init<const std::string&>(),
20692086
R"(
20702087
Create a orthogonal velocity-field basis
20712088

pyEXP/TensorToArray.H

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ py::array_t<T> make_ndarray3(Eigen::Tensor<T, 3>& mat)
1111
// Check rank
1212
if (dims.size() != 3) {
1313
std::ostringstream sout;
14-
sout << "make_ndarray: tensor rank must be 3, found " << dims.size();
14+
sout << "make_ndarray3: tensor rank must be 3, found " << dims.size();
1515
throw std::runtime_error(sout.str());
1616
}
1717

@@ -37,10 +37,17 @@ py::array_t<T> make_ndarray4(Eigen::Tensor<T, 4>& mat)
3737
// Check rank
3838
if (dims.size() != 4) {
3939
std::ostringstream sout;
40-
sout << "make_ndarray: tensor rank must be 4, found " << dims.size();
40+
sout << "make_ndarray4: tensor rank must be 4, found " << dims.size();
4141
throw std::runtime_error(sout.str());
4242
}
4343

44+
// Sanity check
45+
for (int i=0; i<mat.size(); i++) {
46+
if (isnan(std::abs(mat.data()[i]))) {
47+
throw std::runtime_error("make_ndarray4: NaN encountered");
48+
}
49+
}
50+
4451
// Make the memory mapping
4552
return py::array_t<T>
4653
(
@@ -107,11 +114,11 @@ Eigen::Tensor<T, 4> make_tensor4(py::array_t<T> array)
107114

108115
// Build result tensor with col-major ordering
109116
Eigen::Tensor<T, 4> tensor(shape[0], shape[1], shape[2], shape[3]);
110-
for (int i=0, l=0; i < shape[0]; i++) {
117+
for (int i=0, c=0; i < shape[0]; i++) {
111118
for (int j=0; j < shape[1]; j++) {
112119
for (int k=0; k < shape[2]; k++) {
113-
for (int l=0; l < shape[3]; k++) {
114-
tensor(i, j, k, l) = data[l++];
120+
for (int l=0; l < shape[3]; l++, c++) {
121+
tensor(i, j, k, l) = data[c];
115122
}
116123
}
117124
}

0 commit comments

Comments
 (0)