diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 0e7389e5..d6e32742 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -86,6 +86,22 @@ concept InputIterator = requires(It it) { namespace detail { +/* + * FIXME: SimpleArray currently stores strides as small_vector. Negative + * NumPy strides are temporarily preserved by wrapping them into size_t on + * input and converting them back here. Clean this up by making the internal + * stride storage small_vector. + */ +inline ssize_t stride_to_signed(size_t stride) noexcept +{ + if (stride <= static_cast(std::numeric_limits::max())) + { + return static_cast(stride); + } + + return -static_cast(~stride + 1); +} + template size_t buffer_offset_impl(S const &) { diff --git a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp index 01d14de4..c3e2ce48 100644 --- a/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp +++ b/cpp/modmesh/buffer/pymod/TypeBroadcast.hpp @@ -38,53 +38,78 @@ namespace modmesh namespace python { +namespace detail +{ + +inline modmesh::detail::shape_type shape_from_slices( + std::vector const & slices) +{ + modmesh::detail::shape_type shape(slices.size()); + for (size_t i = 0; i < slices.size(); ++i) + { + shape[i] = static_cast(slices[i][3]); + } + return shape; +} + +} /* end namespace detail */ + template struct TypeBroadcastImpl { using slice_type = modmesh::detail::slice_type; using shape_type = modmesh::detail::shape_type; - // NOLINTNEXTLINE(misc-no-recursion) - static void copy_idx(SimpleArray & arr_out, std::vector const & slices, pybind11::array_t const * arr_in, shape_type left_shape, shape_type sidx, int dim) + static ssize_t input_offset(pybind11::array_t const & arr_in, shape_type const & sidx) { - using out_type = typename std::remove_reference_t; - - if (dim < 0) + ssize_t offset = 0; + for (pybind11::ssize_t i = 0; i < arr_in.ndim(); ++i) { - return; + auto const index = static_cast(sidx[i]); + offset += arr_in.strides(i) / arr_in.itemsize() * index; } + return offset; + } - for (size_t i = 0; i < left_shape[dim]; ++i) + static ssize_t offset_from_slices(SimpleArray const & arr, std::vector const & slices, shape_type const & sidx) + { + ssize_t offset = 0; + for (size_t i = 0; i < arr.ndim(); ++i) { - sidx[dim] = i; + auto const slice_index = static_cast(sidx[i]); + ssize_t const index = slices[i][0] + slice_index * slices[i][2]; + offset += modmesh::detail::stride_to_signed(arr.stride(i)) * index; + } + return offset; + } - size_t offset_in = 0; - for (pybind11::ssize_t it = 0; it < arr_in->ndim(); ++it) - { - offset_in += arr_in->strides(it) / arr_in->itemsize() * sidx[it]; - } - const D * ptr_in = arr_in->data() + offset_in; + // NOLINTNEXTLINE(misc-no-recursion) + static void copy_idx(SimpleArray & arr_out, std::vector const & slices, pybind11::array_t const * arr_in, shape_type left_shape, shape_type sidx, int dim) + { + using out_type = typename std::remove_reference_t; - size_t offset_out = 0; - for (size_t it = 0; it < arr_out.ndim(); ++it) - { - auto step = slices[it][2]; - offset_out += arr_out.stride(it) * sidx[it] * step; - } + if (dim < 0) + { + D const * ptr_in = arr_in->data() + input_offset(*arr_in, sidx); + ssize_t const offset_out = offset_from_slices(arr_out, slices, sidx); constexpr bool valid_conversion = (!is_complex_v && !is_complex_v) || (is_complex_v && is_complex_v && std::is_same_v); if constexpr (valid_conversion) { // FIXME: NOLINTNEXTLINE(bugprone-signed-char-misuse,cert-str34-c) - arr_out.at(offset_out) = static_cast(*ptr_in); + arr_out.data()[offset_out] = static_cast(*ptr_in); } else { throw std::runtime_error("Cannot convert between complex and non-complex types"); } + return; + } - // recursion here + for (size_t i = 0; i < left_shape[dim]; ++i) + { + sidx[dim] = i; copy_idx(arr_out, slices, arr_in, left_shape, sidx, dim - 1); } } @@ -94,27 +119,8 @@ struct TypeBroadcastImpl // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) auto * arr_new = reinterpret_cast const *>(&arr_in); - shape_type left_shape(arr_out.ndim()); - for (size_t i = 0; i < arr_out.ndim(); i++) - { - slice_type const & slice = slices[i]; - if ((slice[1] - slice[0]) % slice[2] == 0) - { - left_shape[i] = (slice[1] - slice[0]) / slice[2]; - } - else - { - left_shape[i] = (slice[1] - slice[0]) / slice[2] + 1; - } - } - - shape_type sidx_init(arr_out.ndim()); - - for (size_t i = 0; i < arr_out.ndim(); ++i) - { - sidx_init[i] = 0; - } - + shape_type const left_shape = modmesh::python::detail::shape_from_slices(slices); + shape_type const sidx_init(arr_out.ndim()); copy_idx(arr_out, slices, arr_new, left_shape, sidx_init, static_cast(arr_out.ndim()) - 1); } }; /* end struct TypeBroadcastImpl */ @@ -133,20 +139,7 @@ struct TypeBroadcast right_shape[i] = arr_in.shape(i); } - shape_type left_shape(arr_out.ndim()); - // TODO: range check - for (size_t i = 0; i < arr_out.ndim(); i++) - { - const slice_type & slice = slices[i]; - if ((slice[1] - slice[0]) % slice[2] == 0) - { - left_shape[i] = (slice[1] - slice[0]) / slice[2]; - } - else - { - left_shape[i] = (slice[1] - slice[0]) / slice[2] + 1; - } - } + shape_type left_shape = modmesh::python::detail::shape_from_slices(slices); if (arr_out.ndim() != static_cast(arr_in.ndim())) { diff --git a/cpp/modmesh/buffer/pymod/array_common.hpp b/cpp/modmesh/buffer/pymod/array_common.hpp index f18c7aa8..91d9b946 100644 --- a/cpp/modmesh/buffer/pymod/array_common.hpp +++ b/cpp/modmesh/buffer/pymod/array_common.hpp @@ -254,7 +254,7 @@ class ArrayPropertyHelper const auto arr_in = py_value.cast(); auto slices = make_default_slices(arr_out); - copy_slice(slices[0], slice_in); + copy_slice(slices[0], slice_in, arr_out.shape(0)); broadcast_array_using_slice(arr_out, slices, arr_in); return; @@ -366,24 +366,37 @@ class ArrayPropertyHelper slices.reserve(arr.ndim()); for (size_t i = 0; i < arr.ndim(); ++i) { - slice_type default_slice(3); + auto const dim = static_cast(arr.shape(i)); + slice_type default_slice(4); default_slice[0] = 0; // start - default_slice[1] = static_cast(arr.shape(i)); // stop + default_slice[1] = dim; // stop default_slice[2] = 1; // step + default_slice[3] = dim; // length slices.push_back(std::move(default_slice)); } return slices; } - static void copy_slice(slice_type & slice_out, pybind11::slice const & slice_in) + static void copy_slice(slice_type & slice_out, pybind11::slice const & slice_in, size_t length) { - auto start = std::string(pybind11::str(slice_in.attr("start"))); - auto stop = std::string(pybind11::str(slice_in.attr("stop"))); - auto step = std::string(pybind11::str(slice_in.attr("step"))); + pybind11::ssize_t start = 0; + pybind11::ssize_t stop = 0; + pybind11::ssize_t step = 0; + pybind11::ssize_t slicelength = 0; + auto const signed_length = static_cast(length); + if (!slice_in.compute(signed_length, + &start, + &stop, + &step, + &slicelength)) + { + throw pybind11::error_already_set(); + } - slice_out[0] = start == "None" ? slice_out[0] : std::stoi(start); - slice_out[1] = stop == "None" ? slice_out[1] : std::stoi(stop); - slice_out[2] = step == "None" ? slice_out[2] : std::stoi(step); + slice_out[0] = start; + slice_out[1] = stop; + slice_out[2] = step; + slice_out[3] = slicelength; } static void slice_syntax_check(pybind11::tuple const & tuple, size_t ndim) @@ -409,7 +422,7 @@ class ArrayPropertyHelper } } - if (ellipsis_cnt + slice_cnt > ndim) + if (slice_cnt > ndim) { throw std::runtime_error("syntax error. dimensions mismatches"); } @@ -426,6 +439,8 @@ class ArrayPropertyHelper { namespace py = pybind11; + slice_syntax_check(tuple, ndim); + // copy slices from the front until an ellipsis bool ellipsis_flag = false; for (auto it = tuple.begin(); it != tuple.end(); it++) @@ -440,7 +455,7 @@ class ArrayPropertyHelper auto & slice_out = slices[it - tuple.begin()]; const auto slice_in = (*it).cast(); - copy_slice(slice_out, slice_in); + copy_slice(slice_out, slice_in, slices[it - tuple.begin()][3]); } // copy slices from the back until an ellipsis @@ -457,7 +472,7 @@ class ArrayPropertyHelper auto & slice_out = slices[ndim - size - 1]; const auto slice_in = (*it).cast(); - copy_slice(slice_out, slice_in); + copy_slice(slice_out, slice_in, slices[ndim - size - 1][3]); } } } diff --git a/tests/test_buffer.py b/tests/test_buffer.py index 68f91d91..f7cbec77 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -953,24 +953,34 @@ def test_SimpleArray_broadcast_slice_complex_ndarray(self): sarr[:, :] = rhs np.testing.assert_array_equal(rhs, sarr.ndarray) - def test_SimpleArray_broadcast_slice_shape(self): - ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4)) + def test_SimpleArray_broadcast_slice_from_strided_ndarray(self): + ndarr_shape = (16, 18, 20, 22, 24) + ndarr = np.arange(np.prod(ndarr_shape), dtype='float64') + ndarr = ndarr.reshape(ndarr_shape) + sliced_ndarr = ndarr[15::-2, 1::2, 19::-2, 1::2, 23::-2] + expected = np.array(sliced_ndarr, copy=True) + sarr = modmesh.SimpleArrayFloat64(array=sliced_ndarr) + self.assertEqual( + sliced_ndarr[1, 2, 3, 4, 5], + sarr[1, 2, 3, 4, 5]) - sarr = modmesh.SimpleArrayFloat64((4, 6, 8)) - with self.assertRaisesRegex( - RuntimeError, - r"Broadcast input array from shape\(2, 3, 4\) " - r"into shape\(2, 2, 2\)" - ): - sarr[::2, ::3, ::4] = ndarr[...] + key = np.index_exp[::-2, 1:-1:2, -8:-1:2, 10:2:-3, ...] + value_shape = expected[key].shape - sarr = modmesh.SimpleArrayFloat64((4, 6, 8)) - with self.assertRaisesRegex( - RuntimeError, - r"Broadcast input array from shape\(2, 3, 4\) " - r"into shape\(2, 6, 8\)" - ): - sarr[::2, ::1, ...] = ndarr[...] + value = np.empty(value_shape, dtype='float64') + next_value = 1000.0 + for i0 in range(value_shape[0]): + for i1 in range(value_shape[1]): + for i2 in range(value_shape[2]): + for i3 in range(value_shape[3]): + for i4 in range(value_shape[4]): + value[i0, i1, i2, i3, i4] = next_value + next_value += 1 + + sarr[key] = value + expected[key] = value + + np.testing.assert_array_equal(sliced_ndarr, expected) def test_SimpleArray_broadcast_slice_ghost_1d(self): import math @@ -1011,6 +1021,64 @@ def test_SimpleArray_broadcast_slice_ghost_md(self): for k in range(4): self.assertEqual(ndarr2[i, j, k], sarr[i - G, j, k]) + def test_SimpleArray_broadcast_slice_negative_bounds(self): + ndarr_shape = (8, 9, 10, 11, 12) + ndarr = np.arange(np.prod(ndarr_shape), dtype='float64') + ndarr = ndarr.reshape(ndarr_shape) + expected = ndarr.copy() + sarr = modmesh.SimpleArrayFloat64(array=ndarr) + + key = np.index_exp[-8:-1:2, 8:1:-2, 1:-1:3, ..., -2:-8:-2] + value_shape = expected[key].shape + + value = np.empty(value_shape, dtype='float64') + next_value = 1000.0 + for i0 in range(value_shape[0]): + for i1 in range(value_shape[1]): + for i2 in range(value_shape[2]): + for i3 in range(value_shape[3]): + for i4 in range(value_shape[4]): + value[i0, i1, i2, i3, i4] = next_value + next_value += 1 + + sarr[key] = value + expected[key] = value + + np.testing.assert_array_equal(sarr.ndarray, expected) + np.testing.assert_array_equal(ndarr, expected) + + def test_SimpleArray_broadcast_slice_shape(self): + ndarr = np.arange(2 * 3 * 4, dtype='float64').reshape((2, 3, 4)) + + sarr = modmesh.SimpleArrayFloat64((4, 6, 8)) + with self.assertRaisesRegex( + RuntimeError, + r"Broadcast input array from shape\(2, 3, 4\) " + r"into shape\(2, 2, 2\)" + ): + sarr[::2, ::3, ::4] = ndarr[...] + + sarr = modmesh.SimpleArrayFloat64((4, 6, 8)) + with self.assertRaisesRegex( + RuntimeError, + r"Broadcast input array from shape\(2, 3, 4\) " + r"into shape\(2, 6, 8\)" + ): + sarr[::2, ::1, ...] = ndarr[...] + + sarr = modmesh.SimpleArrayFloat64((2, 3, 4)) + with self.assertRaisesRegex( + RuntimeError, + r"syntax error\. dimensions mismatches" + ): + sarr[:, :, :, :] = np.zeros((2, 3, 4, 1), dtype='float64') + + def test_SimpleArray_broadcast_slice_zero_step(self): + ndarr = np.arange(4 * 5 * 6, dtype='float64').reshape((4, 5, 6)) + sarr = modmesh.SimpleArrayFloat64(array=ndarr) + with self.assertRaisesRegex(ValueError, "slice step cannot be zero"): + sarr[::0, :, :] = np.zeros((4, 5, 6), dtype='float64') + def test_SimpleArray_broadcast_from_list_list(self): sarr = modmesh.SimpleArrayFloat64((2, 3)) sarr[:, :] = [[1, 2, 3], [4, 5, 6]]