Skip to content

Commit 04dcbc1

Browse files
committed
Grid conserves subclasses of ndarray (#56)
- uses solution proposed by @AstroMike (np.asanyarray()) - fix #56 - add test (with np.ma.MaskedArray) - update CHANGELOG
1 parent be64439 commit 04dcbc1

3 files changed

Lines changed: 22 additions & 8 deletions

File tree

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ The rules for this file:
2424
* Added missing floordivision to Grid (PR #53)
2525
* fix test on ARM (#51)
2626
* fix incorrect reading of ncstart and nrstart in CCP4 (#57)
27+
* fix that arithemtical operations broke inheritance (#56)
28+
* fix so that subclasses of ndarray are retained on input (#56)
2729

2830
Changes (do not affect user)
2931

gridData/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
126126
self.load(grid)
127127
elif not (grid is None or edges is None):
128128
# set up from histogramdd-type data
129-
self.grid = numpy.asarray(grid)
129+
self.grid = numpy.asanyarray(grid)
130130
self.edges = edges
131131
self._update()
132132
elif not (grid is None or origin is None or delta is None):
133133
# setup from generic data
134-
origin = numpy.asarray(origin)
135-
delta = numpy.asarray(delta)
134+
origin = numpy.asanyarray(origin)
135+
delta = numpy.asanyarray(delta)
136136
if len(origin) != grid.ndim:
137137
raise TypeError(
138138
"Dimension of origin is not the same as grid dimension.")
@@ -148,7 +148,7 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
148148
self.edges = [origin[dim] +
149149
(numpy.arange(m + 1) - 0.5) * delta[dim]
150150
for dim, m in enumerate(grid.shape)]
151-
self.grid = numpy.asarray(grid)
151+
self.grid = numpy.asanyarray(grid)
152152
self._update()
153153
else:
154154
# empty, must manually populate with load()
@@ -709,7 +709,7 @@ def ndmeshgrid(*arrs):
709709
for i, arr in enumerate(arrs):
710710
slc = [1] * dim
711711
slc[i] = lens[i]
712-
arr2 = numpy.asarray(arr).reshape(slc)
712+
arr2 = numpy.asanyarray(arr).reshape(slc)
713713
for j, sz in enumerate(lens):
714714
if j != i:
715715
arr2 = arr2.repeat(sz, axis=j)

gridData/tests/test_grid.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
from gridData import Grid
1111

12+
def f_arithmetic(g):
13+
return g + g - 2.5 * g / (g + 5.3)
14+
1215
@pytest.fixture(scope="class")
1316
def data():
1417
d = dict(
@@ -148,11 +151,20 @@ class DerivedGrid(Grid):
148151

149152
dg = DerivedGrid(data['griddata'], origin=data['origin'],
150153
delta=data['delta'])
151-
result = dg + dg - 2.5 * dg / (dg + 5.3)
154+
result = f_arithmetic(dg)
152155

153156
assert isinstance(result, DerivedGrid)
154157

155-
g = data['grid']
156-
ref = g + g - 2.5 * g / (dg + 5.3)
158+
ref = f_arithmetic(data['grid'])
157159
assert_almost_equal(result.grid, ref.grid)
158160

161+
def test_anyarray(data):
162+
ma = np.ma.MaskedArray(data['griddata'])
163+
mg = Grid(ma, origin=data['origin'], delta=data['delta'])
164+
165+
assert isinstance(mg.grid, ma.__class__)
166+
167+
result = f_arithmetic(mg)
168+
ref = f_arithmetic(data['grid'])
169+
170+
assert_almost_equal(result.grid, ref.grid)

0 commit comments

Comments
 (0)