Skip to content

Commit 51d25c3

Browse files
committed
target rising
1 parent 5a2a205 commit 51d25c3

2 files changed

Lines changed: 31 additions & 18 deletions

File tree

src/stratify/_vinterp.pyx

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def interpolate(z_target, z_src, fz_src, axis=-1, rising=None,
520520
# Dask array
521521
import dask.array as da
522522

523-
# Ensure z_target is an array.
523+
# Ensure z_target is an array that runs in the same direction as z_src otherwise flip z_src and fz_src.
524524
if not isinstance(z_target, (np.ndarray, da.Array)):
525525
z_target = np.array(z_target)
526526

@@ -622,6 +622,32 @@ cdef class _Interpolation(object):
622622
emsg = 'Shape for z_src {} is not a subset of fz_src {}.'
623623
raise ValueError(emsg.format(z_src.shape, fz_src.shape))
624624

625+
if rising is None:
626+
if z_src.shape[zp_axis] < 2:
627+
raise ValueError('The rising keyword must be defined when '
628+
'the size of the source array is <2 in '
629+
'the interpolation axis.')
630+
z_src_indexer = [0] * z_src.ndim
631+
z_src_indexer[zp_axis] = slice(0, 2)
632+
src_first_two = z_src[tuple(z_src_indexer)]
633+
rising = src_first_two[0] <= src_first_two[1]
634+
if len(z_target) < 2:
635+
tgt_rising = rising
636+
else:
637+
if z_target.ndim == 1:
638+
first_two_t = z_target[:2]
639+
else:
640+
tgt_axis = tgt_axis % z_target.ndim
641+
tgt_indexer = [slice(None)] * z_target.ndim
642+
tgt_indexer[axis] = slice(0, 2)
643+
tgt_first_two = z_tgt[tuple(indexer)].ravel()[:2]
644+
tgt_rising = tgt_first_two[0] <= tgt_first_two_t[1]
645+
if tgt_rising != rising:
646+
z_src = np.flip(z_src, axis=zp_axis)
647+
fz_src = np.flip(fz_src, axis=zp_axis)
648+
rising = tgt_rising
649+
self.rising = bool(rising)
650+
625651
if z_target.ndim == 1:
626652
z_target_size = z_target.shape[0]
627653
else:
@@ -645,7 +671,6 @@ cdef class _Interpolation(object):
645671
'got ({}) != ({}).')
646672
raise ValueError(emsg.format(sep.join(ztsp), sep.join(zssp)))
647673
z_target_size = zts[zp_axis]
648-
649674
# We are going to put the source coordinate into a 3d shape for convenience of
650675
# Cython interface. Writing generic, fast, n-dimensional Cython code
651676
# is not possible, but it is possible to always support a 3d array with
@@ -692,18 +717,6 @@ cdef class _Interpolation(object):
692717
#: The shape of the interpolated data.
693718
self.result_shape = tuple(result_shape)
694719

695-
if rising is None:
696-
if z_src.shape[zp_axis] < 2:
697-
raise ValueError('The rising keyword must be defined when '
698-
'the size of the source array is <2 in '
699-
'the interpolation axis.')
700-
z_src_indexer = [0] * z_src.ndim
701-
z_src_indexer[zp_axis] = slice(0, 2)
702-
first_two = z_src[tuple(z_src_indexer)]
703-
rising = first_two[0] <= first_two[1]
704-
705-
self.rising = bool(rising)
706-
707720
# Sometimes we want to add additional constraints on our interpolation
708721
# and extrapolation - for example, linear extrapolation requires there
709722
# to be two coordinates to interpolate from.

src/stratify/tests/test_vinterp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def test_no_levels(self):
113113
assert_array_equal(r, [])
114114

115115
def test_wrong_rising_target(self):
116-
r = self.interpolate([2, 1], [1, 2])
117-
assert_array_equal(r, [1, np.inf])
116+
r = self.interpolate([1, 2], [2, 1])
117+
assert_array_equal(r, [0.0, 1.0])
118118

119119
def test_wrong_rising_source(self):
120120
r = self.interpolate([1, 2], [2, 1], rising=True)
@@ -124,11 +124,11 @@ def test_wrong_rising_source_and_target(self):
124124
# If we overshoot the first level, there is no hope,
125125
# so we end up extrapolating.
126126
r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True)
127-
assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf])
127+
assert_array_equal(r, [-np.inf, -np.inf, 0.0, np.inf])
128128

129129
def test_non_monotonic_coordinate_interp(self):
130130
result = self.interpolate([15, 5, 15.0], [10.0, 20, 0, 20])
131-
assert_array_equal(result, [1, 2, 3])
131+
assert_array_equal(result, [1.0, 1.0, 2.0])
132132

133133
def test_non_monotonic_coordinate_extrap(self):
134134
result = self.interpolate([0, 15, 16, 17, 5, 15.0, 25], [10.0, 40, 0, 20])

0 commit comments

Comments
 (0)