|
14 | 14 | # |
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import math |
17 | 18 | import weakref |
18 | 19 | from collections import Counter |
19 | 20 | from collections.abc import Iterable |
@@ -1091,134 +1092,64 @@ def reshape(self, newshape: NdShape, order: OrderType) -> NumPyThunk: |
1091 | 1092 | # performance issues, but we will revisit this decision later once |
1092 | 1093 | # we have enough evidence that that's not the case. |
1093 | 1094 |
|
1094 | | - in_dim = 0 |
1095 | | - out_dim = 0 |
1096 | | - |
1097 | 1095 | in_shape = self.shape |
1098 | 1096 | out_shape = newshape |
1099 | 1097 |
|
1100 | | - in_ndim = len(in_shape) |
1101 | | - out_ndim = len(out_shape) |
1102 | | - |
1103 | | - groups = [] |
1104 | | - |
1105 | | - while in_dim < in_ndim and out_dim < out_ndim: |
1106 | | - prev_in_dim = in_dim |
1107 | | - prev_out_dim = out_dim |
1108 | | - |
1109 | | - in_prod = 1 |
| 1098 | + needs_copy = False |
| 1099 | + out_pos = 0 |
| 1100 | + for in_elem in in_shape: |
1110 | 1101 | out_prod = 1 |
1111 | | - |
1112 | | - while True: |
1113 | | - if in_prod < out_prod: |
1114 | | - in_prod *= in_shape[in_dim] |
1115 | | - in_dim += 1 |
1116 | | - else: |
1117 | | - out_prod *= out_shape[out_dim] |
1118 | | - out_dim += 1 |
1119 | | - if in_prod == out_prod: |
1120 | | - if in_dim < in_ndim and in_shape[in_dim] == 1: |
1121 | | - in_dim += 1 |
1122 | | - break |
1123 | | - |
1124 | | - in_group = in_shape[prev_in_dim:in_dim] |
1125 | | - out_group = out_shape[prev_out_dim:out_dim] |
1126 | | - groups.append((in_group, out_group)) |
1127 | | - |
1128 | | - while in_dim < in_ndim: |
1129 | | - assert in_shape[in_dim] == 1 |
1130 | | - groups.append(((1,), ())) |
1131 | | - in_dim += 1 |
1132 | | - |
1133 | | - while out_dim < out_ndim: |
1134 | | - assert out_shape[out_dim] == 1 |
1135 | | - groups.append(((), (1,))) |
1136 | | - out_dim += 1 |
1137 | | - |
1138 | | - needs_linearization = any(len(src_g) > 1 for src_g, _ in groups) |
1139 | | - needs_delinearization = any(len(tgt_g) > 1 for _, tgt_g in groups) |
1140 | | - needs_copy = needs_linearization or needs_delinearization |
| 1102 | + while out_prod < in_elem and out_pos < len(out_shape): |
| 1103 | + out_prod *= out_shape[out_pos] |
| 1104 | + out_pos += 1 |
| 1105 | + if out_prod != in_elem: |
| 1106 | + needs_copy = True |
| 1107 | + break |
1141 | 1108 |
|
1142 | 1109 | if needs_copy: |
1143 | | - tmp_shape: NdShape = () |
1144 | | - for src_g, tgt_g in groups: |
1145 | | - if len(src_g) > 1 and len(tgt_g) > 1: |
1146 | | - tmp_shape += (_prod(tgt_g),) |
1147 | | - else: |
1148 | | - tmp_shape += tgt_g |
1149 | | - |
1150 | | - result = runtime.create_empty_thunk( |
1151 | | - tmp_shape, dtype=self.base.type, inputs=[self] |
| 1110 | + flat_size = math.prod(in_shape) |
| 1111 | + flat_array = runtime.create_empty_thunk( |
| 1112 | + (flat_size,), dtype=self.base.type, inputs=[self] |
1152 | 1113 | ) |
| 1114 | + in_shape_store = flat_array.base.delinearize(0, in_shape) |
| 1115 | + out_shape_store = flat_array.base.delinearize(0, out_shape) |
| 1116 | + in_shape_array = DeferredArray(in_shape_store) |
| 1117 | + in_shape_array.copy(self, deep=True) |
| 1118 | + result = DeferredArray(out_shape_store) |
| 1119 | + return result |
| 1120 | + |
| 1121 | + src = self.base |
| 1122 | + |
| 1123 | + # Process each dimension from right to left |
| 1124 | + out_pos = len(out_shape) - 1 |
| 1125 | + for src_dim, elem_in in zip( |
| 1126 | + range(len(in_shape) - 1, -1, -1), reversed(in_shape) |
| 1127 | + ): |
| 1128 | + if out_pos >= 0 and elem_in == out_shape[out_pos]: |
| 1129 | + # Case 1: Dimensions match exactly |
| 1130 | + out_pos -= 1 |
| 1131 | + continue |
| 1132 | + |
| 1133 | + if elem_in == 1: |
| 1134 | + # Case 2: Input dimension is 1 (projection) |
| 1135 | + src = src.project(src_dim, 0) |
| 1136 | + continue |
| 1137 | + |
| 1138 | + # Case 3: Delinearize operation |
| 1139 | + new_sizes = [] |
| 1140 | + out_prod = 1 |
| 1141 | + while out_prod < elem_in and out_pos >= 0: |
| 1142 | + out_prod *= out_shape[out_pos] |
| 1143 | + new_sizes.append(out_shape[out_pos]) |
| 1144 | + out_pos -= 1 |
1153 | 1145 |
|
1154 | | - src = self.base |
1155 | | - tgt = result.base # type: ignore |
1156 | | - |
1157 | | - src_dim = 0 |
1158 | | - tgt_dim = 0 |
1159 | | - for src_g, tgt_g in groups: |
1160 | | - diff = 1 |
1161 | | - if src_g == tgt_g: |
1162 | | - assert len(src_g) == 1 |
1163 | | - elif len(src_g) == 0: |
1164 | | - assert tgt_g == (1,) |
1165 | | - src = src.promote(src_dim, 1) |
1166 | | - elif len(tgt_g) == 0: |
1167 | | - assert src_g == (1,) |
1168 | | - tgt = tgt.promote(tgt_dim, 1) |
1169 | | - elif len(src_g) == 1: |
1170 | | - src = src.delinearize(src_dim, tgt_g) |
1171 | | - diff = len(tgt_g) |
1172 | | - else: |
1173 | | - tgt = tgt.delinearize(tgt_dim, src_g) |
1174 | | - diff = len(src_g) |
1175 | | - |
1176 | | - src_dim += diff |
1177 | | - tgt_dim += diff |
1178 | | - |
1179 | | - assert src.shape == tgt.shape |
1180 | | - |
1181 | | - src_array = DeferredArray(src) |
1182 | | - tgt_array = DeferredArray(tgt) |
1183 | | - tgt_array.copy(src_array, deep=True) |
1184 | | - |
1185 | | - if needs_delinearization and needs_linearization: |
1186 | | - src = result.base # type: ignore |
1187 | | - src_dim = 0 |
1188 | | - for src_g, tgt_g in groups: |
1189 | | - if len(src_g) > 1 and len(tgt_g) > 1: |
1190 | | - src = src.delinearize(src_dim, tgt_g) |
1191 | | - src_dim += len(tgt_g) |
1192 | | - |
1193 | | - assert src.shape == newshape |
1194 | | - src_array = DeferredArray(src) |
1195 | | - result = runtime.create_empty_thunk( |
1196 | | - newshape, dtype=self.base.type, inputs=[self] |
1197 | | - ) |
1198 | | - result.copy(src_array, deep=True) |
1199 | | - |
1200 | | - else: |
1201 | | - src = self.base |
1202 | | - src_dim = 0 |
1203 | | - for src_g, tgt_g in groups: |
1204 | | - diff = 1 |
1205 | | - if src_g == tgt_g: |
1206 | | - assert len(src_g) == 1 |
1207 | | - elif len(src_g) == 0: |
1208 | | - assert tgt_g == (1,) |
1209 | | - src = src.promote(src_dim, 1) |
1210 | | - elif len(tgt_g) == 0: |
1211 | | - assert src_g == (1,) |
1212 | | - src = src.project(src_dim, 0) |
1213 | | - diff = 0 |
1214 | | - else: |
1215 | | - # unreachable |
1216 | | - assert False |
1217 | | - |
1218 | | - src_dim += diff |
| 1146 | + src = src.delinearize(src_dim, tuple(reversed(new_sizes))) |
1219 | 1147 |
|
1220 | | - result = DeferredArray(src) |
| 1148 | + # Process remaining output dimensions by adding new dimensions (promotion) |
| 1149 | + for _ in range(out_pos + 1): |
| 1150 | + src = src.promote(0, 1) |
1221 | 1151 |
|
| 1152 | + result = DeferredArray(src) |
1222 | 1153 | return result |
1223 | 1154 |
|
1224 | 1155 | def squeeze(self, axis: int | tuple[int, ...] | None) -> DeferredArray: |
|
0 commit comments