Skip to content

Commit 1e54aa4

Browse files
authored
Update reshape() implementation to avoid unnecessary deepcopy (#543)
* Update reshape() implementation to avoid unnecessary deepcopy * Optimize reshape function and update tests * Refactor reshaping logic and add type hints to test
1 parent 43413c0 commit 1e54aa4

2 files changed

Lines changed: 91 additions & 118 deletions

File tree

cupynumeric/_thunk/deferred.py

Lines changed: 49 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515
from __future__ import annotations
1616

17+
import math
1718
import weakref
1819
from collections import Counter
1920
from collections.abc import Iterable
@@ -1091,134 +1092,64 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk:
10911092
# performance issues, but we will revisit this decision later once
10921093
# we have enough evidence that that's not the case.
10931094

1094-
in_dim = 0
1095-
out_dim = 0
1096-
10971095
in_shape = self.shape
10981096
out_shape = newshape
10991097

1100-
in_ndim = len(in_shape)
1101-
out_ndim = len(out_shape)
1102-
1103-
groups = []
1104-
1105-
while in_dim < in_ndim and out_dim < out_ndim:
1106-
prev_in_dim = in_dim
1107-
prev_out_dim = out_dim
1108-
1109-
in_prod = 1
1098+
needs_copy = False
1099+
out_pos = 0
1100+
for in_elem in in_shape:
11101101
out_prod = 1
1111-
1112-
while True:
1113-
if in_prod < out_prod:
1114-
in_prod *= in_shape[in_dim]
1115-
in_dim += 1
1116-
else:
1117-
out_prod *= out_shape[out_dim]
1118-
out_dim += 1
1119-
if in_prod == out_prod:
1120-
if in_dim < in_ndim and in_shape[in_dim] == 1:
1121-
in_dim += 1
1122-
break
1123-
1124-
in_group = in_shape[prev_in_dim:in_dim]
1125-
out_group = out_shape[prev_out_dim:out_dim]
1126-
groups.append((in_group, out_group))
1127-
1128-
while in_dim < in_ndim:
1129-
assert in_shape[in_dim] == 1
1130-
groups.append(((1,), ()))
1131-
in_dim += 1
1132-
1133-
while out_dim < out_ndim:
1134-
assert out_shape[out_dim] == 1
1135-
groups.append(((), (1,)))
1136-
out_dim += 1
1137-
1138-
needs_linearization = any(len(src_g) > 1 for src_g, _ in groups)
1139-
needs_delinearization = any(len(tgt_g) > 1 for _, tgt_g in groups)
1140-
needs_copy = needs_linearization or needs_delinearization
1102+
while out_prod < in_elem and out_pos < len(out_shape):
1103+
out_prod *= out_shape[out_pos]
1104+
out_pos += 1
1105+
if out_prod != in_elem:
1106+
needs_copy = True
1107+
break
11411108

11421109
if needs_copy:
1143-
tmp_shape: NdShape = ()
1144-
for src_g, tgt_g in groups:
1145-
if len(src_g) > 1 and len(tgt_g) > 1:
1146-
tmp_shape += (_prod(tgt_g),)
1147-
else:
1148-
tmp_shape += tgt_g
1149-
1150-
result = runtime.create_empty_thunk(
1151-
tmp_shape, dtype=self.base.type, inputs=[self]
1110+
flat_size = math.prod(in_shape)
1111+
flat_array = runtime.create_empty_thunk(
1112+
(flat_size,), dtype=self.base.type, inputs=[self]
11521113
)
1114+
in_shape_store = flat_array.base.delinearize(0, in_shape)
1115+
out_shape_store = flat_array.base.delinearize(0, out_shape)
1116+
in_shape_array = DeferredArray(in_shape_store)
1117+
in_shape_array.copy(self, deep=True)
1118+
result = DeferredArray(out_shape_store)
1119+
return result
1120+
1121+
src = self.base
1122+
1123+
# Process each dimension from right to left
1124+
out_pos = len(out_shape) - 1
1125+
for src_dim, elem_in in zip(
1126+
range(len(in_shape) - 1, -1, -1), reversed(in_shape)
1127+
):
1128+
if out_pos >= 0 and elem_in == out_shape[out_pos]:
1129+
# Case 1: Dimensions match exactly
1130+
out_pos -= 1
1131+
continue
1132+
1133+
if elem_in == 1:
1134+
# Case 2: Input dimension is 1 (projection)
1135+
src = src.project(src_dim, 0)
1136+
continue
1137+
1138+
# Case 3: Delinearize operation
1139+
new_sizes = []
1140+
out_prod = 1
1141+
while out_prod < elem_in and out_pos >= 0:
1142+
out_prod *= out_shape[out_pos]
1143+
new_sizes.append(out_shape[out_pos])
1144+
out_pos -= 1
11531145

1154-
src = self.base
1155-
tgt = result.base # type: ignore
1156-
1157-
src_dim = 0
1158-
tgt_dim = 0
1159-
for src_g, tgt_g in groups:
1160-
diff = 1
1161-
if src_g == tgt_g:
1162-
assert len(src_g) == 1
1163-
elif len(src_g) == 0:
1164-
assert tgt_g == (1,)
1165-
src = src.promote(src_dim, 1)
1166-
elif len(tgt_g) == 0:
1167-
assert src_g == (1,)
1168-
tgt = tgt.promote(tgt_dim, 1)
1169-
elif len(src_g) == 1:
1170-
src = src.delinearize(src_dim, tgt_g)
1171-
diff = len(tgt_g)
1172-
else:
1173-
tgt = tgt.delinearize(tgt_dim, src_g)
1174-
diff = len(src_g)
1175-
1176-
src_dim += diff
1177-
tgt_dim += diff
1178-
1179-
assert src.shape == tgt.shape
1180-
1181-
src_array = DeferredArray(src)
1182-
tgt_array = DeferredArray(tgt)
1183-
tgt_array.copy(src_array, deep=True)
1184-
1185-
if needs_delinearization and needs_linearization:
1186-
src = result.base # type: ignore
1187-
src_dim = 0
1188-
for src_g, tgt_g in groups:
1189-
if len(src_g) > 1 and len(tgt_g) > 1:
1190-
src = src.delinearize(src_dim, tgt_g)
1191-
src_dim += len(tgt_g)
1192-
1193-
assert src.shape == newshape
1194-
src_array = DeferredArray(src)
1195-
result = runtime.create_empty_thunk(
1196-
newshape, dtype=self.base.type, inputs=[self]
1197-
)
1198-
result.copy(src_array, deep=True)
1199-
1200-
else:
1201-
src = self.base
1202-
src_dim = 0
1203-
for src_g, tgt_g in groups:
1204-
diff = 1
1205-
if src_g == tgt_g:
1206-
assert len(src_g) == 1
1207-
elif len(src_g) == 0:
1208-
assert tgt_g == (1,)
1209-
src = src.promote(src_dim, 1)
1210-
elif len(tgt_g) == 0:
1211-
assert src_g == (1,)
1212-
src = src.project(src_dim, 0)
1213-
diff = 0
1214-
else:
1215-
# unreachable
1216-
assert False
1217-
1218-
src_dim += diff
1146+
src = src.delinearize(src_dim, tuple(reversed(new_sizes)))
12191147

1220-
result = DeferredArray(src)
1148+
# Process remaining output dimensions by adding new dimensions (promotion)
1149+
for _ in range(out_pos + 1):
1150+
src = src.promote(0, 1)
12211151

1152+
result = DeferredArray(src)
12221153
return result
12231154

12241155
def squeeze(self, axis: int | tuple[int, ...] | None) -> DeferredArray:
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 NVIDIA Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import pytest
17+
18+
import cupynumeric as num
19+
20+
21+
@pytest.mark.parametrize(
22+
"in_shape, out_shape",
23+
[
24+
((1, 1, 10), (10,)),
25+
((6, 1, 1), (1, 1, 6)),
26+
((12, 1), (1, 3, 4)),
27+
((12, 1, 4), (2, 3, 2, 4)),
28+
],
29+
)
30+
def test_reshape_no_copy(in_shape: tuple, out_shape: tuple) -> None:
31+
x = num.zeros(in_shape, dtype=num.int32)
32+
y = num.reshape(x, out_shape)
33+
x.fill(1)
34+
35+
assert y.shape == out_shape
36+
assert num.sum(y) != 0
37+
38+
39+
if __name__ == "__main__":
40+
import sys
41+
42+
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)