Skip to content

Commit 6c47ad6

Browse files
authored
Merge pull request #63 from MDAnalysis/subclassing-56
improves inheritance of Grid
2 parents b7e3b83 + 04dcbc1 commit 6c47ad6

4 files changed

Lines changed: 65 additions & 35 deletions

File tree

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ The rules for this file:
2424
* Added missing floordivision to Grid (PR #53)
2525
* fix test on ARM (#51)
2626
* fix incorrect reading of ncstart and nrstart in CCP4 (#57)
27+
* fix that arithemtical operations broke inheritance (#56)
28+
* fix so that subclasses of ndarray are retained on input (#56)
2729

2830
Changes (do not affect user)
2931

gridData/core.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,13 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
126126
self.load(grid)
127127
elif not (grid is None or edges is None):
128128
# set up from histogramdd-type data
129-
self.grid = numpy.asarray(grid)
129+
self.grid = numpy.asanyarray(grid)
130130
self.edges = edges
131131
self._update()
132132
elif not (grid is None or origin is None or delta is None):
133133
# setup from generic data
134-
origin = numpy.asarray(origin)
135-
delta = numpy.asarray(delta)
134+
origin = numpy.asanyarray(origin)
135+
delta = numpy.asanyarray(delta)
136136
if len(origin) != grid.ndim:
137137
raise TypeError(
138138
"Dimension of origin is not the same as grid dimension.")
@@ -148,7 +148,7 @@ def __init__(self, grid=None, edges=None, origin=None, delta=None,
148148
self.edges = [origin[dim] +
149149
(numpy.arange(m + 1) - 0.5) * delta[dim]
150150
for dim, m in enumerate(grid.shape)]
151-
self.grid = numpy.asarray(grid)
151+
self.grid = numpy.asanyarray(grid)
152152
self._update()
153153
else:
154154
# empty, must manually populate with load()
@@ -230,7 +230,7 @@ def resample(self, edges):
230230
coordinates = ndmeshgrid(*midpoints)
231231
# feed a meshgrid to generate all points
232232
newgrid = self.interpolated(*coordinates)
233-
return Grid(newgrid, edges)
233+
return self.__class__(newgrid, edges)
234234

235235
def resample_factor(self, factor):
236236
"""Resample to a new regular grid.
@@ -611,76 +611,76 @@ def __ne__(self, other):
611611

612612
def __add__(self, other):
613613
self.check_compatible(other)
614-
return Grid(self.grid + _grid(other), edges=self.edges)
614+
return self.__class__(self.grid + _grid(other), edges=self.edges)
615615

616616
def __sub__(self, other):
617617
self.check_compatible(other)
618-
return Grid(self.grid - _grid(other), edges=self.edges)
618+
return self.__class__(self.grid - _grid(other), edges=self.edges)
619619

620620
def __mul__(self, other):
621621
self.check_compatible(other)
622-
return Grid(self.grid * _grid(other), edges=self.edges)
622+
return self.__class__(self.grid * _grid(other), edges=self.edges)
623623

624624
def __truediv__(self, other):
625625
# truediv will always do true division (in Python 2 and Python 3);
626626
# we use from __future__ include division everywhere
627627
self.check_compatible(other)
628-
return Grid(self.grid / _grid(other), edges=self.edges)
628+
return self.__class__(self.grid / _grid(other), edges=self.edges)
629629

630630
def __div__(self, other):
631631
# in Python 2 only (without __future__.division): will do "classic division"
632632
# https://docs.python.org/2/reference/datamodel.html#object.__div__
633633
if not six.PY2:
634634
raise NotImplementedError("__div__ is only available in Python 2, use __truediv__")
635635
self.check_compatible(other)
636-
return Grid(self.grid.__div__(_grid(other)), edges=self.edges)
636+
return self.__class__(self.grid.__div__(_grid(other)), edges=self.edges)
637637

638638
def __floordiv__(self, other):
639639
self.check_compatible(other)
640-
return Grid(self.grid // _grid(other), edges=self.edges)
640+
return self.__class__(self.grid // _grid(other), edges=self.edges)
641641

642642
def __pow__(self, other):
643643
self.check_compatible(other)
644-
return Grid(numpy.power(self.grid, _grid(other)), edges=self.edges)
644+
return self.__class__(numpy.power(self.grid, _grid(other)), edges=self.edges)
645645

646646
def __radd__(self, other):
647647
self.check_compatible(other)
648-
return Grid(_grid(other) + self.grid, edges=self.edges)
648+
return self.__class__(_grid(other) + self.grid, edges=self.edges)
649649

650650
def __rsub__(self, other):
651651
self.check_compatible(other)
652-
return Grid(_grid(other) - self.grid, edges=self.edges)
652+
return self.__class__(_grid(other) - self.grid, edges=self.edges)
653653

654654
def __rmul__(self, other):
655655
self.check_compatible(other)
656-
return Grid(_grid(other) * self.grid, edges=self.edges)
656+
return self.__class__(_grid(other) * self.grid, edges=self.edges)
657657

658658
def __rtruediv__(self, other):
659659
self.check_compatible(other)
660-
return Grid(_grid(other) / self.grid, edges=self.edges)
660+
return self.__class__(_grid(other) / self.grid, edges=self.edges)
661661

662662
def __rdiv__(self, other):
663663
# in Python 2 only (without __future__.division): will do "classic division"
664664
# https://docs.python.org/2/reference/datamodel.html#object.__div__
665665
if not six.PY2:
666666
raise NotImplementedError("__rdiv__ is only available in Python 2, use __rtruediv__")
667667
self.check_compatible(other)
668-
return Grid(self.grid.__rdiv__(_grid(other)), edges=self.edges)
668+
return self.__class__(self.grid.__rdiv__(_grid(other)), edges=self.edges)
669669

670670
def __rfloordiv__(self, other):
671671
self.check_compatible(other)
672-
return Grid(_grid(other) // self.grid, edges=self.edges)
672+
return self.__class__(_grid(other) // self.grid, edges=self.edges)
673673

674674
def __rpow__(self, other):
675675
self.check_compatible(other)
676-
return Grid(numpy.power(_grid(other), self.grid), edges=self.edges)
676+
return self.__class__(numpy.power(_grid(other), self.grid), edges=self.edges)
677677

678678
def __repr__(self):
679679
try:
680680
bins = self.grid.shape
681681
except AttributeError:
682682
bins = "no"
683-
return '<Grid with ' + str(bins) + ' bins>'
683+
return '<{0} with {1!r} bins>'.format(self.__class__, bins)
684684

685685

686686
def ndmeshgrid(*arrs):
@@ -709,7 +709,7 @@ def ndmeshgrid(*arrs):
709709
for i, arr in enumerate(arrs):
710710
slc = [1] * dim
711711
slc[i] = lens[i]
712-
arr2 = numpy.asarray(arr).reshape(slc)
712+
arr2 = numpy.asanyarray(arr).reshape(slc)
713713
for j, sz in enumerate(lens):
714714
if j != i:
715715
arr2 = arr2.repeat(sz, axis=j)

gridData/tests/test_ccp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def ccp4data():
6060
('nlabl', 1),
6161
('label', ' Map from fft '),
6262
])
63-
def test_ccp4_integer_reading(ccp4data, name, value):
63+
def test_ccp4_read_header(ccp4data, name, value):
6464
if type(value) is float:
6565
assert_almost_equal(ccp4data.header[name], value, decimal=6)
6666
else:

gridData/tests/test_grid.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,27 @@
22
import six
33

44
import numpy as np
5-
from numpy.testing import assert_array_equal, assert_array_almost_equal
5+
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
6+
assert_almost_equal)
67

78
import pytest
89

910
from gridData import Grid
1011

11-
class TestGrid(object):
12-
@staticmethod
13-
@pytest.fixture(scope="class")
14-
def data():
15-
d = dict(
16-
griddata=np.arange(1, 28).reshape(3, 3, 3),
17-
origin=np.zeros(3),
18-
delta=np.ones(3))
19-
d['grid'] = Grid(d['griddata'], origin=d['origin'],
20-
delta=d['delta'])
21-
return d
12+
def f_arithmetic(g):
13+
return g + g - 2.5 * g / (g + 5.3)
14+
15+
@pytest.fixture(scope="class")
16+
def data():
17+
d = dict(
18+
griddata=np.arange(1, 28).reshape(3, 3, 3),
19+
origin=np.zeros(3),
20+
delta=np.ones(3))
21+
d['grid'] = Grid(d['griddata'], origin=d['origin'],
22+
delta=d['delta'])
23+
return d
2224

25+
class TestGrid(object):
2326
def test_init(self, data):
2427
g = Grid(data['griddata'], origin=data['origin'],
2528
delta=1)
@@ -47,7 +50,7 @@ def test_addition(self, data):
4750
g = g + data['grid']
4851
assert_array_equal(g.grid.flat, (2 + (2 * data['griddata'])).flat)
4952

50-
def test_substraction(self, data):
53+
def test_subtraction(self, data):
5154
g = data['grid'] - data['grid']
5255
assert_array_equal(g.grid.flat, np.zeros(27))
5356
g = 2 - data['grid']
@@ -140,3 +143,28 @@ def test_resample_factor(self, data):
140143
# check that the edges are the same
141144
assert_array_almost_equal(g.grid[::5, ::5, ::5],
142145
data['grid'].grid[::2, ::2, ::2])
146+
147+
148+
def test_inheritance(data):
149+
class DerivedGrid(Grid):
150+
pass
151+
152+
dg = DerivedGrid(data['griddata'], origin=data['origin'],
153+
delta=data['delta'])
154+
result = f_arithmetic(dg)
155+
156+
assert isinstance(result, DerivedGrid)
157+
158+
ref = f_arithmetic(data['grid'])
159+
assert_almost_equal(result.grid, ref.grid)
160+
161+
def test_anyarray(data):
162+
ma = np.ma.MaskedArray(data['griddata'])
163+
mg = Grid(ma, origin=data['origin'], delta=data['delta'])
164+
165+
assert isinstance(mg.grid, ma.__class__)
166+
167+
result = f_arithmetic(mg)
168+
ref = f_arithmetic(data['grid'])
169+
170+
assert_almost_equal(result.grid, ref.grid)

0 commit comments

Comments
 (0)