@@ -289,6 +289,72 @@ def test_permute_transpose_fusion(self) -> None:
289289 graph_copy , converted_graph , (x_input ,), "FuseCascadedTransposeOrPermuteOps"
290290 )
291291
292+ def test_cascaded_permutes_multiple_users (self ) -> None :
293+ # Test case where intermediate permute has multiple users.
294+ # x
295+ # |
296+ # permute1
297+ # / | \
298+ # permute2 permute3 permute4
299+ # | | |
300+ # out0 out1 permute5
301+ # |
302+ # out2
303+
304+ builder = GraphBuilder ()
305+ x_input = torch .randn (2 , 3 , 8 , 8 , dtype = torch .float32 )
306+ x = builder .placeholder ("x" , x_input )
307+ permute1 = builder .call_operator (
308+ op = exir_ops .edge .aten .permute_copy .default ,
309+ args = (x , [0 , 2 , 3 , 1 ]),
310+ )
311+ permute2 = builder .call_operator (
312+ op = exir_ops .edge .aten .permute_copy .default ,
313+ args = (permute1 , [0 , 3 , 1 , 2 ]),
314+ )
315+ permute3 = builder .call_operator (
316+ op = exir_ops .edge .aten .permute_copy .default ,
317+ args = (permute1 , [0 , 1 , 3 , 2 ]),
318+ )
319+ permute4 = builder .call_operator (
320+ op = exir_ops .edge .aten .permute_copy .default ,
321+ args = (permute1 , [3 , 2 , 1 , 0 ]),
322+ )
323+ permute5 = builder .call_operator (
324+ op = exir_ops .edge .aten .permute_copy .default ,
325+ args = (permute4 , [1 , 2 , 3 , 0 ]),
326+ )
327+ builder .output ([permute2 , permute3 , permute5 ])
328+ original_graph = builder .get_graph_module ()
329+ graph_copy = copy .deepcopy (original_graph )
330+
331+ p = FuseCascadedTransposeOrPermuteOps ()
332+ result = p .call (original_graph )
333+ self .assertTrue (result .modified )
334+ converted_graph = result .graph_module
335+
336+ # permute2 becomes a no-op, permute3 and permute5 fused with preceding permutes
337+ # into new single permutes.
338+ output0 , output1 , output2 = converted_graph .graph .output_node ().args [0 ]
339+ # out0: permute1 + permute2 = identity, so it connects to the graph input.
340+ graph_input = converted_graph .graph .find_nodes (op = "placeholder" )[0 ]
341+ self .assertIs (output0 , graph_input )
342+ # out1: permute1 [0,2,3,1] + permute3 [0,1,3,2] fused to [0,2,1,3].
343+ self .assertEqual (output1 .target , exir_ops .edge .aten .permute_copy .default )
344+ self .assertIs (output1 .args [0 ], graph_input )
345+ self .assertEqual (output1 .args [1 ], [0 , 2 , 1 , 3 ])
346+ # out2: permute1 [0,2,3,1] + permute4 [3,2,1,0] + permute5 [1,2,3,0]
347+ # fused to [3,2,0,1].
348+ self .assertEqual (output2 .target , exir_ops .edge .aten .permute_copy .default )
349+ self .assertIs (output2 .args [0 ], graph_input )
350+ self .assertEqual (output2 .args [1 ], [3 , 2 , 0 , 1 ])
351+ validate_numerics (
352+ graph_copy ,
353+ converted_graph ,
354+ (x_input ,),
355+ "FuseCascadedTransposeOrPermuteOps_multiple_users" ,
356+ )
357+
292358 def test_view_fusion (self ) -> None :
293359 builder = GraphBuilder ()
294360 x_input = torch .randn (8 , 5 , 3 , dtype = torch .float32 )
0 commit comments