Skip to content

Commit 88a8679

Browse files
committedApr 29, 2025·
transmutability: uninit transition matches unit byte only
The previous implementation was inconsistent about transitions that apply for an init byte. For example, when answering a query, an init byte could use corresponding init transition. Init byte could also use uninit transition, but only when the corresponding init transition was absent. This behaviour was incompatible with DFA union construction. Define an uninit transition to match an uninit byte only and update implementation accordingly. To describe that `Tree::uninit` is valid for any value, build an automaton that accepts any byte value. Additionally, represent byte ranges uniformly as a pair of integers to avoid special case for uninit byte.
·
1.90.01.88.0
1 parent 21079f5 commit 88a8679

File tree

9 files changed

+243
-352
lines changed

9 files changed

+243
-352
lines changed
 

‎compiler/rustc_transmute/Cargo.toml‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ edition = "2024"
55

66
[dependencies]
77
# tidy-alphabetical-start
8-
itertools = "0.12"
98
rustc_abi = { path = "../rustc_abi", optional = true }
109
rustc_data_structures = { path = "../rustc_data_structures" }
1110
rustc_hir = { path = "../rustc_hir", optional = true }
@@ -15,6 +14,11 @@ smallvec = "1.8.1"
1514
tracing = "0.1"
1615
# tidy-alphabetical-end
1716

17+
[dev-dependencies]
18+
# tidy-alphabetical-start
19+
itertools = "0.12"
20+
# tidy-alphabetical-end
21+
1822
[features]
1923
rustc = [
2024
"dep:rustc_abi",

‎compiler/rustc_transmute/src/layout/dfa.rs‎

Lines changed: 108 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::fmt;
2-
use std::ops::RangeInclusive;
2+
use std::iter::Peekable;
33
use std::sync::atomic::{AtomicU32, Ordering};
44

55
use super::{Byte, Ref, Tree, Uninhabited};
@@ -211,15 +211,15 @@ where
211211
let b_transitions =
212212
b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions);
213213

214-
let byte_transitions =
215-
a_transitions.byte_transitions.union(&b_transitions.byte_transitions);
216-
217-
let byte_transitions = byte_transitions.map_states(|(a_dst, b_dst)| {
218-
assert!(a_dst.is_some() || b_dst.is_some());
214+
let byte_transitions = a_transitions.byte_transitions.union(
215+
&b_transitions.byte_transitions,
216+
|a_dst, b_dst| {
217+
assert!(a_dst.is_some() || b_dst.is_some());
219218

220-
queue.enqueue(a_dst, b_dst);
221-
mapped((a_dst, b_dst))
222-
});
219+
queue.enqueue(a_dst, b_dst);
220+
mapped((a_dst, b_dst))
221+
},
222+
);
223223

224224
let ref_transitions =
225225
a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys());
@@ -245,18 +245,6 @@ where
245245
Self { transitions, start, accept }
246246
}
247247

