Skip to content

Commit 915723b

Browse files
committed
up
1 parent bed9aee commit 915723b

5 files changed

Lines changed: 248 additions & 73 deletions

File tree

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,37 @@ class COMPILE_SPEC_KEYS(Enum):
4848

4949

5050
class MULTIMETHOD_WEIGHT_SHARING_STRATEGY(Enum):
51-
# Methods are processed independently with no weight sharing.
51+
"""Strategy for sharing weights across methods in multi-method models.
52+
53+
When exporting a model with multiple methods (e.g., prefill and decode),
54+
these strategies control how CoreML models are organized and how weights
55+
are shared. Different strategies have different tradeoffs — experiment
56+
with them to find the best fit for your use case.
57+
58+
DISABLED:
59+
Each method is compiled into its own independent CoreML model.
60+
No weight sharing occurs; weights are duplicated across methods.
61+
Simplest strategy with no constraints on model structure.
62+
63+
POSITIONAL:
64+
Partitions are aligned by index across methods. Partition 0 from
65+
all methods are combined into one multifunction CoreML model,
66+
partition 1 into another, and so on. This enables weight sharing
67+
for parameters that appear at the same partition index. Requires
68+
all methods to have the same number of partitions.
69+
70+
ONE_BLOB:
71+
All partitions from all methods are packed into a single
72+
multifunction CoreML model. This maximizes weight sharing
73+
opportunities (any parameter can be shared across any method)
74+
and does not require partition counts to match. However, it may
75+
result in longer compile times and higher peak memory since the
76+
entire model — including any method-specific (non-shared) weights
77+
— lives in a single blob.
78+
"""
79+
5280
DISABLED = "disabled"
53-
# Partitions must align positionally across methods; enables weight sharing
54-
# via NamedDataStore. Raises an error if partition counts don't match.
5581
POSITIONAL = "positional"
56-
# All partitions from all methods are combined into a single multifunction
57-
# model. No partition count alignment is required. Function names use
58-
# "{method_name}__{partition_idx}" encoding.
5982
ONE_BLOB = "one_blob"
6083

6184

@@ -843,7 +866,9 @@ def _preprocess_positional(
843866
f"Method '{method_name}' has {len(programs)} partitions, but "
844867
f"'{first_method}' has {num_partitions}. POSITIONAL weight sharing "
845868
"strategy requires all methods to have the same number of partitions. "
846-
"Use MULTIMETHOD_WEIGHT_SHARING_STRATEGY.DISABLED if methods should "
869+
"Use MULTIMETHOD_WEIGHT_SHARING_STRATEGY.ONE_BLOB (which supports "
870+
"different partition counts per method) or "
871+
"MULTIMETHOD_WEIGHT_SHARING_STRATEGY.DISABLED if methods should "
847872
"be processed independently."
848873
)
849874

@@ -1034,7 +1059,7 @@ def _preprocess_one_blob(
10341059
method_spec = method_model.get_spec()
10351060
input_names = [inp.name for inp in method_spec.description.input]
10361061
output_names = [out.name for out in method_spec.description.output]
1037-
methods_metadata[method_name] = MethodMetadata(
1062+
methods_metadata[function_name] = MethodMetadata(
10381063
inputNames=input_names,
10391064
outputNames=output_names,
10401065
)

backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,16 +656,22 @@ - (nullable ETCoreMLAsset *)modelAssetWithMetadata:(const ModelMetadata&)metadat
656656
return nil;
657657
}
658658

659-
std::string method_name_str = [methodName UTF8String];
660-
const MethodMetadata* method_metadata = metadataValue.get_method_metadata(method_name_str);
659+
if (functionName == nil || functionName.length == 0) {
660+
ETCoreMLLogErrorAndSetNSError(error,
661+
ETCoreMLErrorCorruptedModel,
662+
"functionName must be non-nil and non-empty for multifunction model metadata lookup.");
663+
return nil;
664+
}
665+
std::string lookup_key = [functionName UTF8String];
666+
const MethodMetadata* method_metadata = metadataValue.get_method_metadata(lookup_key);
661667
if (method_metadata != nullptr) {
662668
metadataValue.input_names = method_metadata->input_names;
663669
metadataValue.output_names = method_metadata->output_names;
664670
} else {
665671
ETCoreMLLogErrorAndSetNSError(error,
666672
ETCoreMLErrorCorruptedModel,
667-
"Method '%@' not found in multifunction model metadata.",
668-
methodName);
673+
"Function '%@' not found in multifunction model metadata.",
674+
functionName);
669675
return nil;
670676
}
671677
}

