@@ -83,6 +83,9 @@ CPStatus BasicIndexingPropagator::propagate(CPPropagatorsState& p_state,
8383
8484 const BasicIndexingNode* bi = dynamic_cast <const BasicIndexingNode*>(basic_indexing_->node_ );
8585 assert (bi);
86+
87+ // Not caching this for now as we may need to fit these at propagate time for
88+ // dynamic arrays
8689 std::vector<BasicIndexingNode::slice_or_int> slices = bi->infer_indices ();
8790 for (ssize_t axis = 0 ; axis < array_->node_ ->ndim (); ++axis) {
8891 if (std::holds_alternative<Slice>(slices[axis])) {
@@ -98,23 +101,26 @@ CPStatus BasicIndexingPropagator::propagate(CPPropagatorsState& p_state,
98101 ssize_t bi_index = indices_to_process.front ();
99102 indices_to_process.pop_front ();
100103
101- std::vector<ssize_t > in_multi_index =
104+ // Derive the original array index based on the index of the basic indexing variable.
105+ // We unravel the basic indexing variable index, transform the multi-index into
106+ // one on the original array, and then ravel it to get the final linear index on
107+ // the array.
108+ std::vector<ssize_t > bi_multi_index =
102109 unravel_index (bi_index, basic_indexing_->node_ ->shape ());
103- std::vector<ssize_t > out_multi_index ;
110+ std::vector<ssize_t > arr_multi_index ;
104111 ssize_t bi_axis = 0 ;
105112 for (ssize_t axis = 0 ; axis < array_->node_ ->ndim (); ++axis) {
106113 if (std::holds_alternative<ssize_t >(slices[axis])) {
107- out_multi_index .push_back (std::get<ssize_t >(slices[axis]));
114+ arr_multi_index .push_back (std::get<ssize_t >(slices[axis]));
108115 continue ;
109116 }
110117 assert (std::holds_alternative<Slice>(slices[axis]));
111118 const auto & slice = std::get<Slice>(slices[axis]);
112119 assert (slice.step == 1 );
113- out_multi_index .push_back (in_multi_index [bi_axis] + slice.start );
120+ arr_multi_index .push_back (bi_multi_index [bi_axis] + slice.start );
114121 bi_axis++;
115122 }
116-
117- ssize_t array_index = ravel_multi_index (out_multi_index, array_->node_ ->shape ());
123+ ssize_t array_index = ravel_multi_index (arr_multi_index, array_->node_ ->shape ());
118124
119125 // Now we make the bounds of the array element and the basic indexing element equal
120126
0 commit comments