Skip to content

Commit 97396cd

Browse files
malfetpytorchmergebot
authored andcommitted
Fix undefined behavior detected by clang-12 (pytorch#106354)
Compiler behavior when non-zero offset is added to a null pointer is undefined and is a bad habit. - When `lapackEig` is called with to estimate a workspace size, do not add matrix size to the W pointer. - When `unpack_pivots_cpu_kernel` with zero `dim_size` exit early. - When `topk_impl_loop` is called with `k` is zero, exit right away as output tensors are empty anyway. - Ignore adding non-zero storage-offset in `TensorImpl::data_ptr_impl_impl`, which can be the case if tensor is created as `torch.empty(3)[4:]`. - In `s_addmm_out_sparse_dense_worker` do not call `axpy` over an empty vector. - In `_sparse_binary_op_intersection_kernel_impl` do skip computing `ptr_indices_dim` when `sparse_dim` is empty. - Exit `grid_sample` forward/backward kernels earlier if either `input` or `grid` are empty tensors. Found by asan in clang-12 Before the change UBSan report looks as follows: ``` ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-12/bin/llvm-symbolizer UBSAN_OPTIONS=print_stacktrace=1 LD_PRELOAD=/usr/lib/llvm-12/lib/clang/12.0.1/lib/linux/libclang_rt.asan-x86_64.so python test_fx_experimental.py -v -k test_normalize_operator_exhaustive_linalg_eig_cpu_float32 Test results will be stored in test-reports/python-unittest/test_fx_experimental Running tests... ---------------------------------------------------------------------- test_normalize_operator_exhaustive_linalg_eig_cpu_float32 (__main__.TestNormalizeOperatorsCPU) ... /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()' torch.has_cuda, /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()' torch.has_cudnn, /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()' torch.has_mps, /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()' torch.has_mkldnn, /var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17: runtime error: applying non-zero offset 20 to null pointer #0 0x7f2025794888 in void at::native::lapackEig<float, float>(char, char, int, float*, int, float*, float*, int, float*, int, float*, int, float*, int*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9945888) #1 0x7f20257da256 in void at::native::(anonymous namespace)::apply_linalg_eig<float>(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998b256) #2 0x7f20257d902d in at::native::(anonymous namespace)::linalg_eig_kernel(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor const&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998a02d) #3 0x7f20257b5b3d in at::native::linalg_eig_out_info(at::Tensor const&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9966b3d) #4 0x7f20257b4770 in at::native::linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9965770) #5 0x7f20280710e6 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU_out_linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&))>, std::tuple<at::Tensor&, at::Tensor&>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor&, at::Tensor&> >, std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc2220e6) #6 0x7f202727a045 in at::_ops::linalg_eig_out::call(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42b045) #7 0x7f20257b7e29 in at::native::linalg_eig(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9968e29) #8 0x7f2028070bf0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU__linalg_eig(at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc221bf0) #9 0x7f2026b1f787 in std::tuple<at::Tensor, at::Tensor> c10::Dispatcher::redispatch<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&>(c10::TypedOperatorHandle<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&) const (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xacd0787) #10 0x7f20273230a7 in at::_ops::linalg_eig::redispatch(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb4d40a7) #11 0x7f202c3cc32d in torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057d32d) #12 0x7f202c3cba96 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&), &(torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057ca96) #13 0x7f20272798e0 in at::_ops::linalg_eig::call(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42a8e0) #14 0x7f2043d97ae3 in torch::autograd::THPVariable_linalg_eig(_object*, _object*, _object*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_python.so+0x23feae3) #15 0x5072d6 in cfunction_call /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543:19 ... SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17 in ``` Pull Request resolved: pytorch#106354 Approved by: https://github.com/huydhn, https://github.com/lezcano
1 parent 6e2a284 commit 97396cd

12 files changed

+57
-21
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int
925925
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
926926
// and imaginary parts
927927
double *wr = w;
928-
double *wi = w + n;
928+
double *wi = w ? w + n : nullptr;
929929
(void)rwork; // unused
930930
dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
931931
}
@@ -934,7 +934,7 @@ template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int ld
934934
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
935935
// and imaginary parts
936936
float *wr = w;
937-
float *wi = w + n;
937+
float *wi = w ? w + n : nullptr;
938938
(void)rwork; // unused
939939
sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
940940
}

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <ATen/ops/empty.h>
1717
#include <ATen/ops/empty_strided.h>
1818
#endif
19-
namespace at { namespace native {
19+
namespace at::native {
2020

2121
namespace {
2222
/*
@@ -1102,15 +1102,14 @@ void svd_kernel(const Tensor& A,
11021102
}
11031103

11041104
void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) {
1105-
if (iter.numel() == 0) {
1105+
if (iter.numel() == 0 || dim_size == 0) {
11061106
return;
11071107
}
11081108
auto loop = [&](char* const* const data, const int64_t* const strides, const int64_t nelems) {
11091109
auto* perm_ptr = data[0];
11101110
const auto* pivots_ptr = data[1];
11111111

1112-
for (const auto elem : c10::irange(nelems)) {
1113-
(void)elem; //Suppress unused variable warning
1112+
for (C10_UNUSED const auto elem : c10::irange(nelems)) {
11141113
// WARNING: linalg.lu_factor returns int32 pivots,
11151114
// this behavior could change in the future.
11161115
const auto perm_data = reinterpret_cast<int64_t*>(perm_ptr);
@@ -1224,4 +1223,4 @@ REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
12241223
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
12251224
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
12261225
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
1227-
}} // namespace at::native
1226+
} // namespace at::native

aten/src/ATen/native/GridSampler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ namespace {
5757
int64_t out_H = grid.size(2);
5858
int64_t out_W = grid.size(3);
5959
auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
60+
if (output.numel() == 0) {
61+
return output;
62+
}
6063
int64_t inp_sN = input.stride(0);
6164
int64_t inp_sC = input.stride(1);
6265
int64_t inp_sD = input.stride(2);
@@ -219,6 +222,10 @@ namespace {
219222
}
220223
})();
221224
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
225+
if (grid.numel() == 0 || input.numel() == 0) {
226+
grad_grid.zero_();
227+
return std::make_tuple(grad_input, grad_grid);
228+
}
222229
// If interpolation mode is Nearest, then grad_grid is not filled in the
223230
// loop below.
224231
if (interpolation_mode == GridSamplerInterpolation::Nearest) {
@@ -567,6 +574,9 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
567574
int64_t out_H = grid.size(1);
568575
int64_t out_W = grid.size(2);
569576
auto output = at::empty({N, C, out_H, out_W}, input.options());
577+
if (output.numel() == 0) {
578+
return output;
579+
}
570580
int64_t inp_sN = input.stride(0);
571581
int64_t inp_sC = input.stride(1);
572582
int64_t inp_sH = input.stride(2);
@@ -715,6 +725,10 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
715725

716726
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
717727
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
728+
if (grid.numel() == 0 || input.numel() == 0) {
729+
grad_grid.zero_();
730+
return std::make_tuple(grad_input, grad_grid);
731+
}
718732
// If interpolation mode is Nearest, then grad_grid is not filled in the
719733
// loop below.
720734
if (interpolation_mode == GridSamplerInterpolation::Nearest) {

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,9 @@ TORCH_IMPL_FUNC(index_copy_out)
857857
// Not calling into index_reduce_func_impl because of a different dtype dispatch
858858
TORCH_IMPL_FUNC(index_add_cpu_out)
859859
(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
860-
if (!result.is_same(self)) result.copy_(self);
860+
if (!result.is_same(self)) {
861+
result.copy_(self);
862+
}
861863
auto numel = index.numel();
862864

863865
auto index_contig = index.contiguous();
@@ -870,7 +872,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
870872
// selfSlice.add_(sourceSlice);
871873
// }
872874
// But much faster as this reuses the iterator from add_
873-
if (numel == 0) {
875+
if (numel == 0 || self.numel() == 0) {
874876
return;
875877
}
876878

@@ -945,8 +947,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
945947
add_stub(iter.device_type(), iter, alpha);
946948
}
947949
});
948-
}
949-
else {
950+
} else {
950951
TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
951952

952953
// explicitly capture all required variables to work around windows build

aten/src/ATen/native/TopKImpl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ void topk_impl_loop(
2222
const bool sorted,
2323
char** data, const int64_t* strides, const int64_t n) {
2424

25+
// If k is zero, then output values and indices are empty tensors
26+
// So iterating over other dims is pointless
27+
if (k == 0) {
28+
return;
29+
}
2530
using elem_t = std::pair<accscalar_t, int64_t>;
2631
std::vector<elem_t> queue(dim_size);
2732
for (const auto i : c10::irange(n)) {

aten/src/ATen/native/cpu/GridSamplerKernel.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,13 +1157,16 @@ void grid_sampler_2d_cpu_kernel_impl(
11571157
auto spatial_size = H * W;
11581158
auto grain_size = spatial_size == 0 ? (N + 1)
11591159
: at::divup(at::internal::GRAIN_SIZE, spatial_size * 4 /* 2d * 2 tensors*/);
1160+
if (output.numel() == 0) {
1161+
return;
1162+
}
11601163

11611164
#define HANDLE_CASE(interp, padding, align_corners) \
11621165
case padding: { \
11631166
ApplyGridSample<scalar_t, 2, interp, padding, align_corners> \
11641167
grid_sample(inp_acc); \
11651168
parallel_for(0, N, grain_size, [&](int64_t begin, int64_t end) { \
1166-
for (const auto n : c10::irange(begin, end)) { \
1169+
for (const auto n : c10::irange(begin, end)) { \
11671170
auto out_slice = out_acc[n]; \
11681171
auto inp_slice = inp_acc[n]; \
11691172
grid_sample_2d_grid_slice_iterator( \
@@ -1220,6 +1223,10 @@ void grid_sampler_2d_backward_cpu_kernel_impl(
12201223
int64_t padding_mode,
12211224
bool align_corners,
12221225
std::array<bool,2> output_mask) {
1226+
if (grad_output_.numel() == 0) {
1227+
grad_grid.zero_();
1228+
return;
1229+
}
12231230
// grad_output should be contiguous most of time. Ensuring that it is
12241231
// contiguous can greatly simplify this code.
12251232
auto grad_output = grad_output_.contiguous();

aten/src/ATen/native/cpu/SparseFactories.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void _spdiags_kernel_cpu(
1717
TensorBase& values,
1818
TensorBase& indices) {
1919
auto* row_index_write_ptr = indices.data_ptr<int64_t>();
20-
auto* col_index_write_ptr = row_index_write_ptr + indices.stride(0);
20+
auto* col_index_write_ptr = row_index_write_ptr ? row_index_write_ptr + indices.stride(0) : nullptr;
2121
const int64_t diagonals_index_stride = diagonals.stride(0);
2222
const int64_t diagonals_read_stride = diagonals.stride(1);
2323
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(

aten/src/ATen/native/cpu/group_norm_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ void GroupNormInputBackward(
675675
const int64_t g = i % G;
676676
const T_ACC* ds_ptr = ds + i * D;
677677
const T_ACC* db_ptr = db + i * D;
678-
const PT* gamma_ptr = gamma + g * D;
678+
const PT* gamma_ptr = !gamma_null ? gamma + g * D : nullptr;
679679
CalcDsDb(ds_ptr, db_ptr, gamma_null, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
680680
T_ACC ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), T_ACC(0));
681681
T_ACC db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), T_ACC(0));

aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ void _sparse_binary_op_intersection_kernel_impl(
276276
KernelLauncher::launch(iter,
277277
// NOTE: capture by value required by CUDA
278278
[=] FUNCAPI (index_t nnz_idx) -> int64_t {
279-
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
279+
const auto* RESTRICT ptr_indices_dim = ptr_indices ? ptr_indices + nnz_idx * indices_nnz_stride : nullptr;
280280
int64_t hash = 0;
281281
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
282282
const auto dim_hash_coeff = hash_coeffs[dim];
@@ -299,8 +299,7 @@ void _sparse_binary_op_intersection_kernel_impl(
299299
// NOTE: argsort.dtype == nnz_arange.dtype
300300
const auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
301301
return std::make_tuple(probably_coalesced_indices_hash, argsort);
302-
}
303-
else {
302+
} else {
304303
// NOTE: we want argsort.dtype == nnz_arange.dtype,
305304
// but sort() produces indices of type int64_t,
306305
// so we convert to nnz_arange.dtype to avoid issues
@@ -360,12 +359,12 @@ void _sparse_binary_op_intersection_kernel_impl(
360359
KernelLauncher::launch(iter,
361360
// NOTE: capture by value required by CUDA
362361
[=] FUNCAPI (index_t nnz_idx) -> index_t {
363-
// Compute hash value
364-
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
365362
int64_t hash = 0;
366363
if (hash_ptr) {
367364
hash = hash_ptr[nnz_idx];
368-
} else {
365+
} else if (sparse_dim) {
366+
// Compute hash value
367+
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
369368
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
370369
const auto dim_hash_coeff = hash_coeffs[dim];
371370
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];

aten/src/ATen/native/sparse/SparseTensorMath.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,10 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
12221222
int64_t row = indices_accessor[0][i];
12231223
int64_t col = indices_accessor[1][i];
12241224
if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
1225+
// AXPY call is no-op over an empty vector
1226+
if (dim_k == 0) {
1227+
continue;
1228+
}
12251229
at::native::cpublas::axpy<scalar_t>(dim_k,
12261230
cast_alpha * val,
12271231
dense_ptr + col * dense_stride0, dense_stride1,

c10/core/TensorImpl.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
15501550
// Shared implementation of mutable_data_ptr_impl() and the future
15511551
// mutable_data_ptr_impl().
15521552
template <typename T, typename Func>
1553-
T* data_ptr_impl_impl(const Func& get_data) const {
1553+
__ubsan_ignore_pointer_overflow__ T* data_ptr_impl_impl(
1554+
const Func& get_data) const {
15541555
if (C10_UNLIKELY(!has_storage())) {
15551556
throw_data_ptr_access_error();
15561557
}
@@ -1560,6 +1561,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
15601561
"Caffe2 uses a lazy allocation, so you will need to call "
15611562
"mutable_data() or raw_mutable_data() to actually allocate memory.");
15621563
// Caller does the type check.
1564+
// Note: storage_offset_ can be non-null even for zero-elements tensors
1565+
// (for example if created as `torch.empty(5)[10:]`) that triggers
1566+
// applying non-zero offset to null pointer in UBSan
15631567
return get_data() + storage_offset_;
15641568
}
15651569

c10/macros/Macros.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030
#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined")))
3131
#define __ubsan_ignore_signed_int_overflow__ \
3232
__attribute__((no_sanitize("signed-integer-overflow")))
33+
#define __ubsan_ignore_pointer_overflow__ \
34+
__attribute__((no_sanitize("pointer-overflow")))
3335
#define __ubsan_ignore_function__ __attribute__((no_sanitize("function")))
3436
#else
3537
#define __ubsan_ignore_float_divide_by_zero__
3638
#define __ubsan_ignore_undefined__
3739
#define __ubsan_ignore_signed_int_overflow__
40+
#define __ubsan_ignore_pointer_overflow__
3841
#define __ubsan_ignore_function__
3942
#endif
4043

0 commit comments

Comments
 (0)