|
14 | 14 | # |
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | | -import math |
18 | 17 | import weakref |
19 | 18 | from collections import Counter |
20 | 19 | from collections.abc import Iterable |
@@ -1092,64 +1091,134 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: |
1092 | 1091 | # performance issues, but we will revisit this decision later once |
1093 | 1092 | # we have enough evidence that that's not the case. |
1094 | 1093 |
|
| 1094 | + in_dim = 0 |
| 1095 | + out_dim = 0 |
| 1096 | + |
1095 | 1097 | in_shape = self.shape |
1096 | 1098 | out_shape = newshape |
1097 | 1099 |
|
1098 | | - needs_copy = False |
1099 | | - out_pos = 0 |
1100 | | - for in_elem in in_shape: |
| 1100 | + in_ndim = len(in_shape) |
| 1101 | + out_ndim = len(out_shape) |
| 1102 | + |
| 1103 | + groups = [] |
| 1104 | + |
| 1105 | + while in_dim < in_ndim and out_dim < out_ndim: |
| 1106 | + prev_in_dim = in_dim |
| 1107 | + prev_out_dim = out_dim |
| 1108 | + |
| 1109 | + in_prod = 1 |
1101 | 1110 | out_prod = 1 |
1102 | | - while out_prod < in_elem and out_pos < len(out_shape): |
1103 | | - out_prod *= out_shape[out_pos] |
1104 | | - out_pos += 1 |
1105 | | - if out_prod != in_elem: |
1106 | | - needs_copy = True |
1107 | | - break |
| 1111 | + |
| 1112 | + while True: |
| 1113 | + if in_prod < out_prod: |
| 1114 | + in_prod *= in_shape[in_dim] |
| 1115 | + in_dim += 1 |
| 1116 | + else: |
| 1117 | + out_prod *= out_shape[out_dim] |
| 1118 | + out_dim += 1 |
| 1119 | + if in_prod == out_prod: |
| 1120 | + if in_dim < in_ndim and in_shape[in_dim] == 1: |
| 1121 | + in_dim += 1 |
| 1122 | + break |
| 1123 | + |
| 1124 | + in_group = in_shape[prev_in_dim:in_dim] |
| 1125 | + out_group = out_shape[prev_out_dim:out_dim] |
| 1126 | + groups.append((in_group, out_group)) |
| 1127 | + |
| 1128 | + while in_dim < in_ndim: |
| 1129 | + assert in_shape[in_dim] == 1 |
| 1130 | + groups.append(((1,), ())) |
| 1131 | + in_dim += 1 |
| 1132 | + |
| 1133 | + while out_dim < out_ndim: |
| 1134 | + assert out_shape[out_dim] == 1 |
| 1135 | + groups.append(((), (1,))) |
| 1136 | + out_dim += 1 |
| 1137 | + |
| 1138 | + needs_linearization = any(len(src_g) > 1 for src_g, _ in groups) |
| 1139 | + needs_delinearization = any(len(tgt_g) > 1 for _, tgt_g in groups) |
| 1140 | + needs_copy = needs_linearization or needs_delinearization |
1108 | 1141 |
|
1109 | 1142 | if needs_copy: |
1110 | | - flat_size = math.prod(in_shape) |
1111 | | - flat_array = runtime.create_empty_thunk( |
1112 | | - (flat_size,), dtype=self.base.type, inputs=[self] |
| 1143 | + tmp_shape: NdShape = () |
| 1144 | + for src_g, tgt_g in groups: |
| 1145 | + if len(src_g) > 1 and len(tgt_g) > 1: |
| 1146 | + tmp_shape += (_prod(tgt_g),) |
| 1147 | + else: |
| 1148 | + tmp_shape += tgt_g |
| 1149 | + |
| 1150 | + result = runtime.create_empty_thunk( |
| 1151 | + tmp_shape, dtype=self.base.type, inputs=[self] |
1113 | 1152 | ) |
1114 | | - in_shape_store = flat_array.base.delinearize(0, in_shape) |
1115 | | - out_shape_store = flat_array.base.delinearize(0, out_shape) |
1116 | | - in_shape_array = DeferredArray(in_shape_store) |
1117 | | - in_shape_array.copy(self, deep=True) |
1118 | | - result = DeferredArray(out_shape_store) |
1119 | | - return result |
1120 | | - |
1121 | | - src = self.base |
1122 | | - |
1123 | | - # Process each dimension from right to left |
1124 | | - out_pos = len(out_shape) - 1 |
1125 | | - for src_dim, elem_in in zip( |
1126 | | - range(len(in_shape) - 1, -1, -1), reversed(in_shape) |
1127 | | - ): |
1128 | | - if out_pos >= 0 and elem_in == out_shape[out_pos]: |
1129 | | - # Case 1: Dimensions match exactly |
1130 | | - out_pos -= 1 |
1131 | | - continue |
1132 | | - |
1133 | | - if elem_in == 1: |
1134 | | - # Case 2: Input dimension is 1 (projection) |
1135 | | - src = src.project(src_dim, 0) |
1136 | | - continue |
1137 | | - |
1138 | | - # Case 3: Delinearize operation |
1139 | | - new_sizes = [] |
1140 | | - out_prod = 1 |
1141 | | - while out_prod < elem_in and out_pos >= 0: |
1142 | | - out_prod *= out_shape[out_pos] |
1143 | | - new_sizes.append(out_shape[out_pos]) |
1144 | | - out_pos -= 1 |
1145 | 1153 |
|
1146 | | - src = src.delinearize(src_dim, tuple(reversed(new_sizes))) |
| 1154 | + src = self.base |
| 1155 | + tgt = result.base # type: ignore |
| 1156 | + |
| 1157 | + src_dim = 0 |
| 1158 | + tgt_dim = 0 |
| 1159 | + for src_g, tgt_g in groups: |
| 1160 | + diff = 1 |
| 1161 | + if src_g == tgt_g: |
| 1162 | + assert len(src_g) == 1 |
| 1163 | + elif len(src_g) == 0: |
| 1164 | + assert tgt_g == (1,) |
| 1165 | + src = src.promote(src_dim, 1) |
| 1166 | + elif len(tgt_g) == 0: |
| 1167 | + assert src_g == (1,) |
| 1168 | + tgt = tgt.promote(tgt_dim, 1) |
| 1169 | + elif len(src_g) == 1: |
| 1170 | + src = src.delinearize(src_dim, tgt_g) |
| 1171 | + diff = len(tgt_g) |
| 1172 | + else: |
| 1173 | + tgt = tgt.delinearize(tgt_dim, src_g) |
| 1174 | + diff = len(src_g) |
| 1175 | + |
| 1176 | + src_dim += diff |
| 1177 | + tgt_dim += diff |
| 1178 | + |
| 1179 | + assert src.shape == tgt.shape |
| 1180 | + |
| 1181 | + src_array = DeferredArray(src) |
| 1182 | + tgt_array = DeferredArray(tgt) |
| 1183 | + tgt_array.copy(src_array, deep=True) |
| 1184 | + |
| 1185 | + if needs_delinearization and needs_linearization: |
| 1186 | + src = result.base # type: ignore |
| 1187 | + src_dim = 0 |
| 1188 | + for src_g, tgt_g in groups: |
| 1189 | + if len(src_g) > 1 and len(tgt_g) > 1: |
| 1190 | + src = src.delinearize(src_dim, tgt_g) |
| 1191 | + src_dim += len(tgt_g) |
| 1192 | + |
| 1193 | + assert src.shape == newshape |
| 1194 | + src_array = DeferredArray(src) |
| 1195 | + result = runtime.create_empty_thunk( |
| 1196 | + newshape, dtype=self.base.type, inputs=[self] |
| 1197 | + ) |
| 1198 | + result.copy(src_array, deep=True) |
| 1199 | + |
| 1200 | + else: |
| 1201 | + src = self.base |
| 1202 | + src_dim = 0 |
| 1203 | + for src_g, tgt_g in groups: |
| 1204 | + diff = 1 |
| 1205 | + if src_g == tgt_g: |
| 1206 | + assert len(src_g) == 1 |
| 1207 | + elif len(src_g) == 0: |
| 1208 | + assert tgt_g == (1,) |
| 1209 | + src = src.promote(src_dim, 1) |
| 1210 | + elif len(tgt_g) == 0: |
| 1211 | + assert src_g == (1,) |
| 1212 | + src = src.project(src_dim, 0) |
| 1213 | + diff = 0 |
| 1214 | + else: |
| 1215 | + # unreachable |
| 1216 | + assert False |
| 1217 | + |
| 1218 | + src_dim += diff |
1147 | 1219 |
|
1148 | | - # Process remaining output dimensions by adding new dimensions (promotion) |
1149 | | - for _ in range(out_pos + 1): |
1150 | | - src = src.promote(0, 1) |
| 1220 | + result = DeferredArray(src) |
1151 | 1221 |
|
1152 | | - result = DeferredArray(src) |
1153 | 1222 | return result |
1154 | 1223 |
|
1155 | 1224 | def squeeze(self, axis: int | tuple[int, ...] | None) -> DeferredArray: |
|
0 commit comments