Skip to content

Commit bd5f558

Browse files
authored
Revert "Update reshape() implementation to avoid unnecessary deepcopy (#543)" (#559)
This reverts commit 1e54aa4.
1 parent 1e54aa4 commit bd5f558

2 files changed

Lines changed: 118 additions & 91 deletions

File tree

cupynumeric/_thunk/deferred.py

Lines changed: 118 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#
1515
from __future__ import annotations
1616

17-
import math
1817
import weakref
1918
from collections import Counter
2019
from collections.abc import Iterable
@@ -1092,64 +1091,134 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk:
10921091
# performance issues, but we will revisit this decision later once
10931092
# we have enough evidence that that's not the case.
10941093

1094+
in_dim = 0
1095+
out_dim = 0
1096+
10951097
in_shape = self.shape
10961098
out_shape = newshape
10971099

1098-
needs_copy = False
1099-
out_pos = 0
1100-
for in_elem in in_shape:
1100+
in_ndim = len(in_shape)
1101+
out_ndim = len(out_shape)
1102+
1103+
groups = []
1104+
1105+
while in_dim < in_ndim and out_dim < out_ndim:
1106+
prev_in_dim = in_dim
1107+
prev_out_dim = out_dim
1108+
1109+
in_prod = 1
11011110
out_prod = 1
1102-
while out_prod < in_elem and out_pos < len(out_shape):
1103-
out_prod *= out_shape[out_pos]
1104-
out_pos += 1
1105-
if out_prod != in_elem:
1106-
needs_copy = True
1107-
break
1111+
1112+
while True:
1113+
if in_prod < out_prod:
1114+
in_prod *= in_shape[in_dim]
1115+
in_dim += 1
1116+
else:
1117+
out_prod *= out_shape[out_dim]
1118+
out_dim += 1
1119+
if in_prod == out_prod:
1120+
if in_dim < in_ndim and in_shape[in_dim] == 1:
1121+
in_dim += 1
1122+
break
1123+
1124+
in_group = in_shape[prev_in_dim:in_dim]
1125+
out_group = out_shape[prev_out_dim:out_dim]
1126+
groups.append((in_group, out_group))
1127+
1128+
while in_dim < in_ndim:
1129+
assert in_shape[in_dim] == 1
1130+
groups.append(((1,), ()))
1131+
in_dim += 1
1132+
1133+
while out_dim < out_ndim:
1134+
assert out_shape[out_dim] == 1
1135+
groups.append(((), (1,)))
1136+
out_dim += 1
1137+
1138+
needs_linearization = any(len(src_g) > 1 for src_g, _ in groups)
1139+
needs_delinearization = any(len(tgt_g) > 1 for _, tgt_g in groups)
1140+
needs_copy = needs_linearization or needs_delinearization
11081141

11091142
if needs_copy:
1110-
flat_size = math.prod(in_shape)
1111-
flat_array = runtime.create_empty_thunk(
1112-
(flat_size,), dtype=self.base.type, inputs=[self]
1143+
tmp_shape: NdShape = ()
1144+
for src_g, tgt_g in groups:
1145+
if len(src_g) > 1 and len(tgt_g) > 1:
1146+
tmp_shape += (_prod(tgt_g),)
1147+
else:
1148+
tmp_shape += tgt_g
1149+
1150+
result = runtime.create_empty_thunk(
1151+
tmp_shape, dtype=self.base.type, inputs=[self]
11131152
)
1114-
in_shape_store = flat_array.base.delinearize(0, in_shape)
1115-
out_shape_store = flat_array.base.delinearize(0, out_shape)
1116-
in_shape_array = DeferredArray(in_shape_store)
1117-
in_shape_array.copy(self, deep=True)
1118-
result = DeferredArray(out_shape_store)
1119-
return result
1120-
1121-
src = self.base
1122-
1123-
# Process each dimension from right to left
1124-
out_pos = len(out_shape) - 1
1125-
for src_dim, elem_in in zip(
1126-
range(len(in_shape) - 1, -1, -1), reversed(in_shape)
1127-
):
1128-
if out_pos >= 0 and elem_in == out_shape[out_pos]:
1129-
# Case 1: Dimensions match exactly
1130-
out_pos -= 1
1131-
continue
1132-
1133-
if elem_in == 1:
1134-
# Case 2: Input dimension is 1 (projection)
1135-
src = src.project(src_dim, 0)
1136-
continue
1137-
1138-
# Case 3: Delinearize operation
1139-
new_sizes = []
1140-
out_prod = 1
1141-
while out_prod < elem_in and out_pos >= 0:
1142-
out_prod *= out_shape[out_pos]
1143-
new_sizes.append(out_shape[out_pos])
1144-
out_pos -= 1
11451153

1146-
src = src.delinearize(src_dim, tuple(reversed(new_sizes)))
1154+
src = self.base
1155+
tgt = result.base # type: ignore
1156+
1157+
src_dim = 0
1158+
tgt_dim = 0
1159+
for src_g, tgt_g in groups:
1160+
diff = 1
1161+
if src_g == tgt_g:
1162+
assert len(src_g) == 1
1163+
elif len(src_g) == 0:
1164+
assert tgt_g == (1,)
1165+
src = src.promote(src_dim, 1)
1166+
elif len(tgt_g) == 0:
1167+
assert src_g == (1,)
1168+
tgt = tgt.promote(tgt_dim, 1)
1169+
elif len(src_g) == 1:
1170+
src = src.delinearize(src_dim, tgt_g)
1171+
diff = len(tgt_g)
1172+
else:
1173+
tgt = tgt.delinearize(tgt_dim, src_g)
1174+
diff = len(src_g)
1175+
1176+
src_dim += diff
1177+
tgt_dim += diff
1178+
1179+
assert src.shape == tgt.shape
1180+
1181+
src_array = DeferredArray(src)
1182+
tgt_array = DeferredArray(tgt)
1183+
tgt_array.copy(src_array, deep=True)
1184+
1185+
if needs_delinearization and needs_linearization:
1186+
src = result.base # type: ignore
1187+
src_dim = 0
1188+
for src_g, tgt_g in groups:
1189+
if len(src_g) > 1 and len(tgt_g) > 1:
1190+
src = src.delinearize(src_dim, tgt_g)
1191+
src_dim += len(tgt_g)
1192+
1193+
assert src.shape == newshape
1194+
src_array = DeferredArray(src)
1195+
result = runtime.create_empty_thunk(
1196+
newshape, dtype=self.base.type, inputs=[self]
1197+
)
1198+
result.copy(src_array, deep=True)
1199+
1200+
else:
1201+
src = self.base
1202+
src_dim = 0
1203+
for src_g, tgt_g in groups:
1204+
diff = 1
1205+
if src_g == tgt_g:
1206+
assert len(src_g) == 1
1207+
elif len(src_g) == 0:
1208+
assert tgt_g == (1,)
1209+
src = src.promote(src_dim, 1)
1210+
elif len(tgt_g) == 0:
1211+
assert src_g == (1,)
1212+
src = src.project(src_dim, 0)
1213+
diff = 0
1214+
else:
1215+
# unreachable
1216+
assert False
1217+
1218+
src_dim += diff
11471219

1148-
# Process remaining output dimensions by adding new dimensions (promotion)
1149-
for _ in range(out_pos + 1):
1150-
src = src.promote(0, 1)
1220+
result = DeferredArray(src)
11511221

1152-
result = DeferredArray(src)
11531222
return result
11541223

11551224
def squeeze(self, axis: int | tuple[int, ...] | None) -> DeferredArray:

tests/integration/test_reshape_copy.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)