Skip to content

Commit ab68a07

Browse files
committed
fixed union fields chain
1 parent 42f6e92 commit ab68a07

File tree

2 files changed

+77
-19
lines changed

2 files changed

+77
-19
lines changed

crates/hir-ty/src/diagnostics/unsafe_check.rs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -297,25 +297,9 @@ impl<'db> UnsafeVisitor<'db> {
297297
return;
298298
}
299299
Expr::Field { .. } => {
300-
if matches!(
301-
self.infer.field_resolution(*expr),
302-
Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. }))
303-
) {
304-
match &self.body.exprs[*expr] {
305-
Expr::Field { expr, .. } => {
306-
// Visit the base expression (e.g., `self` in `self.field`) for safety,
307-
// but don't trigger the union field access error since we're just
308-
// creating a raw pointer, not actually reading the field
309-
self.walk_expr(*expr);
310-
}
311-
_ => {
312-
self.body.walk_child_exprs_without_pats(*expr, |child| {
313-
// If it's not a field access for some reason, fall back to normal walking
314-
// This shouldn't happen based on how this function is called
315-
self.walk_expr(child)
316-
});
317-
}
318-
}
300+
if self.contains_union_field_access(*expr) {
301+
// Walk the entire field access chain without triggering union field errors
302+
self.walk_field_chain_for_raw_ptr(*expr);
319303
return;
320304
}
321305
}
@@ -424,4 +408,43 @@ impl<'db> UnsafeVisitor<'db> {
424408
}
425409
}
426410
}
411+
412+
fn contains_union_field_access(&mut self, expr: ExprId) -> bool {
413+
match &self.body.exprs[expr] {
414+
Expr::Field { expr: base_expr, .. } => {
415+
// Check if this field access is from a union
416+
if matches!(
417+
self.infer.field_resolution(expr),
418+
Some(Either::Left(FieldId { parent: VariantId::UnionId(_), .. }))
419+
) {
420+
true
421+
} else {
422+
// Recursively check the base expression
423+
self.contains_union_field_access(*base_expr)
424+
}
425+
}
426+
_ => false,
427+
}
428+
}
429+
430+
/// Walks a field access chain for raw pointer creation, avoiding union field access errors
431+
fn walk_field_chain_for_raw_ptr(&mut self, expr: ExprId) {
432+
match &self.body.exprs[expr] {
433+
Expr::Field { expr: base_expr, .. } => {
434+
// First, recursively handle the base expression
435+
self.walk_field_chain_for_raw_ptr(*base_expr);
436+
437+
// Then handle any non-field child expressions of this field access
438+
self.body.walk_child_exprs_without_pats(expr, |child| {
439+
if child != *base_expr {
440+
self.walk_expr(child);
441+
}
442+
});
443+
}
444+
_ => {
445+
// We've reached the base expression (not a field access), walk it normally
446+
self.walk_expr(expr);
447+
}
448+
}
449+
}
427450
}

crates/ide-diagnostics/src/handlers/missing_unsafe.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,41 @@ fn main() {
869869
)
870870
}
871871

872+
#[test]
873+
fn union_fields_chain_is_allowed() {
874+
check_diagnostics(
875+
r#"
876+
union Inner {
877+
a: u8,
878+
}
879+
880+
union MoreInner {
881+
moreinner: ManuallyDrop<Inner>,
882+
}
883+
884+
union LessOuter {
885+
lessouter: ManuallyDrop<MoreInner>,
886+
}
887+
888+
union Outer {
889+
outer: ManuallyDrop<LessOuter>,
890+
}
891+
892+
fn main() {
893+
let super_outer = Outer {
894+
outer: ManuallyDrop::new(LessOuter {
895+
lessouter: ManuallyDrop::new(MoreInner {
896+
moreinner: ManuallyDrop::new(Inner { a: 42 }),
897+
}),
898+
}),
899+
};
900+
901+
let ptr = &raw const super_outer.outer.lessouter.moreinner.a;
902+
}
903+
"#,
904+
);
905+
}
906+
872907
#[test]
873908
fn raw_ref_reborrow_is_safe() {
874909
check_diagnostics(

0 commit comments

Comments
 (0)