@@ -2046,7 +2046,6 @@ def repeat(a, repeats, axis=None):
20462046 --------
20472047 Multiple GPUs, Multiple CPUs
20482048 """
2049-
20502049 # when array is a scalar
20512050 if np .ndim (a ) == 0 :
20522051 if np .ndim (repeats ) == 0 :
@@ -2100,11 +2099,37 @@ def repeat(a, repeats, axis=None):
21002099 category = UserWarning ,
21012100 )
21022101 repeats = np .int64 (repeats )
2103- result = array ._thunk .repeat (
2104- repeats = repeats ,
2105- axis = axis ,
2106- scalar_repeats = True ,
2107- )
2102+ if repeats < 0 :
2103+ return ValueError (
2104+ "'repeats' should not be negative: {}" .format (repeats )
2105+ )
2106+
2107+ # check output shape (if it will fit to GPU or not)
2108+ out_shape = list (array .shape )
2109+ out_shape [axis ] *= repeats
2110+ out_shape = tuple (out_shape )
2111+ size = sum (out_shape ) * array .itemsize
2112+ # check if size of the output array is less 8GB. In this case we can
2113+ # use output regions, otherwise we will use statcally allocated
2114+ # array
2115+ print ("IRINA DEBUG 1" , size , (8589934592 / 2 - size ))
2116+ if size < 8589934592 / 2 :
2117+
2118+ result = array ._thunk .repeat (
2119+ repeats = repeats , axis = axis , scalar_repeats = True
2120+ )
2121+ else :
2122+ # this implementation is taken from CuPy
2123+ result = ndarray (shape = out_shape , dtype = array .dtype )
2124+ a_index = [slice (None )] * len (out_shape )
2125+ res_index = list (a_index )
2126+ offset = 0
2127+ for i in range (a ._shape [axis ]):
2128+ a_index [axis ] = slice (i , i + 1 )
2129+ res_index [axis ] = slice (offset , offset + repeats )
2130+ result [res_index ] = array [a_index ]
2131+ offset += repeats
2132+ return result
21082133 # repeats is an array
21092134 else :
21102135 # repeats should be integer type
@@ -2116,9 +2141,32 @@ def repeat(a, repeats, axis=None):
21162141 repeats = repeats .astype (np .int64 )
21172142 if repeats .shape [0 ] != array .shape [axis ]:
21182143 return ValueError ("incorrect shape of repeats array" )
2119- result = array ._thunk .repeat (
2120- repeats = repeats ._thunk , axis = axis , scalar_repeats = False
2121- )
2144+
2145+ # check output shape (if it will fit to GPU or not)
2146+ out_shape = list (array .shape )
2147+ n_repeats = sum (repeats )
2148+ out_shape [axis ] = n_repeats
2149+ out_shape = tuple (out_shape )
2150+ size = sum (out_shape ) * array .itemsize
2151+ # check if size of the output array is less 8GB. In this case we can
2152+ # use output regions, otherwise we will use statcally allocated
2153+ # array
2154+ print ("IRINA DEBUG 1" , size , (8589934592 / 2 - size ))
2155+ if size < 8589934592 / 2 :
2156+ result = array ._thunk .repeat (
2157+ repeats = repeats ._thunk , axis = axis , scalar_repeats = False
2158+ )
2159+ else : # this implementation is taken from CuPy
2160+ result = ndarray (shape = out_shape , dtype = array .dtype )
2161+ a_index = [slice (None )] * len (out_shape )
2162+ res_index = list (a_index )
2163+ offset = 0
2164+ for i in range (a ._shape [axis ]):
2165+ a_index [axis ] = slice (i , i + 1 )
2166+ res_index [axis ] = slice (offset , offset + repeats [i ])
2167+ result [res_index ] = array [a_index ]
2168+ offset += repeats [i ]
2169+ return result
21222170 return ndarray (shape = result .shape , thunk = result )
21232171
21242172
0 commit comments