From 271da7d8bc9f0f5ca98cdfeb999872a7a0b6cf74 Mon Sep 17 00:00:00 2001
From: asquared31415 <34665709+asquared31415@users.noreply.github.com>
Date: Thu, 7 Oct 2021 15:42:18 -0400
Subject: [PATCH] make #[target_feature] work with `asm` register classes

---
 compiler/rustc_ast_lowering/src/asm.rs        |  64 +--------
 compiler/rustc_hir/src/hir.rs                 |   7 +
 compiler/rustc_passes/src/intrinsicck.rs      | 136 +++++++++++++++---
 src/test/ui/asm/x86_64/bad-reg.rs             |   4 -
 src/test/ui/asm/x86_64/bad-reg.stderr         |  48 +++----
 src/test/ui/asm/x86_64/target-feature-attr.rs |  40 ++++++
 .../ui/asm/x86_64/target-feature-attr.stderr  |  26 ++++
 7 files changed, 213 insertions(+), 112 deletions(-)
 create mode 100644 src/test/ui/asm/x86_64/target-feature-attr.rs
 create mode 100644 src/test/ui/asm/x86_64/target-feature-attr.stderr

diff --git a/compiler/rustc_ast_lowering/src/asm.rs b/compiler/rustc_ast_lowering/src/asm.rs
index 7165b3bcb9fc1..957b14f348729 100644
--- a/compiler/rustc_ast_lowering/src/asm.rs
+++ b/compiler/rustc_ast_lowering/src/asm.rs
@@ -202,39 +202,20 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
 
         let mut used_input_regs = FxHashMap::default();
         let mut used_output_regs = FxHashMap::default();
