@@ -616,7 +616,13 @@ def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
616
616
new_index = index .subs (dict (zip (index_vars , reindex (new_index_vars ))))
617
617
return new_index
618
618
619
- def indexing (self , index : sympy .Expr , copy_shape = None , dense_indexing = False ):
619
+ def indexing (
620
+ self ,
621
+ index : sympy .Expr ,
622
+ copy_shape = None ,
623
+ dense_indexing = False ,
624
+ load_index0 = False ,
625
+ ):
620
626
"""
621
627
Compute the index and mask to pass to tl.load() or tl.store()
622
628
"""
@@ -632,6 +638,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
632
638
or indirect_indexing
633
639
or self ._load_mask is not None
634
640
) and index != 0
641
+
635
642
have_dense = True
636
643
have_loop_vars = False
637
644
mask = []
@@ -646,7 +653,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
646
653
mask .append (f"{ tree .prefix } mask" )
647
654
dense_mask .append (f"{ tree .prefix } mask" )
648
655
649
- if need_dense and not have_dense :
656
+ if ( need_dense and not have_dense ) or load_index0 :
650
657
mask = dense_mask
651
658
index_str = f"{ index_str } + tl.zeros({ self .dense_size_str ()} , tl.int32)"
652
659
elif not have_loop_vars and copy_shape :
@@ -692,7 +699,7 @@ def mask_loads(self, mask):
692
699
def load (self , name : str , index : sympy .Expr , upcast : bool = False ):
693
700
var = self .args .input (name )
694
701
indirect_indexing = self .is_indirect_indexing (index )
695
- index , mask = self .indexing (index )
702
+ index , mask = self .indexing (index , load_index0 = ( index == 0 ) )
696
703
if "rmask" in mask :
697
704
# This eviction policy heuristic is untested.
698
705
# ptillet suggested we should try only doing this for
0 commit comments