backends/apple/coreml/test/test_coreml_multifunction.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99

1010
import coremltools as ct
1111
import torch
12-
12+
import torch.nn as nn
1313
from executorch.backends.apple.coreml.compiler.coreml_preprocess import (
1414
CoreMLBackend,
1515
MULTIMETHOD_WEIGHT_SHARING_STRATEGY,
1616
)
1717
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1818
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
19+
from executorch.exir.graph_break import remove_graph_break_ops
1920

2021

2122
def is_fbcode():
@@ -320,6 +321,92 @@ def test_multifunction_one_blob_simple_model(self):
320321
)
321322
)
322323

324+
def test_multifunction_one_blob_multiple_partitions(self):
325+
"""Test ONE_BLOB with multiple partitions per method.
326+
327+
Uses graph breaks to force the CoreML partitioner to create multiple
328+
partitions within each method (forward and prefill). The two partitions
329+
have a different number of inputs and outputs so their metadata
330+
(input/output name lists) differ.
331+
332+
Partition 0: 1 input (x) → 2 outputs (a, b)
333+
Partition 1: 2 inputs (a, b) → 1 output (result)
334+
"""
335+
336+
class _GraphBreak(nn.Module):
337+
def forward(self, x):
338+
return torch.ops.executorch_utils.graph_break.Tensor(x)
339+
340+
class MultiPartitionModel(nn.Module):
341+
def __init__(self):
342+
super().__init__()
343+
self.linear_a = nn.Linear(16, 16)
344+
self.linear_b = nn.Linear(16, 16)
345+
self.graph_break_a = _GraphBreak()
346+
self.graph_break_b = _GraphBreak()
347+
self.linear_out = nn.Linear(32, 16)
348+
349+
def forward(self, x):
350+
a = self.linear_a(x)
351+
b = self.linear_b(x)
352+
a = self.graph_break_a(a)
353+
b = self.graph_break_b(b)
354+
combined = torch.cat([a, b], dim=-1)
355+
return self.linear_out(combined)
356+
357+
model = MultiPartitionModel()
358+
model.eval()
359+
360+
decode_inputs = (torch.randn(1, 1, 16),)
361+
prefill_inputs = (torch.randn(1, 8, 16),)
362+
363+
exported_programs = {
364+
"forward": torch.export.export(model, decode_inputs),
365+
"prefill": torch.export.export(model, prefill_inputs),
366+
}
367+
368+
partitioner = CoreMLPartitioner(
369+
compile_specs=self._get_compile_specs(
370+
strategy=MULTIMETHOD_WEIGHT_SHARING_STRATEGY.ONE_BLOB,
371+
),
372+
)
373+
374+
edge_manager = to_edge_transform_and_lower(
375+
exported_programs,
376+
partitioner=[partitioner],
377+
compile_config=self.edge_compile_config,
378+
)
379+
380+
self.assertIn("forward", edge_manager.methods)
381+
self.assertIn("prefill", edge_manager.methods)
382+
383+
remove_graph_break_ops(edge_manager)
384+
385+
et_program = edge_manager.to_executorch()
386+
387+
if _TEST_RUNTIME:
388+
runtime = Runtime.get()
389+
program = runtime.load_program(et_program.buffer)
390+
391+
self.assertIn("forward", program.method_names)
392+
self.assertIn("prefill", program.method_names)
393+
394+
forward_method = program.load_method("forward")
395+
decode_output = forward_method.execute(decode_inputs)
396+
expected_decode = model(*decode_inputs)
397+
self.assertTrue(
398+
torch.allclose(decode_output[0], expected_decode, atol=1e-4, rtol=1e-4)
399+
)
400+
401+
prefill_method = program.load_method("prefill")
402+
prefill_output = prefill_method.execute(prefill_inputs)
403+
expected_prefill = model(*prefill_inputs)
404+
self.assertTrue(
405+
torch.allclose(
406+
prefill_output[0], expected_prefill, atol=1e-4, rtol=1e-4
407+
)
408+
)
409+
323410

