Skip to content

Commit be64439

Browse files
committed
fixed inheritance of Grid class (#56)
- use suggestion by @AstroMike in #56 - add test for using a DerivedClass
1 parent 39a3bf8 commit be64439

2 files changed

Lines changed: 45 additions & 29 deletions

File tree

gridData/core.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def resample(self, edges):
230230
coordinates = ndmeshgrid(*midpoints)
231231
# feed a meshgrid to generate all points
232232
newgrid = self.interpolated(*coordinates)
233-
return Grid(newgrid, edges)
233+
return self.__class__(newgrid, edges)
234234

235235
def resample_factor(self, factor):
236236
"""Resample to a new regular grid.
@@ -611,76 +611,76 @@ def __ne__(self, other):
611611

612612
def __add__(self, other):
613613
self.check_compatible(other)
614-
return Grid(self.grid + _grid(other), edges=self.edges)
614+
return self.__class__(self.grid + _grid(other), edges=self.edges)
615615

616616
def __sub__(self, other):
617617
self.check_compatible(other)
618-
return Grid(self.grid - _grid(other), edges=self.edges)
618+
return self.__class__(self.grid - _grid(other), edges=self.edges)
619619

620620
def __mul__(self, other):
621621
self.check_compatible(other)
622-
return Grid(self.grid * _grid(other), edges=self.edges)
622+
return self.__class__(self.grid * _grid(other), edges=self.edges)
623623

624624
def __truediv__(self, other):
625625
# truediv will always do true division (in Python 2 and Python 3);
626626
# we use from __future__ include division everywhere
627627
self.check_compatible(other)
628-
return Grid(self.grid / _grid(other), edges=self.edges)
628+
return self.__class__(self.grid / _grid(other), edges=self.edges)
629629

630630
def __div__(self, other):
631631
# in Python 2 only (without __future__.division): will do "classic division"
632632
# https://docs.python.org/2/reference/datamodel.html#object.__div__
633633
if not six.PY2:
634634
raise NotImplementedError("__div__ is only available in Python 2, use __truediv__")
635635
self.check_compatible(other)
636-
return Grid(self.grid.__div__(_grid(other)), edges=self.edges)
636+
return self.__class__(self.grid.__div__(_grid(other)), edges=self.edges)
637637

638638
def __floordiv__(self, other):
639639
self.check_compatible(other)
640-
return Grid(self.grid // _grid(other), edges=self.edges)
640+
return self.__class__(self.grid // _grid(other), edges=self.edges)
641641

642642
def __pow__(self, other):
643643
self.check_compatible(other)
644-
return Grid(numpy.power(self.grid, _grid(other)), edges=self.edges)
644+
return self.__class__(numpy.power(self.grid, _grid(other)), edges=self.edges)
645645

646646
def __radd__(self, other):
647647
self.check_compatible(other)
648-
return Grid(_grid(other) + self.grid, edges=self.edges)
648+
return self.__class__(_grid(other) + self.grid, edges=self.edges)
649649

650650
def __rsub__(self, other):
651651
self.check_compatible(other)
652-
return Grid(_grid(other) - self.grid, edges=self.edges)
652+
return self.__class__(_grid(other) - self.grid, edges=self.edges)
653653

654654
def __rmul__(self, other):
655655
self.check_compatible(other)
656-
return Grid(_grid(other) * self.grid, edges=self.edges)
656+
return self.__class__(_grid(other) * self.grid, edges=self.edges)
657657

658658
def __rtruediv__(self, other):
659659
self.check_compatible(other)
660-
return Grid(_grid(other) / self.grid, edges=self.edges)
660+
return self.__class__(_grid(other) / self.grid, edges=self.edges)
661661

662662
def __rdiv__(self, other):
663663
# in Python 2 only (without __future__.division): will do "classic division"
664664
# https://docs.python.org/2/reference/datamodel.html#object.__div__
665665
if not six.PY2:
666666
raise NotImplementedError("__rdiv__ is only available in Python 2, use __rtruediv__")
667667
self.check_compatible(other)
668-
return Grid(self.grid.__rdiv__(_grid(other)), edges=self.edges)
668+
return self.__class__(self.grid.__rdiv__(_grid(other)), edges=self.edges)
669669

670670
def __rfloordiv__(self, other):
671671
self.check_compatible(other)
672-
return Grid(_grid(other) // self.grid, edges=self.edges)
672+
return self.__class__(_grid(other) // self.grid, edges=self.edges)
673673

674674
def __rpow__(self, other):
675675
self.check_compatible(other)
676-
return Grid(numpy.power(_grid(other), self.grid), edges=self.edges)
676+
return self.__class__(numpy.power(_grid(other), self.grid), edges=self.edges)
677677

678678
def __repr__(self):
679679
try:
680680
bins = self.grid.shape
681681
except AttributeError:
682682
bins = "no"
683-
return '<Grid with ' + str(bins) + ' bins>'
683+
return '<{0} with {1!r} bins>'.format(self.__class__, bins)
684684

685685

686686
def ndmeshgrid(*arrs):

gridData/tests/test_grid.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22
import six
33

44
import numpy as np
5-
from numpy.testing import assert_array_equal, assert_array_almost_equal
5+
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
6+
assert_almost_equal)
67

78
import pytest
89

910
from gridData import Grid
1011

11-
class TestGrid(object):
12-
@staticmethod
13-
@pytest.fixture(scope="class")
14-
def data():
15-
d = dict(
16-
griddata=np.arange(1, 28).reshape(3, 3, 3),
17-
origin=np.zeros(3),
18-
delta=np.ones(3))
19-
d['grid'] = Grid(d['griddata'], origin=d['origin'],
20-
delta=d['delta'])
21-
return d
12+
@pytest.fixture(scope="class")
13+
def data():
14+
d = dict(
15+
griddata=np.arange(1, 28).reshape(3, 3, 3),
16+
origin=np.zeros(3),
17+
delta=np.ones(3))
18+
d['grid'] = Grid(d['griddata'], origin=d['origin'],
19+
delta=d['delta'])
20+
return d
2221

22+
class TestGrid(object):
2323
def test_init(self, data):
2424
g = Grid(data['griddata'], origin=data['origin'],
2525
delta=1)
@@ -47,7 +47,7 @@ def test_addition(self, data):
4747
g = g + data['grid']
4848
assert_array_equal(g.grid.flat, (2 + (2 * data['griddata'])).flat)
4949

50-
def test_substraction(self, data):
50+
def test_subtraction(self, data):
5151
g = data['grid'] - data['grid']
5252
assert_array_equal(g.grid.flat, np.zeros(27))
5353
g = 2 - data['grid']
@@ -140,3 +140,19 @@ def test_resample_factor(self, data):
140140
# check that the edges are the same
141141
assert_array_almost_equal(g.grid[::5, ::5, ::5],
142142
data['grid'].grid[::2, ::2, ::2])
143+
144+
145+
def test_inheritance(data):
146+
class DerivedGrid(Grid):
147+
pass
148+
149+
dg = DerivedGrid(data['griddata'], origin=data['origin'],
150+
delta=data['delta'])
151+
result = dg + dg - 2.5 * dg / (dg + 5.3)
152+
153+
assert isinstance(result, DerivedGrid)
154+
155+
g = data['grid']
156+
ref = g + g - 2.5 * g / (dg + 5.3)
157+
assert_almost_equal(result.grid, ref.grid)
158+

0 commit comments

Comments
 (0)