-        let mut required_features: Vec<&str> = vec![];
+
         for (idx, &(ref op, op_sp)) in operands.iter().enumerate() {
             if let Some(reg) = op.reg() {
-                // Make sure we don't accidentally carry features from the
-                // previous iteration.
-                required_features.clear();
-
                 let reg_class = reg.reg_class();
                 if reg_class == asm::InlineAsmRegClass::Err {
                     continue;
                 }
 
-                // We ignore target feature requirements for clobbers: if the
-                // feature is disabled then the compiler doesn't care what we
-                // do with the registers.
-                //
-                // Note that this is only possible for explicit register
-                // operands, which cannot be used in the asm string.
-                let is_clobber = matches!(
-                    op,
-                    hir::InlineAsmOperand::Out {
-                        reg: asm::InlineAsmRegOrRegClass::Reg(_),
-                        late: _,
-                        expr: None
-                    }
-                );
-
                 // Some register classes can only be used as clobbers. This
                 // means that we disallow passing a value in/out of the asm and
                 // require that the operand name an explicit register, not a
                 // register class.
                 if reg_class.is_clobber_only(asm_arch.unwrap())
-                    && !(is_clobber && matches!(reg, asm::InlineAsmRegOrRegClass::Reg(_)))
+                    && !(op.is_clobber() && matches!(reg, asm::InlineAsmRegOrRegClass::Reg(_)))
                 {
                     let msg = format!(
                         "register class `{}` can only be used as a clobber, \
@@ -245,47 +226,6 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                     continue;
                 }
 
-                if !is_clobber {
-                    // Validate register classes against currently enabled target
-                    // features. We check that at least one type is available for
-                    // the current target.
-                    for &(_, feature) in reg_class.supported_types(asm_arch.unwrap()) {
-                        if let Some(feature) = feature {
-                            if self.sess.target_features.contains(&Symbol::intern(feature)) {
-                                required_features.clear();
-                                break;
-                            } else {
-                                required_features.push(feature);
-                            }
-                        } else {
-                            required_features.clear();
-                            break;
-                        }
-                    }
-                    // We are sorting primitive strs here and can use unstable sort here
-                    required_features.sort_unstable();
-                    required_features.dedup();
-                    match &required_features[..] {
-                        [] => {}
-                        [feature] => {
-                            let msg = format!(
-                                "register class `{}` requires the `{}` target feature",
-                                reg_class.name(),
-                                feature
-                            );
-                            sess.struct_span_err(op_sp, &msg).emit();
-                        }
-                        features => {
-                            let msg = format!(
-                                "register class `{}` requires at least one target feature: {}",
-                                reg_class.name(),
-                                features.join(", ")
-                            );
-                            sess.struct_span_err(op_sp, &msg).emit();
-                        }
-                    }
-                }
-
                 // Check for conflicts between explicit register operands.
                 if let asm::InlineAsmRegOrRegClass::Reg(reg) = reg {
                     let (input, output) = match op {
diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs
index fdd52bd74952f..5264f7cc32612 100644
--- a/compiler/rustc_hir/src/hir.rs
+++ b/compiler/rustc_hir/src/hir.rs
@@ -2293,6 +2293,13 @@ impl<'hir> InlineAsmOperand<'hir> {
             Self::Const { .. } | Self::Sym { .. } => None,
         }
     }
+
+    pub fn is_clobber(&self) -> bool {
+        matches!(
+            self,
+            InlineAsmOperand::Out { reg: InlineAsmRegOrRegClass::Reg(_), late: _, expr: None }
+        )
+    }
 }
 
 #[derive(Debug, HashStable_Generic)]
diff --git a/compiler/rustc_passes/src/intrinsicck.rs b/compiler/rustc_passes/src/intrinsicck.rs
index 96e9a40df3685..008b856ebf2fa 100644
--- a/compiler/rustc_passes/src/intrinsicck.rs
+++ b/compiler/rustc_passes/src/intrinsicck.rs
@@ -141,6 +141,7 @@ impl ExprVisitor<'tcx> {
         template: &[InlineAsmTemplatePiece],
         is_input: bool,
         tied_input: Option<(&hir::Expr<'tcx>, Option<InlineAsmType>)>,
+        target_features: &[Symbol],
     ) -> Option<InlineAsmType> {
         // Check the type against the allowed types for inline asm.
         let ty = self.typeck_results.expr_ty_adjusted(expr);
@@ -283,17 +284,20 @@ impl ExprVisitor<'tcx> {
         };
 
         // Check whether the selected type requires a target feature. Note that
-        // this is different from the feature check we did earlier in AST
-        // lowering. While AST lowering checked that this register class is
-        // usable at all with the currently enabled features, some types may
-        // only be usable with a register class when a certain feature is
-        // enabled. We check this here since it depends on the results of typeck.
+        // this is different from the feature check we did earlier. While the
+        // previous check checked that this register class is usable at all
+        // with the currently enabled features, some types may only be usable
+        // with a register class when a certain feature is enabled. We check
+        // this here since it depends on the results of typeck.
         //
         // Also note that this check isn't run when the operand type is never
-        // (!). In that case we still need the earlier check in AST lowering to
-        // verify that the register class is usable at all.
+        // (!). In that case we still need the earlier check to verify that the
+        // register class is usable at all.
         if let Some(feature) = feature {
-            if !self.tcx.sess.target_features.contains(&Symbol::intern(feature)) {
+            let feat_sym = Symbol::intern(feature);
+            if !self.tcx.sess.target_features.contains(&feat_sym)
+                && !target_features.contains(&feat_sym)
+            {
                 let msg = &format!("`{}` target feature is not enabled", feature);
                 let mut err = self.tcx.sess.struct_span_err(expr.span, msg);
                 err.note(&format!(
@@ -349,23 +353,122 @@ impl ExprVisitor<'tcx> {
         Some(asm_ty)
     }
 
-    fn check_asm(&self, asm: &hir::InlineAsm<'tcx>) {
-        for (idx, (op, _)) in asm.operands.iter().enumerate() {
+    fn check_asm(&self, asm: &hir::InlineAsm<'tcx>, hir_id: hir::HirId) {
+        let hir = self.tcx.hir();
+        let enclosing_id = hir.enclosing_body_owner(hir_id);
+        let enclosing_def_id = hir.local_def_id(enclosing_id).to_def_id();
+        let attrs = self.tcx.codegen_fn_attrs(enclosing_def_id);
+        for (idx, (op, op_sp)) in asm.operands.iter().enumerate() {
+            // Validate register classes against currently enabled target
+            // features. We check that at least one type is available for
+            // the enabled features.
+            //
+            // We ignore target feature requirements for clobbers: if the
+            // feature is disabled then the compiler doesn't care what we
+            // do with the registers.
+            //
+            // Note that this is only possible for explicit register
+            // operands, which cannot be used in the asm string.
+            if let Some(reg) = op.reg() {
+                if !op.is_clobber() {
+                    let mut missing_required_features = vec![];
+                    let reg_class = reg.reg_class();
+                    for &(_, feature) in reg_class.supported_types(self.tcx.sess.asm_arch.unwrap())
+                    {
+                        match feature {
+                            Some(feature) => {
+                                let feat_sym = Symbol::intern(feature);
+                                if self.tcx.sess.target_features.contains(&feat_sym)
+                                    || attrs.target_features.contains(&feat_sym)
+                                {
+                                    missing_required_features.clear();
+                                    break;
+                                } else {
+                                    missing_required_features.push(feature);
+                                }
+                            }
+                            None => {
+                                missing_required_features.clear();
+                                break;
+                            }
+                        }
+                    }
+
+                    // We are sorting primitive strs here and can use unstable sort here
+                    missing_required_features.sort_unstable();
+                    missing_required_features.dedup();
+                    match &missing_required_features[..] {
+                        [] => {}
+                        [feature] => {
+                            let msg = format!(
+                                "register class `{}` requires the `{}` target feature",
+                                reg_class.name(),
+                                feature
+                            );
+                            self.tcx.sess.struct_span_err(*op_sp, &msg).emit();
+                            // register isn't enabled, don't do more checks
+                            continue;
+                        }
+                        features => {
+                            let msg = format!(
+                                "register class `{}` requires at least one of the following target features: {}",
+                                reg_class.name(),
+                                features.join(", ")
+                            );
+                            self.tcx.sess.struct_span_err(*op_sp, &msg).emit();
+                            // register isn't enabled, don't do more checks
+                            continue;
+                        }
+                    }
+                }
+            }
+
             match *op {
                 hir::InlineAsmOperand::In { reg, ref expr } => {
-                    self.check_asm_operand_type(idx, reg, expr, asm.template, true, None);
+                    self.check_asm_operand_type(
+                        idx,
+                        reg,
+                        expr,
+                        asm.template,
+                        true,
+                        None,
+                        &attrs.target_features,
+                    );
                 }
                 hir::InlineAsmOperand::Out { reg, late: _, ref expr } => {
                     if let Some(expr) = expr {
-                        self.check_asm_operand_type(idx, reg, expr, asm.template, false, None);
+                        self.check_asm_operand_type(
+                            idx,
+                            reg,
+                            expr,
+                            asm.template,
+                            false,
+                            None,
+                            &attrs.target_features,
+                        );
                     }
                 }
                 hir::InlineAsmOperand::InOut { reg, late: _, ref expr } => {
-                    self.check_asm_operand_type(idx, reg, expr, asm.template, false, None);
+                    self.check_asm_operand_type(
+                        idx,
+                        reg,
+                        expr,
+                        asm.template,
+                        false,
+                        None,
+                        &attrs.target_features,
+                    );
                 }
                 hir::InlineAsmOperand::SplitInOut { reg, late: _, ref in_expr, ref out_expr } => {
-                    let in_ty =
-                        self.check_asm_operand_type(idx, reg, in_expr, asm.template, true, None);
+                    let in_ty = self.check_asm_operand_type(
+                        idx,
+                        reg,
+                        in_expr,
+                        asm.template,
+                        true,
+                        None,
+                        &attrs.target_features,
+                    );
                     if let Some(out_expr) = out_expr {
                         self.check_asm_operand_type(
                             idx,
@@ -374,6 +477,7 @@ impl ExprVisitor<'tcx> {
                             asm.template,
                             false,
                             Some((in_expr, in_ty)),
+                            &attrs.target_features,
                         );
                     }
                 }
@@ -422,7 +526,7 @@ impl Visitor<'tcx> for ExprVisitor<'tcx> {
                 }
             }
 
-            hir::ExprKind::InlineAsm(asm) => self.check_asm(asm),
+            hir::ExprKind::InlineAsm(asm) => self.check_asm(asm, expr.hir_id),
 
             _ => {}
         }
diff --git a/src/test/ui/asm/x86_64/bad-reg.rs b/src/test/ui/asm/x86_64/bad-reg.rs
index 06af08fab80f9..91d0f8c33f979 100644
--- a/src/test/ui/asm/x86_64/bad-reg.rs
+++ b/src/test/ui/asm/x86_64/bad-reg.rs
@@ -21,10 +21,6 @@ fn main() {
         //~^ ERROR asm template modifiers are not allowed for `const` arguments
         asm!("{:a}", sym main);
         //~^ ERROR asm template modifiers are not allowed for `sym` arguments
-        asm!("{}", in(zmm_reg) foo);
-        //~^ ERROR register class `zmm_reg` requires the `avx512f` target feature
-        asm!("", in("zmm0") foo);
-        //~^ ERROR register class `zmm_reg` requires the `avx512f` target feature
         asm!("", in("ebp") foo);
         //~^ ERROR invalid register `ebp`: the frame pointer cannot be used as an operand
         asm!("", in("rsp") foo);
diff --git a/src/test/ui/asm/x86_64/bad-reg.stderr b/src/test/ui/asm/x86_64/bad-reg.stderr
index 14740bf62f8e5..102a17e981570 100644
--- a/src/test/ui/asm/x86_64/bad-reg.stderr
+++ b/src/test/ui/asm/x86_64/bad-reg.stderr
@@ -46,86 +46,74 @@ LL |         asm!("{:a}", sym main);
    |               |
    |               template modifier
 
-error: register class `zmm_reg` requires the `avx512f` target feature
-  --> $DIR/bad-reg.rs:24:20
-   |
-LL |         asm!("{}", in(zmm_reg) foo);
-   |                    ^^^^^^^^^^^^^^^
-
-error: register class `zmm_reg` requires the `avx512f` target feature
-  --> $DIR/bad-reg.rs:26:18
-   |
-LL |         asm!("", in("zmm0") foo);
-   |                  ^^^^^^^^^^^^^^
-
 error: invalid register `ebp`: the frame pointer cannot be used as an operand for inline asm
-  --> $DIR/bad-reg.rs:28:18
+  --> $DIR/bad-reg.rs:24:18
    |
 LL |         asm!("", in("ebp") foo);
    |                  ^^^^^^^^^^^^^
 
 error: invalid register `rsp`: the stack pointer cannot be used as an operand for inline asm
-  --> $DIR/bad-reg.rs:30:18
+  --> $DIR/bad-reg.rs:26:18
    |
 LL |         asm!("", in("rsp") foo);
    |                  ^^^^^^^^^^^^^
 
 error: invalid register `ip`: the instruction pointer cannot be used as an operand for inline asm
-  --> $DIR/bad-reg.rs:32:18
+  --> $DIR/bad-reg.rs:28:18
    |
 LL |         asm!("", in("ip") foo);
    |                  ^^^^^^^^^^^^
 
 error: invalid register `k0`: the k0 AVX mask register cannot be used as an operand for inline asm
-  --> $DIR/bad-reg.rs:34:18
+  --> $DIR/bad-reg.rs:30:18
    |
 LL |         asm!("", in("k0") foo);
    |                  ^^^^^^^^^^^^
 
 error: invalid register `ah`: high byte registers cannot be used as an operand on x86_64
-  --> $DIR/bad-reg.rs:36:18
+  --> $DIR/bad-reg.rs:32:18
    |
 LL |         asm!("", in("ah") foo);
    |                  ^^^^^^^^^^^^
 
 error: register class `x87_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:39:18
+  --> $DIR/bad-reg.rs:35:18
    |
 LL |         asm!("", in("st(2)") foo);
    |                  ^^^^^^^^^^^^^^^
 
 error: register class `mmx_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:41:18
+  --> $DIR/bad-reg.rs:37:18
    |
 LL |         asm!("", in("mm0") foo);
    |                  ^^^^^^^^^^^^^
 
 error: register class `x87_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:45:20
+  --> $DIR/bad-reg.rs:41:20
    |
 LL |         asm!("{}", in(x87_reg) foo);
    |                    ^^^^^^^^^^^^^^^
 
 error: register class `mmx_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:47:20
+  --> $DIR/bad-reg.rs:43:20
    |
 LL |         asm!("{}", in(mmx_reg) foo);
    |                    ^^^^^^^^^^^^^^^
 
 error: register class `x87_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:49:20
+  --> $DIR/bad-reg.rs:45:20
    |
 LL |         asm!("{}", out(x87_reg) _);
    |                    ^^^^^^^^^^^^^^
 
 error: register class `mmx_reg` can only be used as a clobber, not as an input or output
-  --> $DIR/bad-reg.rs:51:20
+  --> $DIR/bad-reg.rs:47:20
    |
 LL |         asm!("{}", out(mmx_reg) _);
    |                    ^^^^^^^^^^^^^^
 
 error: register `al` conflicts with register `ax`
-  --> $DIR/bad-reg.rs:57:33
+  --> $DIR/bad-reg.rs:53:33
    |
 LL |         asm!("", in("eax") foo, in("al") bar);
    |                  -------------  ^^^^^^^^^^^^ register `al`
@@ -133,7 +121,7 @@ LL |         asm!("", in("eax") foo, in("al") bar);
    |                  register `ax`
 
 error: register `ax` conflicts with register `ax`
-  --> $DIR/bad-reg.rs:59:33
+  --> $DIR/bad-reg.rs:55:33
    |
 LL |         asm!("", in("rax") foo, out("rax") bar);
    |                  -------------  ^^^^^^^^^^^^^^ register `ax`
@@ -141,13 +129,13 @@ LL |         asm!("", in("rax") foo, out("rax") bar);
    |                  register `ax`
    |
 help: use `lateout` instead of `out` to avoid conflict
-  --> $DIR/bad-reg.rs:59:18
+  --> $DIR/bad-reg.rs:55:18
    |
 LL |         asm!("", in("rax") foo, out("rax") bar);
    |                  ^^^^^^^^^^^^^
 
 error: register `ymm0` conflicts with register `xmm0`
-  --> $DIR/bad-reg.rs:62:34
+  --> $DIR/bad-reg.rs:58:34
    |
 LL |         asm!("", in("xmm0") foo, in("ymm0") bar);
    |                  --------------  ^^^^^^^^^^^^^^ register `ymm0`
@@ -155,7 +143,7 @@ LL |         asm!("", in("xmm0") foo, in("ymm0") bar);
    |                  register `xmm0`
 
 error: register `ymm0` conflicts with register `xmm0`
-  --> $DIR/bad-reg.rs:64:34
+  --> $DIR/bad-reg.rs:60:34
    |
 LL |         asm!("", in("xmm0") foo, out("ymm0") bar);
    |                  --------------  ^^^^^^^^^^^^^^^ register `ymm0`
@@ -163,10 +151,10 @@ LL |         asm!("", in("xmm0") foo, out("ymm0") bar);
    |                  register `xmm0`
    |
 help: use `lateout` instead of `out` to avoid conflict
-  --> $DIR/bad-reg.rs:64:18
+  --> $DIR/bad-reg.rs:60:18
    |
 LL |         asm!("", in("xmm0") foo, out("ymm0") bar);
    |                  ^^^^^^^^^^^^^^
 
-error: aborting due to 23 previous errors
+error: aborting due to 21 previous errors
 
diff --git a/src/test/ui/asm/x86_64/target-feature-attr.rs b/src/test/ui/asm/x86_64/target-feature-attr.rs
new file mode 100644
index 0000000000000..4f82cd8aab9d0
--- /dev/null
+++ b/src/test/ui/asm/x86_64/target-feature-attr.rs
@@ -0,0 +1,40 @@
+// only-x86_64
+
+#![feature(asm, avx512_target_feature)]
+
+#[target_feature(enable = "avx")]
+unsafe fn foo() {
+    let mut x = 1;
+    let y = 2;
+    asm!("vaddps {2:y}, {0:y}, {1:y}", in(ymm_reg) x, in(ymm_reg) y, lateout(ymm_reg) x);
+    assert_eq!(x, 3);
+}
+
+unsafe fn bar() {
+    let mut x = 1;
+    let y = 2;
+    asm!("vaddps {2:y}, {0:y}, {1:y}", in(ymm_reg) x, in(ymm_reg) y, lateout(ymm_reg) x);
+    //~^ ERROR: register class `ymm_reg` requires the `avx` target feature
+    //~| ERROR: register class `ymm_reg` requires the `avx` target feature
+    //~| ERROR: register class `ymm_reg` requires the `avx` target feature
+    assert_eq!(x, 3);
+}
+
+#[target_feature(enable = "avx512bw")]
+unsafe fn baz() {
+    let x = 1;
+    asm!("/* {0} */", in(kreg) x);
+}
+
+unsafe fn baz2() {
+    let x = 1;
+    asm!("/* {0} */", in(kreg) x);
+    //~^ ERROR: register class `kreg` requires at least one of the following target features: avx512bw, avx512f
+}
+
+fn main() {
+    unsafe {
+        foo();
+        bar();
+    }
+}
diff --git a/src/test/ui/asm/x86_64/target-feature-attr.stderr b/src/test/ui/asm/x86_64/target-feature-attr.stderr
new file mode 100644
index 0000000000000..295c8a97ed3bc
--- /dev/null
+++ b/src/test/ui/asm/x86_64/target-feature-attr.stderr
@@ -0,0 +1,26 @@
+error: register class `ymm_reg` requires the `avx` target feature
+  --> $DIR/target-feature-attr.rs:16:40
+   |
+LL |     asm!("vaddps {2:y}, {0:y}, {1:y}", in(ymm_reg) x, in(ymm_reg) y, lateout(ymm_reg) x);
+   |                                        ^^^^^^^^^^^^^
+
+error: register class `ymm_reg` requires the `avx` target feature
+  --> $DIR/target-feature-attr.rs:16:55
+   |
+LL |     asm!("vaddps {2:y}, {0:y}, {1:y}", in(ymm_reg) x, in(ymm_reg) y, lateout(ymm_reg) x);
+   |                                                       ^^^^^^^^^^^^^
+
+error: register class `ymm_reg` requires the `avx` target feature
+  --> $DIR/target-feature-attr.rs:16:70
+   |
+LL |     asm!("vaddps {2:y}, {0:y}, {1:y}", in(ymm_reg) x, in(ymm_reg) y, lateout(ymm_reg) x);
+   |                                                                      ^^^^^^^^^^^^^^^^^^
+
+error: register class `kreg` requires at least one of the following target features: avx512bw, avx512f
+  --> $DIR/target-feature-attr.rs:31:23
+   |
+LL |     asm!("/* {0} */", in(kreg) x);
+   |                       ^^^^^^^^^^
+
+error: aborting due to 4 previous errors
+