Skip to content

Commit bed6bfb

Browse files
committed
[xnn update prep] deprecate sdpa
ghstack-source-id: 479c94a ghstack-comment-id: 2957149191 Pull Request resolved: #11506
1 parent 1af16cd commit bed6bfb

File tree

6 files changed

+0
-311
lines changed

6 files changed

+0
-311
lines changed

backends/xnnpack/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
op_quant_dequant,
4040
op_relu,
4141
op_rsqrt,
42-
op_sdpa,
4342
op_sigmoid,
4443
op_skip_ops,
4544
op_slice_copy,

backends/xnnpack/operators/op_sdpa.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

backends/xnnpack/partition/config/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
QuantizedPerTensorConfig,
4444
ReciprocalSquareRootConfig,
4545
ReLUConfig,
46-
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4746
SigmoidConfig,
4847
SliceCopyConfig,
4948
SoftmaxConfig,
@@ -99,7 +98,6 @@
9998
PreluConfig,
10099
ReciprocalSquareRootConfig,
101100
ReLUConfig,
102-
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
103101
SigmoidConfig,
104102
SliceCopyConfig,
105103
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -527,33 +527,3 @@ class BMMConfig(GenericNodePartitionerConfig):
527527

528528
def supported_precision_types(self) -> List[ConfigPrecisionType]:
529529
return [ConfigPrecisionType.FP32]
530-
531-
532-
class SDPAConfig(GenericNodePartitionerConfig):
533-
target_name = "scaled_dot_product_attention.default"
534-
535-
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
536-
"""
537-
Requires Mask to have Rank 2
538-
"""
539-
if not self.check_common_constraints(node, ep):
540-
return False
541-
542-
if len(node.all_input_nodes) < 4:
543-
return False
544-
mask_node = node.all_input_nodes[3]
545-
mask_rank = mask_node.meta["val"].dim()
546-
if mask_rank != 2:
547-
why(
548-
node,
549-
reason=f"mask must have rank 2, got mask of rank {mask_rank}",
550-
)
551-
return False
552-
553-
return True
554-
555-
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
556-
return torch.ops.aten.scaled_dot_product_attention.default
557-
558-
def supported_precision_types(self) -> List[ConfigPrecisionType]:
559-
return [ConfigPrecisionType.FP32]

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,42 +1961,6 @@ Error defineStaticSliceNode(
19611961
return Error::Ok;
19621962
}
19631963

1964-
/*
1965-
Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
1966-
using the remapped ids to map the serialized ids,
1967-
to the new ids generated when defining the tensor value
1968-
*/
1969-
Error defineScaledDotProductAttentionNode(
1970-
xnn_subgraph_t subgraph_ptr,
1971-
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1972-
const NodePtr node,
1973-
const fb_xnnpack::XNNGraph* graph) noexcept {
1974-
MAYBE_UNUSED(graph);
1975-
1976-
auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
1977-
1978-
xnn_status status = xnn_define_scaled_dot_product_attention(
1979-
subgraph_ptr,
1980-
xnn_attention_logits_cap_type_none, // cap_type
1981-
nullptr, // cap_value - not used
1982-
remapped_ids.at(graph_node->query_id()),
1983-
remapped_ids.at(graph_node->key_id()),
1984-
remapped_ids.at(graph_node->value_id()),
1985-
remapped_ids.at(graph_node->scale_id()),
1986-
remapped_ids.at(graph_node->mask_id()),
1987-
remapped_ids.at(graph_node->output_id()),
1988-
graph_node->flags());
1989-
1990-
ET_CHECK_OR_RETURN_ERROR(
1991-
status == xnn_status_success,
1992-
Internal,
1993-
"Failed to create SDPA node %i with code: %s",
1994-
node->debug_handle(),
1995-
xnn_status_to_string(status));
1996-
1997-
return Error::Ok;
1998-
}
1999-
20001964
/*
20011965
Defines batch matrix multiply node into the subgraph,
20021966
using the remapped ids to map the serialized ids,
@@ -2097,7 +2061,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
20972061
_DEFINE(Concatenate4)
20982062
_DEFINE(Concatenate5)
20992063
_DEFINE(StaticSlice)
2100-
_DEFINE(ScaledDotProductAttention)
21012064
_DEFINE(BatchMatrixMultiply)
21022065
case fb_xnnpack::XNodeUnion::NONE:
21032066
default: // Adding here as a catch all, just in case

backends/xnnpack/test/ops/test_sdpa.py

Lines changed: 0 additions & 130 deletions
This file was deleted.

0 commit comments

Comments
 (0)