324411
if __name__ == "__main__":
325412
test_runner = TestCoreMLMultifunction()
@@ -328,4 +415,5 @@ def test_multifunction_one_blob_simple_model(self):
328415
test_runner.test_multifunction_without_weight_sharing()
329416
test_runner.test_multifunction_with_constant_methods()
330417
test_runner.test_multifunction_one_blob_simple_model()
418+
test_runner.test_multifunction_one_blob_multiple_partitions()
331419
print("All tests passed!")

examples/apple/coreml/llama/export_static_llm_coreml.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424

2525
import coremltools as ct
2626
import torch
27-
import torch.nn as nn
2827
import torch.utils._pytree as pytree
29-
3028
from executorch.backends.apple.coreml.compiler.coreml_preprocess import (
3129
CoreMLBackend,
3230
MULTIMETHOD_WEIGHT_SHARING_STRATEGY,
@@ -42,69 +40,13 @@
4240
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
4341
from executorch.exir.backend.utils import format_delegated_graph
4442
from executorch.exir.capture._config import ExecutorchBackendConfig
43+
from executorch.exir.graph_break import BlockWithGraphBreak, remove_graph_break_ops
4544
from executorch.exir.passes import MemoryPlanningPass
4645
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
4746
from executorch.extension.export_util.utils import save_pte_program
48-
from torch.library import impl, Library
4947
from torchao.quantization.granularity import PerAxis, PerGroup
5048
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
5149

52-
# Define custom graph break op
53-
lib = Library("executorch_utils", "DEF")
54-
lib.define("graph_break.Tensor(Tensor x) -> Tensor")
55-
56-
57-
@impl(lib, "graph_break.Tensor", "CompositeExplicitAutograd")
58-
def graph_break_impl(x):
59-
return x
60-
61-
62-
class ExecutorchGraphBreakModule(nn.Module):
63-
def __init__(self):
64-
super().__init__()
65-
66-
def forward(self, *args, **kwargs):
67-
return tuple(
68-
(
69-
torch.ops.executorch_utils.graph_break.Tensor(a)
70-
if isinstance(a, torch.Tensor)
71-
else a
72-
)
73-
for a in args
74-
)
75-
76-
77-
class BlockWithGraphBreak(nn.Module):
78-
def __init__(self, block: nn.Module, break_before: bool = True):
79-
super().__init__()
80-
self.graph_break = ExecutorchGraphBreakModule()
81-
self.block = block
82-
self.break_before = break_before
83-
84-
def forward(self, *args, **kwargs):
85-
if self.break_before:
86-
new_args = self.graph_break(*args)
87-
out = self.block(*new_args, **kwargs)
88-
return out
89-
else:
90-
out = self.block(*args, **kwargs)
91-
out = self.graph_break(*out)
92-
return out
93-
94-
95-
def remove_graph_break_(edge_manager):
96-
"""Remove graph break ops from all methods in the edge manager."""
97-
from executorch.exir.dialects._ops import ops as exir_ops
98-
99-
# Get all method names
100-
method_names = edge_manager.methods
101-
for method_name in method_names:
102-
ep = edge_manager.exported_program(method_name)
103-
for n in ep.graph_module.graph.nodes:
104-
if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor:
105-
n.replace_all_uses_with(n.args[0])
106-
ep.graph_module.graph.eliminate_dead_code()
107-
10850

10951
def load_model(
11052
checkpoint_path: str,
@@ -695,7 +637,7 @@ def main():
695637

696638
# Convert to ExecuTorch
697639
print("\nConverting to ExecuTorch...")
698-
remove_graph_break_(edge_manager)
640+
remove_graph_break_ops(edge_manager)
699641
executorch_program = edge_manager.to_executorch(
700642
ExecutorchBackendConfig(
701643
extract_delegate_segments=True,

0 commit comments

Comments
 (0)