Skip to content

Commit 6ab6ed3

Browse files
committed
Merge branch 'improve-pair-mask-handling'
* improve argument conversion in PairQuantity.setPairMask * expose internal helper `isiterable` * raise error for invalid numpy int conversion Resolve #19.
2 parents f05127a + 12379de commit 6ab6ed3

5 files changed

Lines changed: 40 additions & 18 deletions

File tree

src/diffpy/srreal/tests/testpairquantity.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import unittest
77
import pickle
8+
import numpy
89

910
from diffpy.srreal.pairquantity import PairQuantity
1011
from diffpy.srreal.srreal_ext import BasePairQuantity
@@ -76,6 +77,22 @@ def test_setStructure(self):
7677
return
7778

7879

80+
def test_setPairMask_args(self):
81+
"""check argument type handling in setPairMask
82+
"""
83+
spm = self.pq.setPairMask
84+
gpm = self.pq.getPairMask
85+
self.assertRaises(TypeError, spm, 0.0, 0, False)
86+
self.assertRaises(TypeError, spm, numpy.complex(0.5), 0, False)
87+
self.assertTrue(gpm(0, 0))
88+
spm(numpy.int32(1), 0, True, others=False)
89+
self.assertTrue(gpm(0, 1))
90+
self.assertTrue(gpm(1, 0))
91+
self.assertFalse(gpm(0, 0))
92+
self.assertFalse(gpm(2, 7))
93+
return
94+
95+
7996
def test_getStructure(self):
8097
"""check PairQuantity.getStructure()
8198
"""

src/extensions/srreal_converters.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <diffpy/srreal/StructureAdapter.hpp>
3232

3333
#include "srreal_converters.hpp"
34+
#include "srreal_validators.hpp"
3435

3536
#include "srreal_numpy_symbol.hpp"
3637
// numpy/arrayobject.h needs to be included after srreal_numpy_symbol.hpp,
@@ -73,19 +74,6 @@ boost::python::object newNumPyArray(int dim, const int* sz, int typenum)
7374
return rv;
7475
}
7576

76-
77-
bool isiterable(boost::python::object obj)
78-
{
79-
using namespace boost::python;
80-
#if PY_MAJOR_VERSION >= 3
81-
object Iterable = import("collections.abc").attr("Iterable");
82-
#else
83-
object Iterable = import("collections").attr("Iterable");
84-
#endif
85-
bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr()));
86-
return rv;
87-
}
88-
8977
} // namespace
9078

9179
namespace srrealmodule {
@@ -304,6 +292,7 @@ int extractint(boost::python::object obj)
304292
if (PyArray_CheckScalar(pobj))
305293
{
306294
int rv = PyArray_PyIntAsInt(pobj);
295+
if (rv == -1 && PyErr_Occurred()) python::throw_error_already_set();
307296
return rv;
308297
}
309298
// nothing worked, call geti which will raise an exception

src/extensions/srreal_validators.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*****************************************************************************/
1818

1919
#include <boost/python/errors.hpp>
20+
#include <boost/python/import.hpp>
2021

2122
#include "srreal_validators.hpp"
2223

@@ -43,6 +44,19 @@ void ensure_non_negative(int value)
4344
}
4445
}
4546

47+
48+
bool isiterable(boost::python::object obj)
49+
{
50+
using boost::python::import;
51+
#if PY_MAJOR_VERSION >= 3
52+
object Iterable = import("collections.abc").attr("Iterable");
53+
#else
54+
object Iterable = import("collections").attr("Iterable");
55+
#endif
56+
bool rv = (1 == PyObject_IsInstance(obj.ptr(), Iterable.ptr()));
57+
return rv;
58+
}
59+
4660
} // namespace srrealmodule
4761

4862
// End of file

src/extensions/srreal_validators.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace srrealmodule {
2323

2424
void ensure_index_bounds(int idx, int lo, int hi);
2525
void ensure_non_negative(int value);
26+
bool isiterable(boost::python::object obj);
2627

2728
} // namespace srrealmodule
2829

src/extensions/wrap_PairQuantity.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
#include "srreal_converters.hpp"
4141
#include "srreal_pickling.hpp"
42+
#include "srreal_validators.hpp"
4243

4344
#include <diffpy/srreal/PairQuantity.hpp>
4445

@@ -488,13 +489,13 @@ void set_pair_mask(PairQuantity& obj,
488489
python::object others)
489490
{
490491
if (!others.is_none()) mask_all_pairs(obj, others);
491-
python::extract<int> geti(i);
492-
python::extract<int> getj(j);
493492
bool mask = msk;
494-
// short circuit for normal call
495-
if (geti.check() && getj.check())
493+
// short circuit for normal call with scalar values
494+
if (!isiterable(i) && !isiterable(j))
496495
{
497-
obj.setPairMask(geti(), getj(), mask);
496+
const int i1 = extractint(i);
497+
const int j1 = extractint(j);
498+
obj.setPairMask(i1, j1, mask);
498499
return;
499500
}
500501
std::vector<int> iindices = parsepairindex(i);

0 commit comments

Comments
 (0)