99
1010import coremltools as ct
1111import torch
12-
12+ import torch . nn as nn
1313from executorch .backends .apple .coreml .compiler .coreml_preprocess import (
1414 CoreMLBackend ,
1515 MULTIMETHOD_WEIGHT_SHARING_STRATEGY ,
1616)
1717from executorch .backends .apple .coreml .partition import CoreMLPartitioner
1818from executorch .exir import EdgeCompileConfig , to_edge_transform_and_lower
19+ from executorch .exir .graph_break import remove_graph_break_ops
1920
2021
2122def is_fbcode ():
@@ -320,6 +321,92 @@ def test_multifunction_one_blob_simple_model(self):
320321 )
321322 )
322323
324+ def test_multifunction_one_blob_multiple_partitions (self ):
325+ """Test ONE_BLOB with multiple partitions per method.
326+
327+ Uses graph breaks to force the CoreML partitioner to create multiple
328+ partitions within each method (forward and prefill). The two partitions
329+ have a different number of inputs and outputs so their metadata
330+ (input/output name lists) differ.
331+
332+ Partition 0: 1 input (x) → 2 outputs (a, b)
333+ Partition 1: 2 inputs (a, b) → 1 output (result)
334+ """
335+
336+ class _GraphBreak (nn .Module ):
337+ def forward (self , x ):
338+ return torch .ops .executorch_utils .graph_break .Tensor (x )
339+
340+ class MultiPartitionModel (nn .Module ):
341+ def __init__ (self ):
342+ super ().__init__ ()
343+ self .linear_a = nn .Linear (16 , 16 )
344+ self .linear_b = nn .Linear (16 , 16 )
345+ self .graph_break_a = _GraphBreak ()
346+ self .graph_break_b = _GraphBreak ()
347+ self .linear_out = nn .Linear (32 , 16 )
348+
349+ def forward (self , x ):
350+ a = self .linear_a (x )
351+ b = self .linear_b (x )
352+ a = self .graph_break_a (a )
353+ b = self .graph_break_b (b )
354+ combined = torch .cat ([a , b ], dim = - 1 )
355+ return self .linear_out (combined )
356+
357+ model = MultiPartitionModel ()
358+ model .eval ()
359+
360+ decode_inputs = (torch .randn (1 , 1 , 16 ),)
361+ prefill_inputs = (torch .randn (1 , 8 , 16 ),)
362+
363+ exported_programs = {
364+ "forward" : torch .export .export (model , decode_inputs ),
365+ "prefill" : torch .export .export (model , prefill_inputs ),
366+ }
367+
368+ partitioner = CoreMLPartitioner (
369+ compile_specs = self ._get_compile_specs (
370+ strategy = MULTIMETHOD_WEIGHT_SHARING_STRATEGY .ONE_BLOB ,
371+ ),
372+ )
373+
374+ edge_manager = to_edge_transform_and_lower (
375+ exported_programs ,
376+ partitioner = [partitioner ],
377+ compile_config = self .edge_compile_config ,
378+ )
379+
380+ self .assertIn ("forward" , edge_manager .methods )
381+ self .assertIn ("prefill" , edge_manager .methods )
382+
383+ remove_graph_break_ops (edge_manager )
384+
385+ et_program = edge_manager .to_executorch ()
386+
387+ if _TEST_RUNTIME :
388+ runtime = Runtime .get ()
389+ program = runtime .load_program (et_program .buffer )
390+
391+ self .assertIn ("forward" , program .method_names )
392+ self .assertIn ("prefill" , program .method_names )
393+
394+ forward_method = program .load_method ("forward" )
395+ decode_output = forward_method .execute (decode_inputs )
396+ expected_decode = model (* decode_inputs )
397+ self .assertTrue (
398+ torch .allclose (decode_output [0 ], expected_decode , atol = 1e-4 , rtol = 1e-4 )
399+ )
400+
401+ prefill_method = program .load_method ("prefill" )
402+ prefill_output = prefill_method .execute (prefill_inputs )
403+ expected_prefill = model (* prefill_inputs )
404+ self .assertTrue (
405+ torch .allclose (
406+ prefill_output [0 ], expected_prefill , atol = 1e-4 , rtol = 1e-4
407+ )
408+ )
409+
323410
324411if __name__ == "__main__" :
325412 test_runner = TestCoreMLMultifunction ()
@@ -328,4 +415,5 @@ def test_multifunction_one_blob_simple_model(self):
328415 test_runner .test_multifunction_without_weight_sharing ()
329416 test_runner .test_multifunction_with_constant_methods ()
330417 test_runner .test_multifunction_one_blob_simple_model ()
418+ test_runner .test_multifunction_one_blob_multiple_partitions ()
331419 print ("All tests passed!" )
0 commit comments