Skip to content

Commit dc4d689

Browse files
committed
[xnn update prep] deprecate sdpa
ghstack-source-id: 929a8bc ghstack-comment-id: 2957149191 Pull Request resolved: #11506
1 parent b677429 commit dc4d689

File tree

6 files changed

+0
-311
lines changed

6 files changed

+0
-311
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
op_quant_dequant,
4141
op_relu,
4242
op_rsqrt,
43-
op_sdpa,
4443
op_sigmoid,
4544
op_skip_ops,
4645
op_slice_copy,

backends/xnnpack/operators/op_sdpa.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

backends/xnnpack/partition/config/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
QuantizedPerTensorConfig,
4545
ReciprocalSquareRootConfig,
4646
ReLUConfig,
47-
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4847
SigmoidConfig,
4948
SliceCopyConfig,
5049
SoftmaxConfig,
@@ -103,7 +102,6 @@
103102
ReciprocalSquareRootConfig,
104103
ReLUConfig,
105104
TanhConfig,
106-
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
107105
SigmoidConfig,
108106
SliceCopyConfig,
109107
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -564,33 +564,3 @@ class BMMConfig(GenericNodePartitionerConfig):
564564

565565
def supported_precision_types(self) -> List[ConfigPrecisionType]:
566566
return [ConfigPrecisionType.FP32]
567-
568-
569-
class SDPAConfig(GenericNodePartitionerConfig):
570-
target_name = "scaled_dot_product_attention.default"
571-
572-
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
573-
"""
574-
Requires Mask to have Rank 2
575-
"""
576-
if not self.check_common_constraints(node, ep):
577-
return False
578-
579-
if len(node.all_input_nodes) < 4:
580-
return False
581-
mask_node = node.all_input_nodes[3]
582-
mask_rank = mask_node.meta["val"].dim()
583-
if mask_rank != 2:
584-
why(
585-
node,
586-
reason=f"mask must have rank 2, got mask of rank {mask_rank}",
587-
)
588-
return False
589-
590-
return True
591-
592-
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
593-
return torch.ops.aten.scaled_dot_product_attention.default
594-
595-
def supported_precision_types(self) -> List[ConfigPrecisionType]:
596-
return [ConfigPrecisionType.FP32]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,42 +1423,6 @@ Error defineStaticSliceNode(
14231423
return Error::Ok;
14241424
}
14251425

1426-
/*
1427-
Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
1428-
using the remapped ids to map the serialized ids,
1429-
to the new ids generated when defining the tensor value
1430-
*/
1431-
Error defineScaledDotProductAttentionNode(
1432-
xnn_subgraph_t subgraph_ptr,
1433-
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1434-
const NodePtr node,
1435-
const fb_xnnpack::XNNGraph* graph) noexcept {
1436-
MAYBE_UNUSED(graph);
1437-
1438-
auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
1439-
1440-
xnn_status status = xnn_define_scaled_dot_product_attention(
1441-
subgraph_ptr,
1442-
xnn_attention_logits_cap_type_none, // cap_type
1443-
nullptr, // cap_value - not used
1444-
remapped_ids.at(graph_node->query_id()),
1445-
remapped_ids.at(graph_node->key_id()),
1446-
remapped_ids.at(graph_node->value_id()),
1447-
remapped_ids.at(graph_node->scale_id()),
1448-
remapped_ids.at(graph_node->mask_id()),
1449-
remapped_ids.at(graph_node->output_id()),
1450-
graph_node->flags());
1451-
1452-
ET_CHECK_OR_RETURN_ERROR(
1453-
status == xnn_status_success,
1454-
Internal,
1455-
"Failed to create SDPA node %i with code: %s",
1456-
node->debug_handle(),
1457-
xnn_status_to_string(status));
1458-
1459-
return Error::Ok;
1460-
}
1461-
14621426
/*
14631427
Defines batch matrix multiply node into the subgraph,
14641428
using the remapped ids to map the serialized ids,
@@ -1788,7 +1752,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
17881752
_DEFINE(Concatenate4)
17891753
_DEFINE(Concatenate5)
17901754
_DEFINE(StaticSlice)
1791-
_DEFINE(ScaledDotProductAttention)
17921755
_DEFINE(BatchMatrixMultiply)
17931756
case fb_xnnpack::XNodeUnion::NONE:
17941757
default: // Adding here as a catch all, just in case

backends/xnnpack/test/ops/test_sdpa.py

Lines changed: 0 additions & 130 deletions
This file was deleted.

0 commit comments

Comments
 (0)