Skip to content

Commit eac2b31

Browse files
committed
make interp_array more generic
Conflicts: src/interp1d.rs
1 parent d3d1e78 commit eac2b31

2 files changed

Lines changed: 17 additions & 12 deletions

File tree

src/interp1d.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::{fmt::Debug, ops::Sub};
1212

1313
use ndarray::{
1414
s, Array, ArrayBase, ArrayView, Axis, AxisDescription, Data, DimAdd, Dimension, IntoDimension,
15-
Ix1, NdIndex, OwnedRepr, RemoveAxis, Slice,
15+
Ix1, OwnedRepr, RemoveAxis, Slice,
1616
};
1717
use num_traits::{cast, Num, NumCast};
1818

@@ -183,25 +183,20 @@ where
183183
/// let result = interpolator.interp_array(&query).unwrap();
184184
/// # assert_abs_diff_eq!(result, expected, epsilon=f64::EPSILON);
185185
/// ```
186-
pub fn interp_array<Dq>(
186+
pub fn interp_array<Sq, Dq>(
187187
&self,
188-
xs: &ArrayBase<Sx, Dq>,
188+
xs: &ArrayBase<Sq, Dq>,
189189
) -> Result<Array<Sd::Elem, <Dq as DimAdd<D::Smaller>>::Output>, InterpolateError>
190190
where
191-
D: RemoveAxis,
191+
Sq: Data<Elem = Sd::Elem>,
192192
Dq: Dimension + DimAdd<D::Smaller>,
193-
Dq::Pattern: NdIndex<Dq>,
194193
{
195194
let mut dim = <Dq as DimAdd<D::Smaller>>::Output::default();
196195
dim.as_array_view_mut()
197196
.into_iter()
198-
.zip(
199-
xs.shape()
200-
.iter()
201-
.chain(self.data.raw_dim().as_array_view().slice(s![1..])),
202-
)
203-
.for_each(|(new_axis, len)| {
204-
*new_axis = *len;
197+
.zip(xs.shape().iter().chain(self.data.shape()[1..].iter()))
198+
.for_each(|(new_axis, &len)| {
199+
*new_axis = len;
205200
});
206201
let mut ys = Array::zeros(dim);
207202

tests/interp1d.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,13 @@ fn interp_multi_fn() {
193193
epsilon = f64::EPSILON
194194
);
195195
}
196+
197+
#[test]
198+
fn interp_array_with_differnt_repr(){
199+
let interp = Interp1D::builder(array![1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 4.0, 3.0, 2.0, 1.0])
200+
.build()
201+
.unwrap();
202+
let x_query = array![[1.0, 2.0, 9.0], [4.0, 5.0, 7.5]];
203+
let y_expect = array![[2.0, 3.0, 1.0], [5.0, 5.0, 2.5]];
204+
assert_eq!(interp.interp_array(&x_query.view()).unwrap(), y_expect);
205+
}

0 commit comments

Comments
 (0)