Skip to content

Commit bf5eb42

Browse files
committed
Fix set_retained_vns_del!
1 parent e0086e1 commit bf5eb42

File tree

2 files changed

+31
-44
lines changed

2 files changed

+31
-44
lines changed

src/varinfo.jl

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,55 +1813,24 @@ end
18131813
"""
18141814
set_retained_vns_del!(vi::VarInfo)
18151815
1816-
Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`.
1816+
Set the `"del"` flag of variables in `vi` with `order > num_produce` to `true`.
1817+
1818+
Will error if `vi` does not have an accumulator for `VariableOrder`.
18171819
"""
1818-
function set_retained_vns_del!(vi::UntypedVarInfo)
1819-
idcs = _getidcs(vi)
1820-
if get_num_produce(vi) == 0
1821-
for i in length(idcs):-1:1
1822-
vi.metadata.flags["del"][idcs[i]] = true
1823-
end
1824-
else
1825-
for i in 1:length(vi.orders)
1826-
if i in idcs && vi.orders[i] > get_num_produce(vi)
1827-
vi.metadata.flags["del"][i] = true
1828-
end
1820+
function set_retained_vns_del!(vi::VarInfo)
1821+
if !hasacc(vi, Val(:VariableOrder))
1822+
msg = "`vi` must have an accumulator for VariableOrder to set the `del` flag."
1823+
raise(ArgumentError(msg))
1824+
end
1825+
num_produce = get_num_produce(vi)
1826+
for vn in keys(vi)
1827+
order = getorder(vi, vn)
1828+
if order > num_produce
1829+
set_flag!(vi, vn, "del")
18291830
end
18301831
end
18311832
return nothing
18321833
end
1833-
function set_retained_vns_del!(vi::NTVarInfo)
1834-
idcs = _getidcs(vi)
1835-
return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi))
1836-
end
1837-
@generated function _set_retained_vns_del!(
1838-
metadata, idcs::NamedTuple{names}, num_produce
1839-
) where {names}
1840-
expr = Expr(:block)
1841-
for f in names
1842-
f_idcs = :(idcs.$f)
1843-
f_orders = :(metadata.$f.orders)
1844-
f_flags = :(metadata.$f.flags)
1845-
push!(
1846-
expr.args,
1847-
quote
1848-
# Set the flag for variables with symbol `f`
1849-
if num_produce == 0
1850-
for i in length($f_idcs):-1:1
1851-
$f_flags["del"][$f_idcs[i]] = true
1852-
end
1853-
else
1854-
for i in 1:length($f_orders)
1855-
if i in $f_idcs && $f_orders[i] > num_produce
1856-
$f_flags["del"][i] = true
1857-
end
1858-
end
1859-
end
1860-
end,
1861-
)
1862-
end
1863-
return expr
1864-
end
18651834

18661835
# TODO: Maybe rename or something?
18671836
"""

test/varinfo.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,10 +1099,28 @@ end
10991099
@test DynamicPPL.getorder(vi, vn_z3) == 3
11001100
@test DynamicPPL.get_num_produce(vi) == 3
11011101

1102+
@test !DynamicPPL.is_flagged(vi, vn_z1, "del")
1103+
@test !DynamicPPL.is_flagged(vi, vn_a1, "del")
1104+
@test !DynamicPPL.is_flagged(vi, vn_b, "del")
1105+
@test !DynamicPPL.is_flagged(vi, vn_z2, "del")
1106+
@test !DynamicPPL.is_flagged(vi, vn_a2, "del")
1107+
@test !DynamicPPL.is_flagged(vi, vn_z3, "del")
1108+
1109+
vi = DynamicPPL.reset_num_produce!!(vi)
1110+
vi = DynamicPPL.increment_num_produce!!(vi)
1111+
DynamicPPL.set_retained_vns_del!(vi)
1112+
@test !DynamicPPL.is_flagged(vi, vn_z1, "del")
1113+
@test !DynamicPPL.is_flagged(vi, vn_a1, "del")
1114+
@test DynamicPPL.is_flagged(vi, vn_b, "del")
1115+
@test DynamicPPL.is_flagged(vi, vn_z2, "del")
1116+
@test DynamicPPL.is_flagged(vi, vn_a2, "del")
1117+
@test DynamicPPL.is_flagged(vi, vn_z3, "del")
1118+
11021119
vi = DynamicPPL.reset_num_produce!!(vi)
11031120
DynamicPPL.set_retained_vns_del!(vi)
11041121
@test DynamicPPL.is_flagged(vi, vn_z1, "del")
11051122
@test DynamicPPL.is_flagged(vi, vn_a1, "del")
1123+
@test DynamicPPL.is_flagged(vi, vn_b, "del")
11061124
@test DynamicPPL.is_flagged(vi, vn_z2, "del")
11071125
@test DynamicPPL.is_flagged(vi, vn_a2, "del")
11081126
@test DynamicPPL.is_flagged(vi, vn_z3, "del")

0 commit comments

Comments
 (0)