@@ -1659,35 +1659,62 @@ static void mark_all_scalars_precise(struct bpf_verifier_env *env,
1659
1659
}
1660
1660
}
1661
1661
1662
- static int mark_chain_precision (struct bpf_verifier_env * env , int regno )
1662
+ static int __mark_chain_precision (struct bpf_verifier_env * env , int regno ,
1663
+ int spi )
1663
1664
{
1664
1665
struct bpf_verifier_state * st = env -> cur_state ;
1665
1666
int first_idx = st -> first_insn_idx ;
1666
1667
int last_idx = env -> insn_idx ;
1667
1668
struct bpf_func_state * func ;
1668
1669
struct bpf_reg_state * reg ;
1669
- u32 reg_mask = 1u << regno ;
1670
- u64 stack_mask = 0 ;
1670
+ u32 reg_mask = regno >= 0 ? 1u << regno : 0 ;
1671
+ u64 stack_mask = spi >= 0 ? 1ull << spi : 0 ;
1671
1672
bool skip_first = true;
1673
+ bool new_marks = false;
1672
1674
int i , err ;
1673
1675
1674
1676
if (!env -> allow_ptr_leaks )
1675
1677
/* backtracking is root only for now */
1676
1678
return 0 ;
1677
1679
1678
1680
func = st -> frame [st -> curframe ];
1679
- reg = & func -> regs [regno ];
1680
- if (reg -> type != SCALAR_VALUE ) {
1681
- WARN_ONCE (1 , "backtracing misuse" );
1682
- return - EFAULT ;
1681
+ if (regno >= 0 ) {
1682
+ reg = & func -> regs [regno ];
1683
+ if (reg -> type != SCALAR_VALUE ) {
1684
+ WARN_ONCE (1 , "backtracing misuse" );
1685
+ return - EFAULT ;
1686
+ }
1687
+ if (!reg -> precise )
1688
+ new_marks = true;
1689
+ else
1690
+ reg_mask = 0 ;
1691
+ reg -> precise = true;
1683
1692
}
1684
- if (reg -> precise )
1685
- return 0 ;
1686
- func -> regs [regno ].precise = true;
1687
1693
1694
+ while (spi >= 0 ) {
1695
+ if (func -> stack [spi ].slot_type [0 ] != STACK_SPILL ) {
1696
+ stack_mask = 0 ;
1697
+ break ;
1698
+ }
1699
+ reg = & func -> stack [spi ].spilled_ptr ;
1700
+ if (reg -> type != SCALAR_VALUE ) {
1701
+ stack_mask = 0 ;
1702
+ break ;
1703
+ }
1704
+ if (!reg -> precise )
1705
+ new_marks = true;
1706
+ else
1707
+ stack_mask = 0 ;
1708
+ reg -> precise = true;
1709
+ break ;
1710
+ }
1711
+
1712
+ if (!new_marks )
1713
+ return 0 ;
1714
+ if (!reg_mask && !stack_mask )
1715
+ return 0 ;
1688
1716
for (;;) {
1689
1717
DECLARE_BITMAP (mask , 64 );
1690
- bool new_marks = false;
1691
1718
u32 history = st -> jmp_history_cnt ;
1692
1719
1693
1720
if (env -> log .level & BPF_LOG_LEVEL )
@@ -1730,12 +1757,15 @@ static int mark_chain_precision(struct bpf_verifier_env *env, int regno)
1730
1757
if (!st )
1731
1758
break ;
1732
1759
1760
+ new_marks = false;
1733
1761
func = st -> frame [st -> curframe ];
1734
1762
bitmap_from_u64 (mask , reg_mask );
1735
1763
for_each_set_bit (i , mask , 32 ) {
1736
1764
reg = & func -> regs [i ];
1737
- if (reg -> type != SCALAR_VALUE )
1765
+ if (reg -> type != SCALAR_VALUE ) {
1766
+ reg_mask &= ~(1u << i );
1738
1767
continue ;
1768
+ }
1739
1769
if (!reg -> precise )
1740
1770
new_marks = true;
1741
1771
reg -> precise = true;
@@ -1756,11 +1786,15 @@ static int mark_chain_precision(struct bpf_verifier_env *env, int regno)
1756
1786
return - EFAULT ;
1757
1787
}
1758
1788
1759
- if (func -> stack [i ].slot_type [0 ] != STACK_SPILL )
1789
+ if (func -> stack [i ].slot_type [0 ] != STACK_SPILL ) {
1790
+ stack_mask &= ~(1ull << i );
1760
1791
continue ;
1792
+ }
1761
1793
reg = & func -> stack [i ].spilled_ptr ;
1762
- if (reg -> type != SCALAR_VALUE )
1794
+ if (reg -> type != SCALAR_VALUE ) {
1795
+ stack_mask &= ~(1ull << i );
1763
1796
continue ;
1797
+ }
1764
1798
if (!reg -> precise )
1765
1799
new_marks = true;
1766
1800
reg -> precise = true;
@@ -1772,6 +1806,8 @@ static int mark_chain_precision(struct bpf_verifier_env *env, int regno)
1772
1806
reg_mask , stack_mask );
1773
1807
}
1774
1808
1809
+ if (!reg_mask && !stack_mask )
1810
+ break ;
1775
1811
if (!new_marks )
1776
1812
break ;
1777
1813
@@ -1781,6 +1817,15 @@ static int mark_chain_precision(struct bpf_verifier_env *env, int regno)
1781
1817
return 0 ;
1782
1818
}
1783
1819
1820
+ static int mark_chain_precision (struct bpf_verifier_env * env , int regno )
1821
+ {
1822
+ return __mark_chain_precision (env , regno , -1 );
1823
+ }
1824
+
1825
+ static int mark_chain_precision_stack (struct bpf_verifier_env * env , int spi )
1826
+ {
1827
+ return __mark_chain_precision (env , -1 , spi );
1828
+ }
1784
1829
1785
1830
static bool is_spillable_regtype (enum bpf_reg_type type )
1786
1831
{
@@ -7111,6 +7156,46 @@ static int propagate_liveness(struct bpf_verifier_env *env,
7111
7156
return 0 ;
7112
7157
}
7113
7158
7159
+ /* find precise scalars in the previous equivalent state and
7160
+ * propagate them into the current state
7161
+ */
7162
+ static int propagate_precision (struct bpf_verifier_env * env ,
7163
+ const struct bpf_verifier_state * old )
7164
+ {
7165
+ struct bpf_reg_state * state_reg ;
7166
+ struct bpf_func_state * state ;
7167
+ int i , err = 0 ;
7168
+
7169
+ state = old -> frame [old -> curframe ];
7170
+ state_reg = state -> regs ;
7171
+ for (i = 0 ; i < BPF_REG_FP ; i ++ , state_reg ++ ) {
7172
+ if (state_reg -> type != SCALAR_VALUE ||
7173
+ !state_reg -> precise )
7174
+ continue ;
7175
+ if (env -> log .level & BPF_LOG_LEVEL2 )
7176
+ verbose (env , "propagating r%d\n" , i );
7177
+ err = mark_chain_precision (env , i );
7178
+ if (err < 0 )
7179
+ return err ;
7180
+ }
7181
+
7182
+ for (i = 0 ; i < state -> allocated_stack / BPF_REG_SIZE ; i ++ ) {
7183
+ if (state -> stack [i ].slot_type [0 ] != STACK_SPILL )
7184
+ continue ;
7185
+ state_reg = & state -> stack [i ].spilled_ptr ;
7186
+ if (state_reg -> type != SCALAR_VALUE ||
7187
+ !state_reg -> precise )
7188
+ continue ;
7189
+ if (env -> log .level & BPF_LOG_LEVEL2 )
7190
+ verbose (env , "propagating fp%d\n" ,
7191
+ (- i - 1 ) * BPF_REG_SIZE );
7192
+ err = mark_chain_precision_stack (env , i );
7193
+ if (err < 0 )
7194
+ return err ;
7195
+ }
7196
+ return 0 ;
7197
+ }
7198
+
7114
7199
static bool states_maybe_looping (struct bpf_verifier_state * old ,
7115
7200
struct bpf_verifier_state * cur )
7116
7201
{
@@ -7203,6 +7288,14 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
7203
7288
* this state and will pop a new one.
7204
7289
*/
7205
7290
err = propagate_liveness (env , & sl -> state , cur );
7291
+
7292
+ /* if previous state reached the exit with precision and
7293
+ * current state is equivalent to it (except precsion marks)
7294
+ * the precision needs to be propagated back in
7295
+ * the current state.
7296
+ */
7297
+ err = err ? : push_jmp_history (env , cur );
7298
+ err = err ? : propagate_precision (env , & sl -> state );
7206
7299
if (err )
7207
7300
return err ;
7208
7301
return 1 ;
0 commit comments