248-
pub(crate) fn states_from(
249-
&self,
250-
state: State,
251-
src_validity: RangeInclusive<u8>,
252-
) -> impl Iterator<Item = (Byte, State)> {
253-
self.transitions
254-
.get(&state)
255-
.map(move |t| t.byte_transitions.states_from(src_validity))
256-
.into_iter()
257-
.flatten()
258-
}
259-
260248
pub(crate) fn get_uninit_edge_dst(&self, state: State) -> Option<State> {
261249
let transitions = self.transitions.get(&state)?;
262250
transitions.byte_transitions.get_uninit_edge_dst()
@@ -334,95 +322,31 @@ where
334322

335323
use edge_set::EdgeSet;
336324
mod edge_set {
337-
use std::cmp;
338-
339-
use run::*;
340-
use smallvec::{SmallVec, smallvec};
325+
use smallvec::SmallVec;
341326

342327
use super::*;
343-
mod run {
344-
use std::ops::{Range, RangeInclusive};
345-
346-
use super::*;
347-
use crate::layout::Byte;
348-
349-
/// A logical set of edges.
350-
///
351-
/// A `Run` encodes one edge for every byte value in `start..=end`
352-
/// pointing to `dst`.
353-
#[derive(Eq, PartialEq, Copy, Clone, Debug)]
354-
pub(super) struct Run<S> {
355-
// `start` and `end` are both inclusive (ie, closed) bounds, as this
356-
// is required in order to be able to store 0..=255. We provide
357-
// setters and getters which operate on closed/open ranges, which
358-
// are more intuitive and easier for performing offset math.
359-
start: u8,
360-
end: u8,
361-
pub(super) dst: S,
362-
}
363-
364-
impl<S> Run<S> {
365-
pub(super) fn new(range: RangeInclusive<u8>, dst: S) -> Self {
366-
Self { start: *range.start(), end: *range.end(), dst }
367-
}
368-
369-
pub(super) fn from_inclusive_exclusive(range: Range<u16>, dst: S) -> Self {
370-
Self {
371-
start: range.start.try_into().unwrap(),
372-
end: (range.end - 1).try_into().unwrap(),
373-
dst,
374-
}
375-
}
376-
377-
pub(super) fn contains(&self, idx: u16) -> bool {
378-
idx >= u16::from(self.start) && idx <= u16::from(self.end)
379-
}
380-
381-
pub(super) fn as_inclusive_exclusive(&self) -> (u16, u16) {
382-
(u16::from(self.start), u16::from(self.end) + 1)
383-
}
384-
385-
pub(super) fn as_byte(&self) -> Byte {
386-
Byte::new(self.start..=self.end)
387-
}
388328

389-
pub(super) fn map_state<SS>(self, f: impl FnOnce(S) -> SS) -> Run<SS> {
390-
let Run { start, end, dst } = self;
391-
Run { start, end, dst: f(dst) }
392-
}
393-
394-
/// Produces a new `Run` whose lower bound is the greater of
395-
/// `self`'s existing lower bound and `lower_bound`.
396-
pub(super) fn clamp_lower(self, lower_bound: u8) -> Self {
397-
let Run { start, end, dst } = self;
398-
Run { start: cmp::max(start, lower_bound), end, dst }
399-
}
400-
}
401-
}
402-
403-
/// The set of outbound byte edges associated with a DFA node (not including
404-
/// reference edges).
329+
/// The set of outbound byte edges associated with a DFA node.
405330
#[derive(Eq, PartialEq, Clone, Debug)]
406331
pub(super) struct EdgeSet<S = State> {
407-
// A sequence of runs stored in ascending order. Since the graph is a
408-
// DFA, these must be non-overlapping with one another.
409-
runs: SmallVec<[Run<S>; 1]>,
410-
// The edge labeled with the uninit byte, if any.
332+
// A sequence of byte edges with contiguous byte values and a common
333+
// destination is stored as a single run.
411334
//
412-
// FIXME(@joshlf): Make `State` a `NonZero` so that this is NPO'd.
413-
uninit: Option<S>,
335+
// Runs are non-empty, non-overlapping, and stored in ascending order.
336+
runs: SmallVec<[(Byte, S); 1]>,
414337
}
415338

416339
impl<S> EdgeSet<S> {
417-
pub(crate) fn new(byte: Byte, dst: S) -> Self {
418-
match byte.range() {
419-
Some(range) => Self { runs: smallvec![Run::new(range, dst)], uninit: None },
420-
None => Self { runs: SmallVec::new(), uninit: Some(dst) },
340+
pub(crate) fn new(range: Byte, dst: S) -> Self {
341+
let mut this = Self { runs: SmallVec::new() };
342+
if !range.is_empty() {
343+
this.runs.push((range, dst));
421344
}
345+
this
422346
}
423347

424348
pub(crate) fn empty() -> Self {
425-
Self { runs: SmallVec::new(), uninit: None }
349+
Self { runs: SmallVec::new() }
426350
}
427351

428352
#[cfg(test)]
@@ -431,43 +355,23 @@ mod edge_set {
431355
S: Ord,
432356
{
433357
edges.sort();
434-
Self {
435-
runs: edges
436-
.into_iter()
437-
.map(|(byte, state)| Run::new(byte.range().unwrap(), state))
438-
.collect(),
439-
uninit: None,
440-
}
358+
Self { runs: edges.into() }
441359
}
442360

443361
pub(crate) fn iter(&self) -> impl Iterator<Item = (Byte, S)>
444362
where
445363
S: Copy,
446364
{
447-
self.uninit
448-
.map(|dst| (Byte::uninit(), dst))
449-
.into_iter()
450-
.chain(self.runs.iter().map(|run| (run.as_byte(), run.dst)))
451-
}
452-
453-
pub(crate) fn states_from(
454-
&self,
455-
byte: RangeInclusive<u8>,
456-
) -> impl Iterator<Item = (Byte, S)>
457-
where
458-
S: Copy,
459-
{
460-
// FIXME(@joshlf): Optimize this. A manual scan over `self.runs` may
461-
// permit us to more efficiently discard runs which will not be
462-
// produced by this iterator.
463-
self.iter().filter(move |(o, _)| Byte::new(byte.clone()).transmutable_into(&o))
365+
self.runs.iter().copied()
464366
}
465367

466368
pub(crate) fn get_uninit_edge_dst(&self) -> Option<S>
467369
where
468370
S: Copy,
469371
{
470-
self.uninit
372+
// Uninit is ordered last.
373+
let &(range, dst) = self.runs.last()?;
374+
if range.contains_uninit() { Some(dst) } else { None }
471375
}
472376

473377
pub(crate) fn map_states<SS>(self, mut f: impl FnMut(S) -> SS) -> EdgeSet<SS> {
@@ -478,95 +382,106 @@ mod edge_set {
478382
// allocates the correct number of elements once up-front [1].
479383
//
480384
// [1] https://doc.rust-lang.org/1.85.0/src/alloc/vec/spec_from_iter_nested.rs.html#47
481-
runs: self.runs.into_iter().map(|run| run.map_state(&mut f)).collect(),
482-
uninit: self.uninit.map(f),
385+
runs: self.runs.into_iter().map(|(b, s)| (b, f(s))).collect(),
483386
}
484387
}
485388

486389
/// Unions two edge sets together.
487390
///
488391
/// If `u = a.union(b)`, then for each byte value, `u` will have an edge
489-
/// with that byte value and with the destination `(Some(_), None)`,
490-
/// `(None, Some(_))`, or `(Some(_), Some(_))` depending on whether `a`,
392+
/// with that byte value and with the destination `join(Some(_), None)`,
393+
/// `join(None, Some(_))`, or `join(Some(_), Some(_))` depending on whether `a`,
491394
/// `b`, or both have an edge with that byte value.
492395
///
493396
/// If neither `a` nor `b` have an edge with a particular byte value,
494397
/// then no edge with that value will be present in `u`.
495-
pub(crate) fn union(&self, other: &Self) -> EdgeSet<(Option<S>, Option<S>)>
398+
pub(crate) fn union(
399+
&self,
400+
other: &Self,
401+
mut join: impl FnMut(Option<S>, Option<S>) -> S,
402+
) -> EdgeSet<S>
496403
where
497404
S: Copy,
498405
{
499-
let uninit = match (self.uninit, other.uninit) {
500-
(None, None) => None,
501-
(s, o) => Some((s, o)),
502-
};
503-
504-
let mut runs = SmallVec::new();
505-
506-
// Iterate over `self.runs` and `other.runs` simultaneously,
507-
// advancing `idx` as we go. At each step, we advance `idx` as far
508-
// as we can without crossing a run boundary in either `self.runs`
509-
// or `other.runs`.
510-
511-
// INVARIANT: `idx < s[0].end && idx < o[0].end`.
512-
let (mut s, mut o) = (self.runs.as_slice(), other.runs.as_slice());
513-
let mut idx = 0u16;
514-
while let (Some((s_run, s_rest)), Some((o_run, o_rest))) =
515-
(s.split_first(), o.split_first())
516-
{
517-
let (s_start, s_end) = s_run.as_inclusive_exclusive();
518-
let (o_start, o_end) = o_run.as_inclusive_exclusive();
519-
520-
// Compute `end` as the end of the current run (which starts
521-
// with `idx`).
522-
let (end, dst) = match (s_run.contains(idx), o_run.contains(idx)) {
523-
// `idx` is in an existing run in both `s` and `o`, so `end`
524-
// is equal to the smallest of the two ends of those runs.
525-
(true, true) => (cmp::min(s_end, o_end), (Some(s_run.dst), Some(o_run.dst))),
526-
// `idx` is in an existing run in `s`, but not in any run in
527-
// `o`. `end` is either the end of the `s` run or the
528-
// beginning of the next `o` run, whichever comes first.
529-
(true, false) => (cmp::min(s_end, o_start), (Some(s_run.dst), None)),
530-
// The inverse of the previous case.
531-
(false, true) => (cmp::min(s_start, o_end), (None, Some(o_run.dst))),
532-
// `idx` is not in a run in either `s` or `o`, so advance it
533-
// to the beginning of the next run.
534-
(false, false) => {
535-
idx = cmp::min(s_start, o_start);
536-
continue;
537-
}
538-
};
406+
let xs = self.runs.iter().copied();
407+
let ys = other.runs.iter().copied();
408+
// FIXME(@joshlf): Merge contiguous runs with common destination.
409+
EdgeSet { runs: union(xs, ys).map(|(range, (x, y))| (range, join(x, y))).collect() }
410+
}
411+
}
412+
}
413+
414+
/// Merges two sorted sequences into one sorted sequence.
415+
pub(crate) fn union<S: Copy, X: Iterator<Item = (Byte, S)>, Y: Iterator<Item = (Byte, S)>>(
416+
xs: X,
417+
ys: Y,
418+
) -> UnionIter<X, Y> {
419+
UnionIter { xs: xs.peekable(), ys: ys.peekable() }
420+
}
421+
422+
pub(crate) struct UnionIter<X: Iterator, Y: Iterator> {
423+
xs: Peekable<X>,
424+
ys: Peekable<Y>,
425+
}
426+
427+
// FIXME(jswrenn) we'd likely benefit from specializing try_fold here.
428+
impl<S: Copy, X: Iterator<Item = (Byte, S)>, Y: Iterator<Item = (Byte, S)>> Iterator
429+
for UnionIter<X, Y>
430+
{
431+
type Item = (Byte, (Option<S>, Option<S>));
539432

540-
// FIXME(@joshlf): If this is contiguous with the previous run
541-
// and has the same `dst`, just merge it into that run rather
542-
// than adding a new one.
543-
runs.push(Run::from_inclusive_exclusive(idx..end, dst));
544-
idx = end;
433+
fn next(&mut self) -> Option<Self::Item> {
434+
use std::cmp::{self, Ordering};
545435

546-
if idx >= s_end {
547-
s = s_rest;
436+
let ret;
437+
match (self.xs.peek_mut(), self.ys.peek_mut()) {
438+
(None, None) => {
439+
ret = None;
440+
}
441+
(Some(x), None) => {
442+
ret = Some((x.0, (Some(x.1), None)));
443+
self.xs.next();
444+
}
445+
(None, Some(y)) => {
446+
ret = Some((y.0, (None, Some(y.1))));
447+
self.ys.next();
448+
}
449+
(Some(x), Some(y)) => {
450+
let start;
451+
let end;
452+
let dst;
453+
match x.0.start.cmp(&y.0.start) {
454+
Ordering::Less => {
455+
start = x.0.start;
456+
end = cmp::min(x.0.end, y.0.start);
457+
dst = (Some(x.1), None);
458+
}
459+
Ordering::Greater => {
460+
start = y.0.start;
461+
end = cmp::min(x.0.start, y.0.end);
462+
dst = (None, Some(y.1));
463+
}
464+
Ordering::Equal => {
465+
start = x.0.start;
466+
end = cmp::min(x.0.end, y.0.end);
467+
dst = (Some(x.1), Some(y.1));
468+
}
548469
}
549-
if idx >= o_end {
550-
o = o_rest;
470+
ret = Some((Byte { start, end }, dst));
471+
if start == x.0.start {
472+
x.0.start = end;
473+
}
474+
if start == y.0.start {
475+
y.0.start = end;
476+
}
477+
if x.0.is_empty() {
478+
self.xs.next();
479+
}
480+
if y.0.is_empty() {
481+
self.ys.next();
551482
}
552483
}
553-
554-
// At this point, either `s` or `o` have been exhausted, so the
555-
// remaining elements in the other slice are guaranteed to be
556-
// non-overlapping. We can add all remaining runs to `runs` with no
557-
// further processing.
558-
if let Ok(idx) = u8::try_from(idx) {
559-
let (slc, map) = if !s.is_empty() {
560-
let map: fn(_) -> _ = |st| (Some(st), None);
561-
(s, map)
562-
} else {
563-
let map: fn(_) -> _ = |st| (None, Some(st));
564-
(o, map)
565-
};
566-
runs.extend(slc.iter().map(|run| run.clamp_lower(idx).map_state(map)));
567-
}
568-
569-
EdgeSet { runs, uninit }
570484
}
485+
ret
571486
}
572487
}

‎compiler/rustc_transmute/src/layout/mod.rs‎

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,61 @@ pub(crate) mod tree;
66
pub(crate) use tree::Tree;
77

88
pub(crate) mod dfa;
9-
pub(crate) use dfa::Dfa;
9+
pub(crate) use dfa::{Dfa, union};
1010

1111
#[derive(Debug)]
1212
pub(crate) struct Uninhabited;
1313

14-
/// A range of byte values, or the uninit byte.
14+
/// A range of byte values (including an uninit byte value).
1515
#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)]
1616
pub(crate) struct Byte {
17-
// An inclusive-inclusive range. We use this instead of `RangeInclusive`
18-
// because `RangeInclusive: !Copy`.
17+
// An inclusive-exclusive range. We use this instead of `Range` because `Range: !Copy`.
1918
//
20-
// `None` means uninit.
21-
//
22-
// FIXME(@joshlf): Optimize this representation. Some pairs of values (where
23-
// `lo > hi`) are illegal, and we could use these to represent `None`.
24-
range: Option<(u8, u8)>,
19+
// Uninit byte value is represented by 256.
20+
pub(crate) start: u16,
21+
pub(crate) end: u16,
2522
}
2623

2724
impl Byte {
25+
const UNINIT: u16 = 256;
26+
27+
#[inline]
2828
fn new(range: RangeInclusive<u8>) -> Self {
29-
Self { range: Some((*range.start(), *range.end())) }
29+
let start: u16 = (*range.start()).into();
30+
let end: u16 = (*range.end()).into();
31+
Byte { start, end: end + 1 }
3032
}
3133

34+
#[inline]
3235
fn from_val(val: u8) -> Self {
33-
Self { range: Some((val, val)) }
36+
let val: u16 = val.into();
37+
Byte { start: val, end: val + 1 }
3438
}
3539

36-
pub(crate) fn uninit() -> Byte {
37-
Byte { range: None }
40+
#[inline]
41+
fn uninit() -> Byte {
42+
Byte { start: 0, end: Self::UNINIT + 1 }
3843
}
3944

40-
/// Returns `None` if `self` is the uninit byte.
41-
pub(crate) fn range(&self) -> Option<RangeInclusive<u8>> {
42-
self.range.map(|(lo, hi)| lo..=hi)
45+
#[inline]
46+
fn is_empty(&self) -> bool {
47+
self.start == self.end
4348
}
4449

45-
/// Are any of the values in `self` transmutable into `other`?
46-
///
47-
/// Note two special cases: An uninit byte is only transmutable into another
48-
/// uninit byte. Any byte is transmutable into an uninit byte.
49-
pub(crate) fn transmutable_into(&self, other: &Byte) -> bool {
50-
match (self.range, other.range) {
51-
(None, None) => true,
52-
(None, Some(_)) => false,
53-
(Some(_), None) => true,
54-
(Some((slo, shi)), Some((olo, ohi))) => slo <= ohi && olo <= shi,
55-
}
50+
#[inline]
51+
fn contains_uninit(&self) -> bool {
52+
self.start <= Self::UNINIT && Self::UNINIT < self.end
5653
}
5754
}
5855

5956
impl fmt::Debug for Byte {
6057
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61-
match self.range {
62-
None => write!(f, "uninit"),
63-
Some((lo, hi)) => write!(f, "{lo}..={hi}"),
58+
if self.start == Self::UNINIT && self.end == Self::UNINIT + 1 {
59+
write!(f, "uninit")
60+
} else if self.start <= Self::UNINIT && self.end == Self::UNINIT + 1 {
61+
write!(f, "{}..{}|uninit", self.start, self.end - 1)
62+
} else {
63+
write!(f, "{}..{}", self.start, self.end)
6464
}
6565
}
6666
}
@@ -72,6 +72,7 @@ impl From<RangeInclusive<u8>> for Byte {
7272
}
7373

7474
impl From<u8> for Byte {
75+
#[inline]
7576
fn from(src: u8) -> Self {
7677
Self::from_val(src)
7778
}

‎compiler/rustc_transmute/src/maybe_transmutable/mod.rs‎

Lines changed: 13 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
use std::rc::Rc;
2-
use std::{cmp, iter};
3-
4-
use itertools::Either;
51
use tracing::{debug, instrument, trace};
62

73
pub(crate) mod query_context;
84
#[cfg(test)]
95
mod tests;
106

11-
use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa};
7+
use crate::layout::{self, Def, Dfa, Ref, Tree, dfa, union};
128
use crate::maybe_transmutable::query_context::QueryContext;
139
use crate::{Answer, Condition, Map, Reason};
1410

@@ -197,122 +193,20 @@ where
197193
Quantifier::ForAll
198194
};
199195

200-
let c = &core::cell::RefCell::new(&mut *cache);
201196
let bytes_answer = src_quantifier.apply(
202-
// for each of the byte set transitions out of the `src_state`...
203-
self.src.bytes_from(src_state).flat_map(
204-
move |(src_validity, src_state_prime)| {
205-
// ...find all matching transitions out of `dst_state`.
206-
207-
let Some(src_validity) = src_validity.range() else {
208-
// NOTE: We construct an iterator here rather
209-
// than just computing the value directly (via
210-
// `self.answer_memo`) so that, if the iterator
211-
// we produce from this branch is
212-
// short-circuited, we don't waste time
213-
// computing `self.answer_memo` unnecessarily.
214-
// That will specifically happen if
215-
// `src_quantifier == Quantifier::ThereExists`,
216-
// since we emit `Answer::Yes` first (before
217-
// chaining `answer_iter`).
218-
let answer_iter = if let Some(dst_state_prime) =
219-
self.dst.get_uninit_edge_dst(dst_state)
220-
{
221-
Either::Left(iter::once_with(move || {
222-
let mut c = c.borrow_mut();
223-
self.answer_memo(&mut *c, src_state_prime, dst_state_prime)
224-
}))
225-
} else {
226-
Either::Right(iter::once(Answer::No(
227-
Reason::DstIsBitIncompatible,
228-
)))
229-
};
230-
231-
// When `answer == Answer::No(...)`, there are
232-
// two cases to consider:
233-
// - If `assume.validity`, then we should
234-
// succeed because the user is responsible for
235-
// ensuring that the *specific* byte value
236-
// appearing at runtime is valid for the
237-
// destination type. When `assume.validity`,
238-
// `src_quantifier ==
239-
// Quantifier::ThereExists`, so adding an
240-
// `Answer::Yes` has the effect of ensuring
241-
// that the "there exists" is always
242-
// satisfied.
243-
// - If `!assume.validity`, then we should fail.
244-
// In this case, `src_quantifier ==
245-
// Quantifier::ForAll`, so adding an
246-
// `Answer::Yes` has no effect.
247-
return Either::Left(iter::once(Answer::Yes).chain(answer_iter));
248-
};
249-
250-
#[derive(Copy, Clone, Debug)]
251-
struct Accum {
252-
// The number of matching byte edges that we
253-
// have found in the destination so far.
254-
sum: usize,
255-
found_uninit: bool,
256-
}
257-
258-
let accum1 = Rc::new(std::cell::Cell::new(Accum {
259-
sum: 0,
260-
found_uninit: false,
261-
}));
262-
let accum2 = Rc::clone(&accum1);
263-
let sv = src_validity.clone();
264-
let update_accum = move |mut accum: Accum, dst_validity: Byte| {
265-
if let Some(dst_validity) = dst_validity.range() {
266-
// Only add the part of `dst_validity` that
267-
// overlaps with `src_validity`.
268-
let start = cmp::max(*sv.start(), *dst_validity.start());
269-
let end = cmp::min(*sv.end(), *dst_validity.end());
270-
271-
// We add 1 here to account for the fact
272-
// that `end` is an inclusive bound.
273-
accum.sum += 1 + usize::from(end.saturating_sub(start));
274-
} else {
275-
accum.found_uninit = true;
197+
union(self.src.bytes_from(src_state), self.dst.bytes_from(dst_state))
198+
.filter_map(|(_range, (src_state_prime, dst_state_prime))| {
199+
match (src_state_prime, dst_state_prime) {
200+
// No matching transitions in `src`. Skip.
201+
(None, _) => None,
202+
// No matching transitions in `dst`. Fail.
203+
(Some(_), None) => Some(Answer::No(Reason::DstIsBitIncompatible)),
204+
// Matching transitions. Continue with successor states.
205+
(Some(src_state_prime), Some(dst_state_prime)) => {
206+
Some(self.answer_memo(cache, src_state_prime, dst_state_prime))
276207
}
277-
accum
278-
};
279-
280-
let answers = self
281-
.dst
282-
.states_from(dst_state, src_validity.clone())
283-
.map(move |(dst_validity, dst_state_prime)| {
284-
let mut c = c.borrow_mut();
285-
accum1.set(update_accum(accum1.get(), dst_validity));
286-
let answer =
287-
self.answer_memo(&mut *c, src_state_prime, dst_state_prime);
288-
answer
289-
})
290-
.chain(
291-
iter::once_with(move || {
292-
let src_validity_len = usize::from(*src_validity.end())
293-
- usize::from(*src_validity.start())
294-
+ 1;
295-
let accum = accum2.get();
296-
297-
// If this condition is false, then
298-
// there are some byte values in the
299-
// source which have no corresponding
300-
// transition in the destination DFA. In
301-
// that case, we add a `No` to our list
302-
// of answers. When
303-
// `!self.assume.validity`, this will
304-
// cause the query to fail.
305-
if accum.found_uninit || accum.sum == src_validity_len {
306-
None
307-
} else {
308-
Some(Answer::No(Reason::DstIsBitIncompatible))
309-
}
310-
})
311-
.flatten(),
312-
);
313-
Either::Right(answers)
314-
},
315-
),
208+
}
209+
}),
316210
);
317211

318212
// The below early returns reflect how this code would behave:

‎compiler/rustc_transmute/src/maybe_transmutable/tests.rs‎

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,16 +400,23 @@ mod r#ref {
400400
fn should_permit_identity_transmutation() {
401401
type Tree = crate::layout::Tree<Def, [(); 1]>;
402402

403-
let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]);
403+
for validity in [false, true] {
404+
let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]);
404405

405-
let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
406-
layout.clone(),
407-
layout,
408-
Assume::default(),
409-
UltraMinimal::default(),
410-
)
411-
.answer();
412-
assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] }));
406+
let assume = Assume { validity, ..Assume::default() };
407+
408+
let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new(
409+
layout.clone(),
410+
layout,
411+
assume,
412+
UltraMinimal::default(),
413+
)
414+
.answer();
415+
assert_eq!(
416+
answer,
417+
Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] })
418+
);
419+
}
413420
}
414421
}
415422

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#![crate_type = "lib"]
2+
#![feature(transmutability)]
3+
use std::mem::{Assume, MaybeUninit, TransmuteFrom};
4+
5+
pub fn is_maybe_transmutable<Src, Dst>()
6+
where Dst: TransmuteFrom<Src, { Assume::VALIDITY.and(Assume::SAFETY) }>
7+
{}
8+
9+
fn extension() {
10+
is_maybe_transmutable::<(), MaybeUninit<u8>>();
11+
is_maybe_transmutable::<MaybeUninit<u8>, [u8; 2]>(); //~ ERROR `MaybeUninit<u8>` cannot be safely transmuted into `[u8; 2]`
12+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
error[E0277]: `MaybeUninit<u8>` cannot be safely transmuted into `[u8; 2]`
2+
--> $DIR/extension.rs:11:46
3+
|
4+
LL | is_maybe_transmutable::<MaybeUninit<u8>, [u8; 2]>();
5+
| ^^^^^^^ the size of `MaybeUninit<u8>` is smaller than the size of `[u8; 2]`
6+
|
7+
note: required by a bound in `is_maybe_transmutable`
8+
--> $DIR/extension.rs:6:16
9+
|
10+
LL | pub fn is_maybe_transmutable<Src, Dst>()
11+
| --------------------- required by a bound in this function
12+
LL | where Dst: TransmuteFrom<Src, { Assume::VALIDITY.and(Assume::SAFETY) }>
13+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `is_maybe_transmutable`
14+
15+
error: aborting due to 1 previous error
16+
17+
For more information about this error, try `rustc --explain E0277`.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//@ check-pass
2+
// Regression test for issue #140337.
3+
#![crate_type = "lib"]
4+
#![feature(transmutability)]
5+
#![allow(dead_code)]
6+
use std::mem::{Assume, MaybeUninit, TransmuteFrom};
7+
8+
pub fn is_transmutable<Src, Dst>()
9+
where
10+
Dst: TransmuteFrom<Src, { Assume::SAFETY }>
11+
{}
12+
13+
#[derive(Copy, Clone)]
14+
#[repr(u8)]
15+
pub enum B0 { Value = 0 }
16+
17+
#[derive(Copy, Clone)]
18+
#[repr(u8)]
19+
pub enum B1 { Value = 1 }
20+
21+
fn main() {
22+
is_transmutable::<(B0, B0), MaybeUninit<(B0, B0)>>();
23+
is_transmutable::<(B0, B0), MaybeUninit<(B0, B1)>>();
24+
is_transmutable::<(B0, B0), MaybeUninit<(B1, B0)>>();
25+
is_transmutable::<(B0, B0), MaybeUninit<(B1, B1)>>();
26+
}

‎tests/ui/transmutability/unions/should_permit_intersecting_if_validity_is_assumed.rs‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,19 @@ fn test() {
3434

3535
assert::is_maybe_transmutable::<A, B>();
3636
assert::is_maybe_transmutable::<B, A>();
37+
38+
#[repr(C)]
39+
struct C {
40+
a: Ox00,
41+
b: Ox00,
42+
}
43+
44+
#[repr(C, align(2))]
45+
struct D {
46+
a: Ox00,
47+
}
48+
49+
assert::is_maybe_transmutable::<C, D>();
50+
// With Assume::VALIDITY a padding byte can hold any value.
51+
assert::is_maybe_transmutable::<D, C>();
3752
}

0 commit comments

Comments
 (0)
Please sign in to comment.