Skip to content

Commit be6132a

Browse files
authored
Make interpolation outside of known map 0. (#84)
* Make interpolation outside of known map 0. * Correct behavior for resample_factor The reason tests were failing for the previous commit was that resample_factor was depending on "nearest" interpolation to correct for not altering the range where the edge points were calculated. By correctly calculating an even distribution of edge points, interpolation at "midpoints" no longer occurs outside the defined grid. * PEP8ification and additional changelog commentary * Fix the test for resampling. * Test for small resample factor * Test more values. * Update CHANGELOG
1 parent 5f344cd commit be6132a

4 files changed

Lines changed: 108 additions & 36 deletions

File tree

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ Contributors:
2020
* Eloy Félix <eloyfelix>
2121
* René Hafner (Hamburger) <renehamburger1993>
2222
* Lily Wang <lilyminium>
23+
* Josh Vermaas <jvermaas>

CHANGELOG

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ The rules for this file:
1313
* accompany each entry with github issue/PR number (Issue #xyz)
1414

1515
------------------------------------------------------------------------------
16-
??/??/2019 eloyfelix, renehamburger1993, lilyminium
16+
??/??/2019 eloyfelix, renehamburger1993, lilyminium, jvermaas
1717

1818
* 0.6.0
1919

@@ -26,6 +26,9 @@ The rules for this file:
2626

2727
* fix initialization of mutable instance variable of Grid class (metadata dict) (#71)
2828
* fix multiple __init__ calls (#73)
29+
* interpolation behavior outside of the grid changed to default to a
30+
constant rather than the nearest value (#84)
31+
* corrected resampling behavior to not draw on values outside of the grid (#84)
2932

3033

3134
05/16/2019 giacomofiorin, orbeckst

gridData/core.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def interpolation_spline_order(self, x):
187187
self.__interpolation_spline_order = x
188188
self._update()
189189

190-
191190
def resample(self, edges):
192191
"""Resample data to a new grid with edges *edges*.
193192
@@ -245,7 +244,8 @@ def resample_factor(self, factor):
245244
factor : float
246245
The number of grid cells are scaled with `factor` in each
247246
dimension, i.e., ``factor * N_i`` cells along each
248-
dimension i.
247+
dimension i. Must be positive, and cannot result in fewer
248+
than 2 cells along a dimension.
249249
250250
251251
Returns
@@ -257,12 +257,32 @@ def resample_factor(self, factor):
257257
--------
258258
resample
259259
260+
.. versionchanged:: 0.6.0
261+
Previous implementations would not alter the range of the grid edges
262+
being resampled on. As a result, values at the grid edges would creep
263+
steadily inward. The new implementation recalculates the extent of
264+
grid edges for every resampling.
265+
260266
"""
261-
# new number of edges N' = (N-1)*f + 1
262-
newlengths = [(N - 1) * float(factor) + 1 for N in self._len_edges()]
263-
edges = [numpy.linspace(start, stop, num=int(N), endpoint=True)
264-
for (start, stop, N) in
265-
zip(self._min_edges(), self._max_edges(), newlengths)]
267+
if float(factor) <= 0:
268+
raise ValueError("Factor must be positive")
269+
# Determine current spacing
270+
spacing = (numpy.array(self._max_edges()) - numpy.array(self._min_edges())) / (
271+
-1 + numpy.array(self._len_edges()))
272+
# First guess at the new spacing is inversely related to the
273+
# magnification factor.
274+
newspacing = spacing / float(factor)
275+
smidpoints = numpy.array(self._midpoints())
276+
# We require that the new spacing result in an even subdivision of the
277+
# existing midpoints
278+
newspacing = (smidpoints[:, -1] - smidpoints[:, 0]) / (numpy.maximum(
279+
1, numpy.floor((smidpoints[:, -1] - smidpoints[:, 0]) / newspacing)))
280+
# How many edge points should there be? It is the number of intervals
281+
# between midpoints + 2
282+
edgelength = 2 + \
283+
numpy.round((smidpoints[:, -1] - smidpoints[:, 0]) / newspacing)
284+
edges = [numpy.linspace(start, stop, num=int(N), endpoint=True) for (start, stop, N) in zip(
285+
smidpoints[:, 0] - 0.5 * newspacing, smidpoints[:, -1] + 0.5 * newspacing, edgelength)]
266286
return self.resample(edges)
267287

268288
def _update(self):
@@ -321,6 +341,18 @@ def interpolated(self):
321341
a cubic spline interpolation can generate negative values,
322342
especially at the boundary between 0 and high values.
323343
344+
Internally, the function uses :func:`scipy.ndimage.map_coordinates`
345+
with ``mode="constant"`` whereby interpolated values outside
346+
the interpolated grid are determined by filling all values beyond
347+
the edge with the same constant value, defined by the
348+
:attr:`interpolation_cval` parameter, which when not set defaults
349+
to the minimum value in the interpolated grid.
350+
351+
.. versionchanged:: 0.6.0
352+
Interpolation outside the grid is now performed with
353+
``mode="constant"`rather than ``mode="nearest"``, eliminating
354+
extruded volumes when interpolating beyond the grid.
355+
324356
"""
325357
if self.__interpolated is None:
326358
self.__interpolated = self._interpolationFunctionFactory()
@@ -373,7 +405,13 @@ def _get_loader(self, filename, file_format=None):
373405
file_format=file_format,
374406
export=False)]
375407

376-
def _load(self, grid=None, edges=None, metadata=None, origin=None, delta=None):
408+
def _load(
409+
self,
410+
grid=None,
411+
edges=None,
412+
metadata=None,
413+
origin=None,
414+
delta=None):
377415
if edges is not None:
378416
# set up from histogramdd-type data
379417
self.grid = numpy.asanyarray(grid)
@@ -396,16 +434,17 @@ def _load(self, grid=None, edges=None, metadata=None, origin=None, delta=None):
396434
"len(grid.ndim)")
397435
# note that origin is CENTER so edges must be shifted by -0.5*delta
398436
self.edges = [origin[dim] +
399-
(numpy.arange(m + 1) - 0.5) * delta[dim]
400-
for dim, m in enumerate(grid.shape)]
437+
(numpy.arange(m + 1) - 0.5) * delta[dim]
438+
for dim, m in enumerate(grid.shape)]
401439
self.grid = numpy.asanyarray(grid)
402440
self._update()
403441
else:
404-
raise ValueError("Wrong/missing data to set up Grid. Use Grid() or "
405-
"Grid(grid=<array>, edges=<list>) or "
406-
"Grid(grid=<array>, origin=(x0, y0, z0), delta=(dx, dy, dz)):\n"
407-
"grid={0} edges={1} origin={2} delta={3}".format(
408-
grid, edges, origin, delta))
442+
raise ValueError(
443+
"Wrong/missing data to set up Grid. Use Grid() or "
444+
"Grid(grid=<array>, edges=<list>) or "
445+
"Grid(grid=<array>, origin=(x0, y0, z0), delta=(dx, dy, dz)):\n"
446+
"grid={0} edges={1} origin={2} delta={3}".format(
447+
grid, edges, origin, delta))
409448

410449
def load(self, filename, file_format=None):
411450
"""Load saved (pickled or dx) grid and edges from <filename>.pickle
@@ -533,8 +572,7 @@ def _export_dx(self, filename, type=None, typequote='"', **kwargs):
533572
'File format: http://opendx.sdsc.edu/docs/html/pages/usrgu068.htm#HDREDF',
534573
'Data are embedded in the header and tied to the grid positions.',
535574
'Data is written in C array order: In grid[x,y,z] the axis z is fastest',
536-
'varying, then y, then finally x, i.e. z is the innermost loop.'
537-
]
575+
'varying, then y, then finally x, i.e. z is the innermost loop.']
538576

539577
# write metadata in comments section
540578
if self.metadata:
@@ -614,7 +652,8 @@ def _interpolationFunctionFactory(self, spline_order=None, cval=None):
614652
import scipy.ndimage
615653

616654
if spline_order is None:
617-
# must be compatible with whatever :func:`scipy.ndimage.spline_filter` takes.
655+
# must be compatible with whatever
656+
# :func:`scipy.ndimage.spline_filter` takes.
618657
spline_order = self.interpolation_spline_order
619658
if cval is None:
620659
cval = self.interpolation_cval
@@ -650,18 +689,21 @@ def interpolatedF(*coordinates):
650689
return scipy.ndimage.map_coordinates(coeffs,
651690
_coordinates,
652691
prefilter=False,
653-
mode='nearest',
692+
mode='constant',
654693
cval=cval)
655-
# mode='wrap' would be ideal but is broken: https://github.com/scipy/scipy/issues/1323
656694
return interpolatedF
657695

658696
def __eq__(self, other):
659697
if not isinstance(other, Grid):
660698
return False
661-
return numpy.all(other.grid == self.grid) and \
662-
numpy.all(other.origin == self.origin) and \
663-
numpy.all(numpy.all(other_edge == self_edge) for other_edge, self_edge in
664-
zip(other.edges, self.edges))
699+
return numpy.all(
700+
other.grid == self.grid) and numpy.all(
701+
other.origin == self.origin) and numpy.all(
702+
numpy.all(
703+
other_edge == self_edge) for other_edge,
704+
self_edge in zip(
705+
other.edges,
706+
self.edges))
665707

666708
def __ne__(self, other):
667709
return not self.__eq__(other)
@@ -688,17 +730,25 @@ def __div__(self, other):
688730
# in Python 2 only (without __future__.division): will do "classic division"
689731
# https://docs.python.org/2/reference/datamodel.html#object.__div__
690732
if not six.PY2:
691-
raise NotImplementedError("__div__ is only available in Python 2, use __truediv__")
733+
raise NotImplementedError(
734+
"__div__ is only available in Python 2, use __truediv__")
692735
self.check_compatible(other)
693-
return self.__class__(self.grid.__div__(_grid(other)), edges=self.edges)
736+
return self.__class__(
737+
self.grid.__div__(
738+
_grid(other)),
739+
edges=self.edges)
694740

695741
def __floordiv__(self, other):
696742
self.check_compatible(other)
697743
return self.__class__(self.grid // _grid(other), edges=self.edges)
698744

699745
def __pow__(self, other):
700746
self.check_compatible(other)
701-
return self.__class__(numpy.power(self.grid, _grid(other)), edges=self.edges)
747+
return self.__class__(
748+
numpy.power(
749+
self.grid,
750+
_grid(other)),
751+
edges=self.edges)
702752

703753
def __radd__(self, other):
704754
self.check_compatible(other)
@@ -720,17 +770,25 @@ def __rdiv__(self, other):
720770
# in Python 2 only (without __future__.division): will do "classic division"
721771
# https://docs.python.org/2/reference/datamodel.html#object.__div__
722772
if not six.PY2:
723-
raise NotImplementedError("__rdiv__ is only available in Python 2, use __rtruediv__")
773+
raise NotImplementedError(
774+
"__rdiv__ is only available in Python 2, use __rtruediv__")
724775
self.check_compatible(other)
725-
return self.__class__(self.grid.__rdiv__(_grid(other)), edges=self.edges)
776+
return self.__class__(
777+
self.grid.__rdiv__(
778+
_grid(other)),
779+
edges=self.edges)
726780

727781
def __rfloordiv__(self, other):
728782
self.check_compatible(other)
729783
return self.__class__(_grid(other) // self.grid, edges=self.edges)
730784

731785
def __rpow__(self, other):
732786
self.check_compatible(other)
733-
return self.__class__(numpy.power(_grid(other), self.grid), edges=self.edges)
787+
return self.__class__(
788+
numpy.power(
789+
_grid(other),
790+
self.grid),
791+
edges=self.edges)
734792

735793
def __repr__(self):
736794
try:
@@ -753,7 +811,7 @@ def ndmeshgrid(*arrs):
753811
754812
.. SeeAlso: :func:`numpy.meshgrid` for the 2D case.
755813
"""
756-
#arrs = tuple(reversed(arrs)) <-- wrong on stackoverflow.com
814+
# arrs = tuple(reversed(arrs)) <-- wrong on stackoverflow.com
757815
arrs = tuple(arrs)
758816
lens = list(map(len, arrs))
759817
dim = len(arrs)

gridData/tests/test_grid.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,25 @@ def test_centers(self, data):
161161
assert_array_equal(centers[-1] - g.origin,
162162
(np.array(g.grid.shape) - 1) * data['delta'])
163163

164+
def test_resample_factor_failure(self, data):
165+
pytest.importorskip('scipy')
166+
167+
with pytest.raises(ValueError):
168+
g = data['grid'].resample_factor(0)
169+
164170
def test_resample_factor(self, data):
165171
pytest.importorskip('scipy')
166172

167173
g = data['grid'].resample_factor(2)
168174
assert_array_equal(g.delta, np.ones(3) * .5)
169-
assert_array_equal(g.grid.shape, np.ones(3) * 6)
170-
# check that the edges are the same
171-
assert_array_almost_equal(g.grid[::5, ::5, ::5],
172-
data['grid'].grid[::2, ::2, ::2])
175+
# zooming in by a factor of 2. Each subinterval is
176+
# split in half, so 3 gridpoints (2 subintervals)
177+
# becomes 5 gridpoints (4 subintervals)
178+
assert_array_equal(g.grid.shape, np.ones(3) * 5)
179+
# check that the values are identical with the
180+
# correct stride.
181+
assert_array_almost_equal(g.grid[::2, ::2, ::2],
182+
data['grid'].grid)
173183

174184
def test_load_pickle(self, data, tmpdir):
175185
g = data['grid']

0 commit comments

Comments
 (0)