From 98ad6c2f3de49aef5191927eb8eee61123a67ef9 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sat, 7 Sep 2024 11:32:26 -0700 Subject: [PATCH 1/7] partition mutable buffer to coreml state --- .../coreml/partition/coreml_partitioner.py | 6 ++- .../coreml/test/test_coreml_partitioner.py | 49 +++++++++++++++++++ exir/backend/utils.py | 34 +++++++++++++ 3 files changed, 88 insertions(+), 1 deletion(-) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index ecf6d44b19c..adedbdee836 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -17,7 +17,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase @@ -61,6 +61,7 @@ def __init__( self, skip_ops_for_coreml_delegation: Optional[List[str]] = None, compile_specs: Optional[List[CompileSpec]] = None, + take_over_mutable_buffer: Optional[bool] = True, ) -> None: if skip_ops_for_coreml_delegation is None: skip_ops_for_coreml_delegation = [] @@ -69,6 +70,7 @@ def __init__( backend_id=CoreMLBackend.__name__, compile_specs=compile_specs if compile_specs is not None else [], ) + self.take_over_mutable_buffer = take_over_mutable_buffer def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible @@ -89,6 +91,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags[tag] = self.delegation_spec tag_constant_data(exported_program) + if self.take_over_mutable_buffer: + tag_mutated_buffer(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 34cf531b261..02ae3830e65 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -2,13 +2,17 @@ # # Please refer to the license found in the LICENSE file in the root directory of the source tree. +import pytest import unittest +import coremltools as ct + import executorch.exir import torch import torchvision +from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner @@ -86,8 +90,53 @@ def test_vit_skip_conv(self): if node.op == "call_function" ] == total + @pytest.mark.skipif( + "b" in ct.__version__ or ct.__version__ < "8.0", + reason="coremltools 8.0 or higher is required" + ) + def test_buffer(self): + embedding_dim = 3 + max_seq_len = 2 + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32)) + + def forward(self, q, k_val, input_pos): + q_T = q.transpose(0, 1) + k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) + attn = k.mm(q_T) + return attn + + model = Model() + model.eval() + + q = torch.randn((1, embedding_dim)) + k_val = torch.randn((1, embedding_dim)) + input_pos = torch.tensor([0]) + example_inputs = (q, k_val, input_pos) + exir_program_aten = torch.export.export(model, example_inputs) + + compile_specs = CoreMLBackend.generate_compile_specs(minimum_deployment_target=ct.target.iOS18) + partitioner = CoreMLPartitioner(compile_specs=compile_specs) + edge_program_manager = executorch.exir.to_edge( + exir_program_aten, compile_config=self.edge_compile_config + ) + delegated_program_manager = edge_program_manager.to_backend(partitioner) + + assert [ + node.target.__name__ + for node in delegated_program_manager.exported_program().graph.nodes + if node.op == "call_function" + ] == [ + "executorch_call_delegate", + "getitem", + ] + if __name__ == "__main__": test_runner = TestCoreMLPartitioner() test_runner.test_add_sub_skip_mm() test_runner.test_vit_skip_conv() + test_runner.test_buffer() diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 2b768fe7c23..fb5e16c6bd0 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -383,6 +383,40 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: node.meta["delegation_tag"] = user_tags.pop() +def tag_mutated_buffer(edge_program: ExportedProgram) -> None: + """ + Util function for partitioners. This function tags the mutated buffer nodes + whose users all belong within the same partition. This should be called after tagging all other nodes. + Any buffer which is used as input to a subgraph, will be tagged with the same tag as that + subgraph. Throw error when buffers is used across different partitions. That is the + underlying data will be owned by multiple delegates. + """ + for node in edge_program.graph.nodes: + # Determine whether this node is a mutated buffer + is_mutated_buffer_node = False + if node.op == "placeholder" and is_buffer(edge_program, node): + for node_user in node.users: + if node_user.name in edge_program.graph_signature.buffers_to_mutate: + is_mutated_buffer_node = True + break + # This node is mutated buffer, tag it + if is_mutated_buffer_node: + user_tags = set() + for user in node.users: + user_tag = user.meta.get("delegation_tag", None) + if user_tag is not None: + user_tags.add(user_tag) + if len(user_tags) > 1: + logging.info( + f"The data node is used across multiple partitions, including {user_tags}. " + "If the data is too large and it's not preferred to copy, please tag the " + "constant node like node.['no_copy'] = True and they won't be copied." + ) + # tag the data node with the same tag as the last user + if len(user_tags) > 0: + node.meta["delegation_tag"] = user_tags.pop() + + # TODO - style: use templated types class DelegateMappingBuilder: """ From 8a6d0de8e0bf4e66c17eb74ed5c07c6dfb68a268 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sat, 7 Sep 2024 11:33:08 -0700 Subject: [PATCH 2/7] delegate llama mutable buffer to coreml --- examples/models/llama2/export_llama_lib.py | 9 ++++++++- extension/llm/export/partitioner_lib.py | 18 ++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index e56d2fe848b..b49f78369e6 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -281,6 +281,13 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") + parser.add_argument( + "--coreml-disable-state", + dest="coreml_enable_state", + default=True, # Enable this by default + action="store_false", + help="Delegate mutable buffer to Core ML state", + ) parser.add_argument( "--qnn", action="store_true", @@ -504,7 +511,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.coreml: coreml_partitioner = get_coreml_partitioner( - args.use_kv_cache, args.pt2e_quantize + args.use_kv_cache and args.coreml_enable_state, args.pt2e_quantize ) partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index e75d5bef3fb..8605dee4d5d 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -56,11 +56,8 @@ def get_mps_partitioner(use_kv_cache: bool = False): def get_coreml_partitioner( - use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None + enable_state: bool = False, pt2e_quantize: Optional[str] = None ): - assert ( - use_kv_cache is True - ), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: import coremltools as ct from executorch.backends.apple.coreml.compiler import ( # pyre-ignore @@ -75,7 +72,10 @@ def get_coreml_partitioner( ) minimum_deployment_target = ct.target.iOS15 - # In Core ML, quantization in introduced in iOS 16 + # In Core ML, stateful execution is introduced in iOS 18 + if enable_state: + minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + # In Core ML, quantization is introduced in iOS 16 if pt2e_quantize is not None: minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16) # In Core ML, 8-bit activation quantization is introduced in iOS 17 @@ -84,13 +84,6 @@ def get_coreml_partitioner( # In Core ML, 4-bit weight compression is introduced in iOS 18 if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"): minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) - # In Core ML, stateful execution is introduced in iOS 18 - # TODO (https://github.com/pytorch/executorch/issues/4209) - # For now, since mutable buffer is kept in executorch runtime, - # state is out of place and can be handled by older iOS. - # Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have - # if use_kv_cache: - # minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=minimum_deployment_target, @@ -101,6 +94,7 @@ def get_coreml_partitioner( ) return CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, + take_over_mutable_buffer=enable_state, ) From 6a7cccab347bd3248868b1f60b94bcb455bf48be Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sat, 7 Sep 2024 21:23:44 -0700 Subject: [PATCH 3/7] fix lint --- .../apple/coreml/test/test_coreml_partitioner.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 02ae3830e65..06f4a3ed647 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -2,13 +2,14 @@ # # Please refer to the license found in the LICENSE file in the root directory of the source tree. -import pytest import unittest import coremltools as ct import executorch.exir +import pytest + import torch import torchvision @@ -92,7 +93,7 @@ def test_vit_skip_conv(self): @pytest.mark.skipif( "b" in ct.__version__ or ct.__version__ < "8.0", - reason="coremltools 8.0 or higher is required" + reason="coremltools 8.0 or higher is required", ) def test_buffer(self): embedding_dim = 3 @@ -101,7 +102,10 @@ def test_buffer(self): class Model(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("cache", torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32)) + self.register_buffer( + "cache", + torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32), + ) def forward(self, q, k_val, input_pos): q_T = q.transpose(0, 1) @@ -118,7 +122,9 @@ def forward(self, q, k_val, input_pos): example_inputs = (q, k_val, input_pos) exir_program_aten = torch.export.export(model, example_inputs) - compile_specs = CoreMLBackend.generate_compile_specs(minimum_deployment_target=ct.target.iOS18) + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18 + ) partitioner = CoreMLPartitioner(compile_specs=compile_specs) edge_program_manager = executorch.exir.to_edge( exir_program_aten, compile_config=self.edge_compile_config From c20ef9fcaf2e3e4d8b92d3e577965a7c52a5bcbf Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sun, 8 Sep 2024 12:25:28 -0700 Subject: [PATCH 4/7] support embedding quantize --- examples/models/llama2/export_llama_lib.py | 4 +++- extension/llm/export/partitioner_lib.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index b49f78369e6..a501ae8cd99 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -511,7 +511,9 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.coreml: coreml_partitioner = get_coreml_partitioner( - args.use_kv_cache and args.coreml_enable_state, args.pt2e_quantize + args.use_kv_cache and args.coreml_enable_state, + args.embedding_quantize, + args.pt2e_quantize, ) partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 8605dee4d5d..e8b75d1f477 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -56,7 +56,9 @@ def get_mps_partitioner(use_kv_cache: bool = False): def get_coreml_partitioner( - enable_state: bool = False, pt2e_quantize: Optional[str] = None + enable_state: bool = False, + embedding_quantize: Optional[str] = None, + pt2e_quantize: Optional[str] = None, ): try: import coremltools as ct @@ -76,13 +78,17 @@ def get_coreml_partitioner( if enable_state: minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) # In Core ML, quantization is introduced in iOS 16 - if pt2e_quantize is not None: + if embedding_quantize is not None or pt2e_quantize is not None: minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16) # In Core ML, 8-bit activation quantization is introduced in iOS 17 - if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): + if ( + embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8 + ) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17) # In Core ML, 4-bit weight compression is introduced in iOS 18 - if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"): + if ( + embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4 + ) or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"): minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] From 1735a201adeb4bad62d6bfad37de6bfca5051931 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sun, 8 Sep 2024 13:00:28 -0700 Subject: [PATCH 5/7] try fix CI: 1. pin coremltools 8.0b2; 2. refrain from defaulting stateful llama until CI machine upgraded to MacOS 15 --- backends/apple/coreml/scripts/install_requirements.sh | 2 +- backends/apple/coreml/test/test_coreml_partitioner.py | 6 ------ examples/models/llama2/export_llama_lib.py | 8 +------- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index 0018b5ffc2d..ac0320e5920 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party" mkdir "$COREML_DIR_PATH/third-party" echo "${green}ExecuTorch: Cloning coremltools." -git clone --depth 1 --branch 8.0b1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH +git clone --depth 1 --branch 8.0b2 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH cd $COREMLTOOLS_DIR_PATH STATUS=$? diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 06f4a3ed647..72a7fbf0932 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -8,8 +8,6 @@ import executorch.exir -import pytest - import torch import torchvision @@ -91,10 +89,6 @@ def test_vit_skip_conv(self): if node.op == "call_function" ] == total - @pytest.mark.skipif( - "b" in ct.__version__ or ct.__version__ < "8.0", - reason="coremltools 8.0 or higher is required", - ) def test_buffer(self): embedding_dim = 3 max_seq_len = 2 diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index a501ae8cd99..f2bd153bf02 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -281,13 +281,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") - parser.add_argument( - "--coreml-disable-state", - dest="coreml_enable_state", - default=True, # Enable this by default - action="store_false", - help="Delegate mutable buffer to Core ML state", - ) + parser.add_argument("--coreml-enable-state", action="store_true") parser.add_argument( "--qnn", action="store_true", From 5af7700af0f7bb00779e3b00cf40beaabc7f2443 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Sun, 8 Sep 2024 13:40:59 -0700 Subject: [PATCH 6/7] address review comments: 1. add arg help info; 2. add mutable buffer partition log --- backends/apple/coreml/partition/coreml_partitioner.py | 7 +++++++ examples/models/llama2/export_llama_lib.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index adedbdee836..c0b6663f729 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -92,6 +92,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: tag_constant_data(exported_program) if self.take_over_mutable_buffer: + logger.info( + "Core ML partitioner will take over torch mutable buffer as Core ML state, " + "so if your model contains mutable buffer, " + "then you will need MacOS15+/iOS18+ to execute. " + "If you want your mutable buffer model to be compatible with older OS, " + "then please set `take_over_mutable_buffer=False`" + ) tag_mutated_buffer(exported_program) return PartitionResult( diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index f2bd153bf02..78d3d9605d1 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -281,7 +281,11 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") - parser.add_argument("--coreml-enable-state", action="store_true") + parser.add_argument( + "--coreml-enable-state", + action="store_true", + help="This option is only for coreml, and is only supported for MacOS15+/iOS18+", + ) parser.add_argument( "--qnn", action="store_true", From 0bfa4222216d80ecfd650620cd3330dc6bc78006 Mon Sep 17 00:00:00 2001 From: yifan_shen3 Date: Mon, 9 Sep 2024 10:56:16 -0700 Subject: [PATCH 7/7] fix CI: executorch example model test env is using older transformers, that does not support numpy 2.0 --- backends/apple/coreml/scripts/install_requirements.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index ac0320e5920..b6c9a073e08 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -47,6 +47,11 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel echo "${green}ExecuTorch: Installing coremltools." pip install "$COREMLTOOLS_DIR_PATH" +# CoreMLTools have started supporting numpy 2.0, +# but ExecuTorch example model test env is still using older transformers, +# so for now we will need to downgrade numpy to 1.x +# TODO: Remove this numpy downgrade once later transformers starts to be used +pip install numpy==1.26.4 STATUS=$? if [ $STATUS -ne 0 ]; then echo "${red}ExecuTorch: Failed to install coremltools."