Skip to content

Commit 6328abc

Browse files
authored
Merge pull request #451 from abergeron/fix_segfault
Fix crash in reshape.
2 parents a4c7381 + c441951 commit 6328abc

3 files changed

Lines changed: 55 additions & 31 deletions

File tree

src/gpuarray_array.c

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -696,41 +696,43 @@ int GpuArray_reshape_inplace(GpuArray *a, unsigned int nd,
696696
if (newstrides == NULL)
697697
return error_sys(ctx->err, "calloc");
698698

699-
while (ni < nd && oi < a->nd) {
700-
np = newdims[ni];
701-
op = a->dimensions[oi];
699+
if (newsize != 0) {
700+
while (ni < nd && oi < a->nd) {
701+
np = newdims[ni];
702+
op = a->dimensions[oi];
703+
704+
while (np != op) {
705+
if (np < op) {
706+
np *= newdims[nj++];
707+
} else {
708+
op *= a->dimensions[oj++];
709+
}
710+
}
702711

703-
while (np != op) {
704-
if (np < op) {
705-
np *= newdims[nj++];
706-
} else {
707-
op *= a->dimensions[oj++];
712+
for (ok = oi; ok < oj - 1; ok++) {
713+
if (ord == GA_F_ORDER) {
714+
if (a->strides[ok+1] != (ssize_t)a->dimensions[ok]*a->strides[ok])
715+
goto need_copy;
716+
} else {
717+
if (a->strides[ok] != (ssize_t)a->dimensions[ok+1]*a->strides[ok+1])
718+
goto need_copy;
719+
}
708720
}
709-
}
710721

711-
for (ok = oi; ok < oj - 1; ok++) {
712722
if (ord == GA_F_ORDER) {
713-
if (a->strides[ok+1] != (ssize_t)a->dimensions[ok]*a->strides[ok])
714-
goto need_copy;
723+
newstrides[ni] = a->strides[oi];
724+
for (nk = ni + 1; nk < nj; nk++) {
725+
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1];
726+
}
715727
} else {
716-
if (a->strides[ok] != (ssize_t)a->dimensions[ok+1]*a->strides[ok+1])
717-
goto need_copy;
718-
}
719-
}
720-
721-
if (ord == GA_F_ORDER) {
722-
newstrides[ni] = a->strides[oi];
723-
for (nk = ni + 1; nk < nj; nk++) {
724-
newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1];
725-
}
726-
} else {
727-
newstrides[nj-1] = a->strides[oj-1];
728-
for (nk = nj-1; nk > ni; nk--) {
729-
newstrides[nk-1] = newstrides[nk]*newdims[nk];
728+
newstrides[nj-1] = a->strides[oj-1];
729+
for (nk = nj-1; nk > ni; nk--) {
730+
newstrides[nk-1] = newstrides[nk]*newdims[nk];
731+
}
730732
}
733+
ni = nj++;
734+
oi = oj++;
731735
}
732-
ni = nj++;
733-
oi = oj++;
734736
}
735737

736738
/* Fixup trailing ones */

tests/CMakeLists.txt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
include(CheckSymbolExists)
1+
include(CheckCSourceCompiles)
22
find_package(PkgConfig)
33

44
pkg_search_module(CHECK check)
@@ -17,8 +17,15 @@ if(NOT CHECK_FOUND)
1717
endif()
1818

1919
if(CHECK_FOUND)
20-
set(CMAKE_REQUIRED_INCLUDE ${CHECK_INCLUDE_DIRS})
21-
CHECK_SYMBOL_EXISTS(ck_assert_ptr_ne "check.h" CHECK_FUNCS)
20+
set(CMAKE_REQUIRED_FLAGS ${CHECK_C_FLAGS} ${CHECK_LDFLAGS_OTHERS})
21+
set(CMAKE_REQUIRED_INCLUDES ${CHECK_INCLUDE_DIRS})
22+
set(CMAKE_REQUIRED_LIBRARIES ${CHECK_LIBRARIES})
23+
CHECK_C_SOURCE_COMPILES(
24+
"#include <check.h>
25+
int main() {
26+
ck_assert_ptr_ne(NULL, NULL);
27+
}"
28+
CHECK_FUNCS)
2229
if (NOT CHECK_FUNCS)
2330
set(CHECK_FOUND 0)
2431
endif()

tests/check_array.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,28 @@ START_TEST(test_take1_offset) {
270270
}
271271
END_TEST
272272

273+
START_TEST(test_reshape_0) {
274+
/* This tests that we don't segfault when reshaping 0-sized arrays */
275+
const size_t odims[3] = {24, 0, 33};
276+
const size_t ndims1[3] = {0, 24, 33};
277+
const size_t ndims2[3] = {24, 33, 0};
278+
279+
GpuArray v;
280+
ga_assert_ok(GpuArray_empty(&v, ctx, GA_FLOAT, 3, odims, GA_C_ORDER));
281+
ga_assert_ok(GpuArray_reshape_inplace(&v, 3, ndims1, GA_ANY_ORDER));
282+
ga_assert_ok(GpuArray_reshape_inplace(&v, 3, odims, GA_ANY_ORDER));
283+
ga_assert_ok(GpuArray_reshape_inplace(&v, 3, ndims2, GA_ANY_ORDER));
284+
}
285+
END_TEST
286+
273287
Suite *get_suite(void) {
274288
Suite *s = suite_create("array");
275289
TCase *tc = tcase_create("take1");
276290
tcase_add_checked_fixture(tc, setup, teardown);
277291
tcase_set_timeout(tc, 8.0);
278292
tcase_add_test(tc, test_take1_ok);
279293
tcase_add_test(tc, test_take1_offset);
294+
tcase_add_test(tc, test_reshape_0);
280295
suite_add_tcase(s, tc);
281296
return s;
282297
}

0 commit comments

Comments
 (0)