@@ -85,9 +85,17 @@ static int ga_extcopy(GpuArray *dst, const GpuArray *src) {
8585/* Value below which a size_t multiplication will never overflow. */
8686#define MUL_NO_OVERFLOW (1UL << (sizeof(size_t) * 4))
8787
88- int GpuArray_empty (GpuArray * a , gpucontext * ctx ,
89- int typecode , unsigned int nd , const size_t * dims ,
90- ga_order ord ) {
88+ void GpuArray_fix_flags (GpuArray * a ) {
89+ /* Only keep the writable flag */
90+ a -> flags &= GA_WRITEABLE ;
91+ /* Set the other flags if applicable */
92+ if (GpuArray_is_c_contiguous (a )) a -> flags |= GA_C_CONTIGUOUS ;
93+ if (GpuArray_is_f_contiguous (a )) a -> flags |= GA_F_CONTIGUOUS ;
94+ if (GpuArray_is_aligned (a )) a -> flags |= GA_ALIGNED ;
95+ }
96+
97+ int GpuArray_empty (GpuArray * a , gpucontext * ctx , int typecode ,
98+ unsigned int nd , const size_t * dims , ga_order ord ) {
9199 size_t size = gpuarray_get_elsize (typecode );
92100 unsigned int i ;
93101 int res = GA_NO_ERROR ;
@@ -185,9 +193,7 @@ int GpuArray_fromdata(GpuArray *a, gpudata *data, size_t offset, int typecode,
185193 memcpy (a -> dimensions , dims , nd * sizeof (size_t ));
186194 memcpy (a -> strides , strides , nd * sizeof (ssize_t ));
187195
188- if (GpuArray_is_c_contiguous (a )) a -> flags |= GA_C_CONTIGUOUS ;
189- if (GpuArray_is_f_contiguous (a )) a -> flags |= GA_F_CONTIGUOUS ;
190- if (GpuArray_is_aligned (a )) a -> flags |= GA_ALIGNED ;
196+ GpuArray_fix_flags (a );
191197
192198 return GA_NO_ERROR ;
193199}
@@ -304,18 +310,7 @@ int GpuArray_index_inplace(GpuArray *a, const ssize_t *starts,
304310 a -> dimensions = newdims ;
305311 free (a -> strides );
306312 a -> strides = newstrs ;
307- if (GpuArray_is_c_contiguous (a ))
308- a -> flags |= GA_C_CONTIGUOUS ;
309- else
310- a -> flags &= ~GA_C_CONTIGUOUS ;
311- if (GpuArray_is_f_contiguous (a ))
312- a -> flags |= GA_F_CONTIGUOUS ;
313- else
314- a -> flags &= ~GA_F_CONTIGUOUS ;
315- if (GpuArray_is_aligned (a ))
316- a -> flags |= GA_ALIGNED ;
317- else
318- a -> flags &= ~GA_ALIGNED ;
313+ GpuArray_fix_flags (a );
319314
320315 return GA_NO_ERROR ;
321316}
@@ -582,9 +577,8 @@ int GpuArray_setarray(GpuArray *a, const GpuArray *v) {
582577 tv .nd = a -> nd ;
583578 tv .dimensions = a -> dimensions ;
584579 tv .strides = strs ;
585- /* This could be optiomized by setting the right flags */
586580 if (tv .nd != 0 )
587- tv . flags &= ~( GA_C_CONTIGUOUS | GA_F_CONTIGUOUS );
581+ GpuArray_fix_flags ( & tv );
588582 err = ga_extcopy (a , & tv );
589583 free (strs );
590584 return err ;
@@ -745,18 +739,7 @@ int GpuArray_reshape_inplace(GpuArray *a, unsigned int nd,
745739 a -> strides = newstrides ;
746740
747741 fix_flags :
748- if (GpuArray_is_c_contiguous (a ))
749- a -> flags |= GA_C_CONTIGUOUS ;
750- else
751- a -> flags &= ~GA_C_CONTIGUOUS ;
752- if (GpuArray_is_f_contiguous (a ))
753- a -> flags |= GA_F_CONTIGUOUS ;
754- else
755- a -> flags &= ~GA_F_CONTIGUOUS ;
756- if (GpuArray_is_aligned (a ))
757- a -> flags |= GA_ALIGNED ;
758- else
759- a -> flags &= ~GA_ALIGNED ;
742+ GpuArray_fix_flags (a );
760743 return GA_NO_ERROR ;
761744}
762745
@@ -808,11 +791,7 @@ int GpuArray_transpose_inplace(GpuArray *a, const unsigned int *new_axes) {
808791 a -> dimensions = newdims ;
809792 a -> strides = newstrs ;
810793
811- a -> flags &= ~(GA_C_CONTIGUOUS |GA_F_CONTIGUOUS );
812- if (GpuArray_is_c_contiguous (a ))
813- a -> flags |= GA_C_CONTIGUOUS ;
814- if (GpuArray_is_f_contiguous (a ))
815- a -> flags |= GA_F_CONTIGUOUS ;
794+ GpuArray_fix_flags (a );
816795
817796 return GA_NO_ERROR ;
818797}
@@ -1016,10 +995,9 @@ int GpuArray_concatenate(GpuArray *r, const GpuArray **as, size_t n,
1016995 res_off = r -> offset ;
1017996 res_dims = r -> dimensions ;
1018997 res_flags = r -> flags ;
1019- /* This could be optimized by setting the right flags */
1020- r -> flags &= ~(GA_C_CONTIGUOUS |GA_F_CONTIGUOUS );
1021998 for (i = 0 ; i < n ; i ++ ) {
1022999 r -> dimensions = as [i ]-> dimensions ;
1000+ GpuArray_fix_flags (r );
10231001 err = ga_extcopy (r , as [i ]);
10241002 if (err != GA_NO_ERROR ) {
10251003 r -> dimensions = res_dims ;
0 commit comments