diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4690fb641..9125d1979 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -115,6 +115,22 @@ def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor: return weights +def _get_lengths_offset_per_key_or_throw( + lengths_offset_per_key: Optional[List[int]], +) -> List[int]: + assert ( + lengths_offset_per_key is not None + ), "This (Keyed)JaggedTensor doesn't have lengths_offset_per_key." + return lengths_offset_per_key + + +def _get_stride_per_key_or_throw(stride_per_key: Optional[List[int]]) -> List[int]: + assert ( + stride_per_key is not None + ), "This (Keyed)JaggedTensor doesn't have stride_per_key." + return stride_per_key + + def _get_inverse_indices_or_throw( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], ) -> Tuple[List[str], torch.Tensor]: @@ -891,9 +907,9 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten def _assert_tensor_has_no_elements_or_has_integers( - tensor: torch.Tensor, tensor_name: str + tensor: Optional[torch.Tensor], tensor_name: str ) -> None: - if is_torchdynamo_compiling(): + if is_torchdynamo_compiling() or tensor is None: # Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes. # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable. return @@ -921,10 +937,13 @@ def _maybe_compute_stride_kjt( stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], + stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 + elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: + stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: @@ -1467,6 +1486,50 @@ def _check_attributes( return True +def _maybe_compute_lengths_offset_per_key( + lengths_offset_per_key: Optional[List[int]], + stride_per_key: Optional[List[int]], + stride: Optional[int], + keys: List[str], +) -> Optional[List[int]]: + if lengths_offset_per_key is not None: + return lengths_offset_per_key + elif stride_per_key is not None: + return _cumsum(stride_per_key) + elif stride is not None: + return _cumsum([stride] * len(keys)) + else: + return None + + +def _maybe_compute_stride_per_key( + stride_per_key: Optional[List[int]], + stride_per_key_per_rank: Optional[List[List[int]]], + stride: Optional[int], + keys: List[str], +) -> Optional[List[int]]: + if stride_per_key is not None: + return stride_per_key + elif stride_per_key_per_rank is not None: + return [sum(s) for s in stride_per_key_per_rank] + elif stride is not None: + return [stride] * len(keys) + else: + return None + + +def _maybe_compute_variable_stride_per_key( + variable_stride_per_key: Optional[bool], + stride_per_key_per_rank: Optional[List[List[int]]], +) -> bool: + if variable_stride_per_key is not None: + return variable_stride_per_key + elif stride_per_key_per_rank is not None: + return True + else: + return False + + class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): """Represents an (optionally weighted) keyed jagged tensor. @@ -1540,62 +1603,57 @@ def __init__( stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, # Below exposed to ensure torch.script-able + stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, + lengths_offset_per_key: Optional[List[int]] = None, offset_per_key: Optional[List[int]] = None, index_per_key: Optional[Dict[str, int]] = None, jt_dict: Optional[Dict[str, JaggedTensor]] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> None: + """ + This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible. + It is important only to assign attributes here or do input checks to support various + internal inference optimizations. By convention the attirbute is named same as input arg, just + with leading underscore + """ self._keys: List[str] = keys self._values: torch.Tensor = values self._weights: Optional[torch.Tensor] = weights - if offsets is not None: - _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") - if lengths is not None: - _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets - - self._stride_per_key_per_rank: List[List[int]] = [] - self._stride_per_key: List[int] = [] - self._variable_stride_per_key: bool = False - self._stride: int = -1 - - if stride_per_key_per_rank is not None: - self._stride_per_key_per_rank = stride_per_key_per_rank - self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank] - self._variable_stride_per_key = True - if stride is not None: - self._stride = stride - else: - self._stride = ( - max(self._stride_per_key) if len(self._stride_per_key) > 0 else 0 - ) - else: - stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets) - self._stride = stride - self._stride_per_key_per_rank = [[stride]] * len(self._keys) - self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank] - - # lazy fields + self._stride: Optional[int] = stride + self._stride_per_key_per_rank: Optional[List[List[int]]] = ( + stride_per_key_per_rank + ) + self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key + self._lengths_offset_per_key: Optional[List[int]] = lengths_offset_per_key self._index_per_key: Optional[Dict[str, int]] = index_per_key self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( inverse_indices ) - self._lengths_offset_per_key: List[int] = [] - self._init_pt2_checks() + # legacy attribute, for backward compatabilibity + self._variable_stride_per_key: Optional[bool] = None + + # validation logic + if not torch.jit.is_scripting(): + _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") + _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + self._init_pt2_checks() def _init_pt2_checks(self) -> None: if torch.jit.is_scripting() or not is_torchdynamo_compiling(): return - - pt2_checks_all_is_size(self._stride_per_key) - for s in self._stride_per_key_per_rank: - pt2_checks_all_is_size(s) + if self._stride_per_key is not None: + pt2_checks_all_is_size(self._stride_per_key) + if self._stride_per_key_per_rank is not None: + # pyre-ignore [16] + for s in self._stride_per_key_per_rank: + pt2_checks_all_is_size(s) @staticmethod def from_offsets_sync( @@ -1863,16 +1921,34 @@ def weights_or_none(self) -> Optional[torch.Tensor]: return self._weights def stride(self) -> int: - return self._stride + stride = _maybe_compute_stride_kjt( + self._keys, + self._stride, + self._lengths, + self._offsets, + self._stride_per_key_per_rank, + ) + self._stride = stride + return stride def stride_per_key(self) -> List[int]: - return self._stride_per_key + stride_per_key = _maybe_compute_stride_per_key( + self._stride_per_key, + self._stride_per_key_per_rank, + self.stride(), + self._keys, + ) + self._stride_per_key = stride_per_key + return _get_stride_per_key_or_throw(stride_per_key) def stride_per_key_per_rank(self) -> List[List[int]]: - return self._stride_per_key_per_rank + stride_per_key_per_rank = self._stride_per_key_per_rank + return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] def variable_stride_per_key(self) -> bool: - return self._variable_stride_per_key + if self._variable_stride_per_key is not None: + return self._variable_stride_per_key + return self._stride_per_key_per_rank is not None def inverse_indices(self) -> Tuple[List[str], torch.Tensor]: return _get_inverse_indices_or_throw(self._inverse_indices) @@ -1925,9 +2001,20 @@ def offset_per_key_or_none(self) -> Optional[List[int]]: return self._offset_per_key def lengths_offset_per_key(self) -> List[int]: - if not self._lengths_offset_per_key: - self._lengths_offset_per_key = _cumsum(self.stride_per_key()) - return self._lengths_offset_per_key + if self.variable_stride_per_key(): + _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key( + self._lengths_offset_per_key, + self.stride_per_key(), + None, + self._keys, + ) + else: + _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key( + self._lengths_offset_per_key, None, self.stride(), self._keys + ) + + self._lengths_offset_per_key = _lengths_offset_per_key + return _get_lengths_offset_per_key_or_throw(_lengths_offset_per_key) def index_per_key(self) -> Dict[str, int]: return self._key_indices() @@ -1958,7 +2045,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: offsets=self._offsets, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=self._length_per_key, + lengths_offset_per_key=None, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, jt_dict=self._jt_dict, @@ -1992,7 +2081,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: ), stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=None, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, @@ -2036,7 +2127,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: offsets=None, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=split_length_per_key, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, @@ -2070,7 +2163,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: offsets=None, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=split_length_per_key, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, @@ -2098,10 +2193,11 @@ def permute( for index in indices: key = self.keys()[index] permuted_keys.append(key) - permuted_stride_per_key_per_rank.append( - self.stride_per_key_per_rank()[index] - ) permuted_length_per_key.append(length_per_key[index]) + if self.variable_stride_per_key(): + permuted_stride_per_key_per_rank.append( + self.stride_per_key_per_rank()[index] + ) permuted_length_per_key_sum = sum(permuted_length_per_key) if not torch.jit.is_scripting() and is_non_strict_exporting(): @@ -2164,7 +2260,9 @@ def permute( offsets=None, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, @@ -2184,7 +2282,9 @@ def flatten_lengths(self) -> "KeyedJaggedTensor": offsets=None, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, length_per_key=self.length_per_key(), + lengths_offset_per_key=None, offset_per_key=None, index_per_key=None, jt_dict=None, @@ -2304,8 +2404,10 @@ def to( self._stride_per_key_per_rank if self.variable_stride_per_key() else None ) length_per_key = self._length_per_key + lengths_offset_per_key = self._lengths_offset_per_key offset_per_key = self._offset_per_key index_per_key = self._index_per_key + stride_per_key = self._stride_per_key jt_dict = self._jt_dict inverse_indices = self._inverse_indices if inverse_indices is not None: @@ -2337,7 +2439,9 @@ def to( ), stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=stride_per_key, length_per_key=length_per_key, + lengths_offset_per_key=lengths_offset_per_key, offset_per_key=offset_per_key, index_per_key=index_per_key, jt_dict=jt_dict, @@ -2387,7 +2491,9 @@ def pin_memory(self) -> "KeyedJaggedTensor": offsets=offsets.pin_memory() if offsets is not None else None, stride=self._stride, stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=self._stride_per_key, length_per_key=self._length_per_key, + lengths_offset_per_key=self._lengths_offset_per_key, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, jt_dict=None, diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index d987dfbdd..7eca4af47 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2121,13 +2121,13 @@ def forward( lengths=input.lengths(), offsets=input.offsets(), ) - return output, output._stride + return output, output.stride() # Case 3: KeyedJaggedTensor is used as both an input and an output of the root module. m = ModuleUseKeyedJaggedTensorAsInputAndOutput() gm = symbolic_trace(m) FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check( - "._stride" + "stride" ).run(gm.code) input = KeyedJaggedTensor.from_offsets_sync( values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), @@ -2185,7 +2185,7 @@ def forward( lengths: torch.Tensor, ) -> Tuple[KeyedJaggedTensor, int]: output = KeyedJaggedTensor(keys, values, weights, lengths) - return output, output._stride + return output, output.stride() # Case 1: KeyedJaggedTensor is only used as an output of the root module. m = ModuleUseKeyedJaggedTensorAsOutput()