@@ -276,14 +276,21 @@ static int check_basic(GpuElemwise *ge, void **args, int flags,
276276 GpuArray * a = NULL , * v ;
277277 unsigned int i , j , p , num_arrays = 0 , nd = 0 , nnd ;
278278 int call32 = 1 ;
279+ unsigned int nd_i = 0 ;
280+ size_t v_dim_j = 0 ;
279281
280282 /* Go through the list and grab some info */
281283 for (i = 0 ; i < ge -> n ; i ++ ) {
282284 if (is_array (ge -> args [i ])) {
285+ nd_i = ((GpuArray * )args [i ])-> nd ;
283286 if (num_arrays == 0 )
284- nd = ((GpuArray * )args [i ])-> nd ;
285- else if (((GpuArray * )args [i ])-> nd != nd )
286- return error_fmt (ctx -> err , GA_VALUE_ERROR , "Arg %u has differing nd = %u" , i , ((GpuArray * )args [i ])-> nd );
287+ nd = nd_i ;
288+ else if (nd_i != nd ) {
289+ if (flags & GE_PADSHAPE )
290+ nd = nd_i > nd ? nd_i : nd ;
291+ else
292+ return error_fmt (ctx -> err , GA_VALUE_ERROR , "Arg %u has differing nd = %u" , i , nd_i );
293+ }
287294 ++ num_arrays ;
288295 if (a == NULL && is_output (ge -> args [i ]))
289296 a = (GpuArray * )args [i ];
@@ -301,15 +308,19 @@ static int check_basic(GpuElemwise *ge, void **args, int flags,
301308 return error_sys (ctx -> err , "ge_grow" );
302309 }
303310
304- /* Now we know that all array arguments have the same number of
311+ /* Now we know that all array arguments have at most nd
305312 dimensions and that the expected output size is the size of a */
306313
307314 /* And copy their initial values in */
308315 memcpy (ge -> dims , a -> dimensions , nd * sizeof (size_t ));
309316 p = 0 ;
310317 for (i = 0 ; i < ge -> n ; i ++ ) {
311318 if (is_array (ge -> args [i ])) {
312- memcpy (ge -> strides [p ], ((GpuArray * )args [i ])-> strides , nd * sizeof (ssize_t ));
319+ /* Left-pad strides with zero on implicitly broadcasted dimensions */
320+ memset (ge -> strides [p ], 0 , nd * sizeof (ssize_t ));
321+ nd_i = ((GpuArray * )args [i ])-> nd ;
322+ memcpy ((char * )(ge -> strides [p ]) + (nd - nd_i )* sizeof (ssize_t ),
323+ ((GpuArray * )args [i ])-> strides , nd_i * sizeof (ssize_t ));
313324 p ++ ;
314325 }
315326 }
@@ -326,16 +337,23 @@ static int check_basic(GpuElemwise *ge, void **args, int flags,
326337 for (i = 0 ; i < ge -> n ; i ++ ) {
327338 if (is_array (ge -> args [i ])) {
328339 v = (GpuArray * )args [i ];
329- if (ge -> dims [j ] != v -> dimensions [j ]) {
340+ nd_i = v -> nd ;
341+ /* Pad shape with 1 if needed for implicitly broadcasted dimensions
342+ and shift if needed */
343+ if (j < nd - nd_i )
344+ v_dim_j = 1 ;
345+ else
346+ v_dim_j = v -> dimensions [j - (nd - nd_i )];
347+ if (ge -> dims [j ] != v_dim_j ) {
330348 /* We can't broadcast outputs */
331349 if (ISCLR (flags , GE_BROADCAST ) || is_output (ge -> args [i ]) ||
332- v -> dimensions [ j ] != 1 ) {
333- return error_fmt (ctx -> err , GA_VALUE_ERROR , "Mismatched dimension %u for input %u (expected %" SPREFIX "u got %" SPREFIX "u)" , j , i , ge -> dims [j ], v -> dimensions [ j ] );
350+ v_dim_j != 1 ) {
351+ return error_fmt (ctx -> err , GA_VALUE_ERROR , "Mismatched dimension %u for input %u (expected %" SPREFIX "u got %" SPREFIX "u)" , j , i , ge -> dims [j ], v_dim_j );
334352 }
335353 }
336354 /* If the dimension is 1 set the strides to 0 regardless since
337355 it won't change anything in the non-broadcast case. */
338- if (v -> dimensions [ j ] == 1 ) {
356+ if (v_dim_j == 1 ) {
339357 ge -> strides [p ][j ] = 0 ;
340358 }
341359 call32 &= v -> offset < ADDR32_MAX ;
0 commit comments