Skip to content

Commit 2a78177

Browse files
committed
Add function to fix the flags of a GpuArray.
1 parent 690827d commit 2a78177

2 files changed

Lines changed: 24 additions & 39 deletions

File tree

src/gpuarray/array.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ static inline int GpuArray_CHKFLAGS(const GpuArray *a, int flags) {
196196
*/
197197
#define GpuArray_ITEMSIZE(a) gpuarray_get_elsize((a)->typecode)
198198

199+
/**
200+
* Fix the flags of an array using the current strides and shape.
201+
*
202+
* \param a GpuArray to fix flags for
203+
*/
204+
GPUARRAY_PUBLIC void GpuArray_fix_flags(GpuArray *a);
205+
199206
/**
200207
* Initialize and allocate a new empty (uninitialized data) array.
201208
*

src/gpuarray_array.c

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,17 @@ static int ga_extcopy(GpuArray *dst, const GpuArray *src) {
8585
/* Value below which a size_t multiplication will never overflow. */
8686
#define MUL_NO_OVERFLOW (1UL << (sizeof(size_t) * 4))
8787

88-
int GpuArray_empty(GpuArray *a, gpucontext *ctx,
89-
int typecode, unsigned int nd, const size_t *dims,
90-
ga_order ord) {
88+
void GpuArray_fix_flags(GpuArray *a) {
89+
/* Only keep the writable flag */
90+
a->flags &= GA_WRITEABLE;
91+
/* Set the other flags if applicable */
92+
if (GpuArray_is_c_contiguous(a)) a->flags |= GA_C_CONTIGUOUS;
93+
if (GpuArray_is_f_contiguous(a)) a->flags |= GA_F_CONTIGUOUS;
94+
if (GpuArray_is_aligned(a)) a->flags |= GA_ALIGNED;
95+
}
96+
97+
int GpuArray_empty(GpuArray *a, gpucontext *ctx, int typecode,
98+
unsigned int nd, const size_t *dims, ga_order ord) {
9199
size_t size = gpuarray_get_elsize(typecode);
92100
unsigned int i;
93101
int res = GA_NO_ERROR;
@@ -185,9 +193,7 @@ int GpuArray_fromdata(GpuArray *a, gpudata *data, size_t offset, int typecode,
185193
memcpy(a->dimensions, dims, nd*sizeof(size_t));
186194
memcpy(a->strides, strides, nd*sizeof(ssize_t));
187195

188-
if (GpuArray_is_c_contiguous(a)) a->flags |= GA_C_CONTIGUOUS;
189-
if (GpuArray_is_f_contiguous(a)) a->flags |= GA_F_CONTIGUOUS;
190-
if (GpuArray_is_aligned(a)) a->flags |= GA_ALIGNED;
196+
GpuArray_fix_flags(a);
191197

192198
return GA_NO_ERROR;
193199
}
@@ -304,18 +310,7 @@ int GpuArray_index_inplace(GpuArray *a, const ssize_t *starts,
304310
a->dimensions = newdims;
305311
free(a->strides);
306312
a->strides = newstrs;
307-
if (GpuArray_is_c_contiguous(a))
308-
a->flags |= GA_C_CONTIGUOUS;
309-
else
310-
a->flags &= ~GA_C_CONTIGUOUS;
311-
if (GpuArray_is_f_contiguous(a))
312-
a->flags |= GA_F_CONTIGUOUS;
313-
else
314-
a->flags &= ~GA_F_CONTIGUOUS;
315-
if (GpuArray_is_aligned(a))
316-
a->flags |= GA_ALIGNED;
317-
else
318-
a->flags &= ~GA_ALIGNED;
313+
GpuArray_fix_flags(a);
319314

320315
return GA_NO_ERROR;
321316
}
@@ -582,9 +577,8 @@ int GpuArray_setarray(GpuArray *a, const GpuArray *v) {
582577
tv.nd = a->nd;
583578
tv.dimensions = a->dimensions;
584579
tv.strides = strs;
585-
/* This could be optiomized by setting the right flags */
586580
if (tv.nd != 0)
587-
tv.flags &= ~(GA_C_CONTIGUOUS|GA_F_CONTIGUOUS);
581+
GpuArray_fix_flags(&tv);
588582
err = ga_extcopy(a, &tv);
589583
free(strs);
590584
return err;
@@ -745,18 +739,7 @@ int GpuArray_reshape_inplace(GpuArray *a, unsigned int nd,
745739
a->strides = newstrides;
746740

747741
fix_flags:
748-
if (GpuArray_is_c_contiguous(a))
749-
a->flags |= GA_C_CONTIGUOUS;
750-
else
751-
a->flags &= ~GA_C_CONTIGUOUS;
752-
if (GpuArray_is_f_contiguous(a))
753-
a->flags |= GA_F_CONTIGUOUS;
754-
else
755-
a->flags &= ~GA_F_CONTIGUOUS;
756-
if (GpuArray_is_aligned(a))
757-
a->flags |= GA_ALIGNED;
758-
else
759-
a->flags &= ~GA_ALIGNED;
742+
GpuArray_fix_flags(a);
760743
return GA_NO_ERROR;
761744
}
762745

@@ -808,11 +791,7 @@ int GpuArray_transpose_inplace(GpuArray *a, const unsigned int *new_axes) {
808791
a->dimensions = newdims;
809792
a->strides = newstrs;
810793

811-
a->flags &= ~(GA_C_CONTIGUOUS|GA_F_CONTIGUOUS);
812-
if (GpuArray_is_c_contiguous(a))
813-
a->flags |= GA_C_CONTIGUOUS;
814-
if (GpuArray_is_f_contiguous(a))
815-
a->flags |= GA_F_CONTIGUOUS;
794+
GpuArray_fix_flags(a);
816795

817796
return GA_NO_ERROR;
818797
}
@@ -1016,10 +995,9 @@ int GpuArray_concatenate(GpuArray *r, const GpuArray **as, size_t n,
1016995
res_off = r->offset;
1017996
res_dims = r->dimensions;
1018997
res_flags = r->flags;
1019-
/* This could be optimized by setting the right flags */
1020-
r->flags &= ~(GA_C_CONTIGUOUS|GA_F_CONTIGUOUS);
1021998
for (i = 0; i < n; i++) {
1022999
r->dimensions = as[i]->dimensions;
1000+
GpuArray_fix_flags(r);
10231001
err = ga_extcopy(r, as[i]);
10241002
if (err != GA_NO_ERROR) {
10251003
r->dimensions = res_dims;

0 commit comments

Comments
 (0)