diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 9b10c3be530..8712c2709ac 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType: def _overwrite_precision(self, node: torch.fx.Node): precision = self._detect_precision(node) if precision not in self.enabled_precision_types: - # detected precision is not enabled, lets try to partition it as fp32 + # detected precision is not enabled, try to partition it as fp32 if self.enabled_precision_types == [ConfigPrecisionType.FP32]: - # if only fp32 is enabled, then we can still partition fp32 gemms + # when only fp32 is enabled, then we can still partition fp32 gemms # even with in a quantized graph if precision in [ ConfigPrecisionType.STATIC_QUANT, @@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node): precision = ConfigPrecisionType.FP32 logging.info(f"Overwriting precision, partitioning {node} as FP32") return True, precision + return False, precision def get_deps( @@ -226,8 +227,11 @@ def _get_bias_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: gemm_deps = [] - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force force_fp32_dynamic_linear is enabled, then we + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is enabled, then we # do not partition the weight node return (True, gemm_deps) @@ -305,8 +309,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is enabled, then we + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is enabled, then we # do not partition the weight node return (True, []) @@ -412,9 +419,11 @@ def __init__(self, **kwargs): def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: - # TODO(maxren, T210537195): - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is on and we detected this as fp32, then we + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we # do not partition the weight node return (True, []) @@ -501,11 +510,11 @@ def find_partition_args(input_node): node.args = old_args node.users = old_users - # When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes. + # When using force_non_static_weights_for_f32_linear, we want to get_deps to overwrite the source partition nodes. # Else we want to be greedy. ret_deps = ( list(set(deps) & set(src_partition.nodes)) - if self.force_fp32_dynamic_linear + if self.force_non_static_weights_for_f32_linear else list(set(deps) | set(src_partition.nodes)) ) @@ -531,8 +540,11 @@ def __init__(self, **kwargs): def _get_weight_deps( self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType ) -> Tuple[bool, List[torch.fx.Node]]: - if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: - # if force fp32_dynamic_linear is on and we detected this as fp32, then we + if ( + precision == ConfigPrecisionType.FP32 + and self.force_non_static_weights_for_f32_linear + ): + # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we # do not partition the weight node return (True, []) diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index d261416a76f..20018610fce 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -41,7 +41,9 @@ def __init__(self, **kwargs): super().__init__() self.enabled_precision_types = self.supported_precision_types() # Flag used in GEMMConfig() - self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False) + self.force_non_static_weights_for_f32_linear = kwargs.get( + "force_non_static_weights_for_f32_linear", False + ) def get_partition( self, node: torch.fx.Node, ep: ExportedProgram diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index b56a746651c..690a1109a17 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -948,7 +948,7 @@ def test_linear_qd8_as_fp32(self): }, ) - def test_linear_fp32_with_force_as_mm(self): + def test_linear_with_force_non_static_weights_for_f32_linear(self): def check_signature( signature: ExportGraphSignature, force_flag: bool, @@ -981,7 +981,7 @@ def check_signature( inputs = module.get_inputs() tester = Tester(module, inputs).export() partitioner = XnnpackPartitioner( - force_fp32_dynamic_linear=force_flag + force_non_static_weights_for_f32_linear=force_flag ) if legacy_mode: tester.to_edge() diff --git a/backends/xnnpack/test/ops/test_lstm.py b/backends/xnnpack/test/ops/test_lstm.py index be209082b37..6c174b16f33 100644 --- a/backends/xnnpack/test/ops/test_lstm.py +++ b/backends/xnnpack/test/ops/test_lstm.py @@ -43,18 +43,20 @@ def test_fp32_lstm(self): .run_method_and_compare_outputs() ) - def test_fp32_lstm_force_dynamic_linear(self): + def test_lstm_with_force_non_static_weights_for_f32_linear(self): ( Tester(self.LSTMLinear(32, 32, 10), (torch.rand(1, 32, 32),)) .export() .to_edge_transform_and_lower( ToEdgeTransformAndLower( - partitioners=[XnnpackPartitioner(force_fp32_dynamic_linear=True)] + partitioners=[ + XnnpackPartitioner(force_non_static_weights_for_f32_linear=True) + ] ) ) .check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"]) # Weights are supplied as input to linears - # Biases are not owned by delegates when force_fp32_dynamic_linear is set + # Biases are not owned by delegates when force_non_static_weights_for_f32_linear is set .check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"]) .to_executorch() .serialize()