Skip to content

Commit 18b2751

Browse files
bertmaherfacebook-github-bot
authored andcommitted
[nnc] Make our exceptions c10::Errors, get C++ stacktraces (pytorch#64332)
Summary: Pull Request resolved: pytorch#64332 With this diff, if a compiler bug occurs (unlikely, I know!) we'll be able to get a c++ stacktrace leading to the exception, rather than just a terse message. E.g., ``` RuntimeError: UNSUPPORTED DTYPE Exception raised from compilation_error at ../torch/csrc/jit/tensorexpr/exceptions.h:32 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7f966659b2eb in /fsx/users/bertrand/c\ onda/envs/pytorch/lib/python3.8/site-packages/torch/lib/libc10.so) frame #1: <unknown function> + 0x376f099 (0x7f966a195099 in /fsx/users/bertrand/conda/envs/pytorch/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so) frame #2: <unknown function> + 0x3763bf5 (0x7f966a189bf5 in /fsx/users/bertrand/conda/envs/pytorch/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so) frame #3: torch::jit::tensorexpr::CudaCodeGen::Initialize() + 0xdd8 (0x7f966a193368 in /fsx/users/bertrand/conda/envs/pytorch/lib/python3.8/site-packages/torch/lib/libtorch_cuda\ .so) ``` Test Plan: Imported from OSS Reviewed By: huiguoo Differential Revision: D30745610 Pulled By: bertmaher fbshipit-source-id: a1cfaa7364ef4120de834e9cbe57ced1d082ab4e
1 parent 6cac7ca commit 18b2751

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

torch/csrc/jit/tensorexpr/exceptions.h

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,66 +26,78 @@ namespace torch {
2626
namespace jit {
2727
namespace tensorexpr {
2828

29-
class unsupported_dtype : public std::runtime_error {
29+
TORCH_API std::string buildErrorMessage(const std::string& s);
30+
31+
class compilation_error : public c10::Error {
3032
public:
31-
explicit unsupported_dtype() : std::runtime_error("UNSUPPORTED DTYPE") {}
33+
explicit compilation_error(const std::string& err)
34+
: c10::Error(
35+
{
36+
__func__,
37+
__FILE__,
38+
static_cast<uint32_t>(__LINE__),
39+
},
40+
buildErrorMessage(err)) {}
41+
};
42+
43+
class unsupported_dtype : public compilation_error {
44+
public:
45+
explicit unsupported_dtype() : compilation_error("UNSUPPORTED DTYPE") {}
3246
explicit unsupported_dtype(const std::string& err)
33-
: std::runtime_error("UNSUPPORTED DTYPE: " + err) {}
47+
: compilation_error("UNSUPPORTED DTYPE: " + err) {}
3448
};
3549

36-
class out_of_range_index : public std::runtime_error {
50+
class out_of_range_index : public compilation_error {
3751
public:
38-
explicit out_of_range_index() : std::runtime_error("OUT OF RANGE INDEX") {}
52+
explicit out_of_range_index() : compilation_error("OUT OF RANGE INDEX") {}
3953
explicit out_of_range_index(const std::string& err)
40-
: std::runtime_error("OUT OF RANGE INDEX: " + err) {}
54+
: compilation_error("OUT OF RANGE INDEX: " + err) {}
4155
};
4256

43-
class unimplemented_lowering : public std::runtime_error {
57+
class unimplemented_lowering : public compilation_error {
4458
public:
4559
explicit unimplemented_lowering()
46-
: std::runtime_error("UNIMPLEMENTED LOWERING") {}
60+
: compilation_error("UNIMPLEMENTED LOWERING") {}
4761
explicit unimplemented_lowering(ExprPtr expr)
48-
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
62+
: compilation_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
4963
explicit unimplemented_lowering(StmtPtr stmt)
50-
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
64+
: compilation_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
5165
};
5266

53-
class malformed_input : public std::runtime_error {
67+
class malformed_input : public compilation_error {
5468
public:
55-
explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {}
69+
explicit malformed_input() : compilation_error("MALFORMED INPUT") {}
5670
explicit malformed_input(const std::string& err)
57-
: std::runtime_error("MALFORMED INPUT: " + err) {}
71+
: compilation_error("MALFORMED INPUT: " + err) {}
5872
explicit malformed_input(ExprPtr expr)
59-
: std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
73+
: compilation_error("MALFORMED INPUT: " + std::to_string(expr)) {}
6074
explicit malformed_input(const std::string& err, ExprPtr expr)
61-
: std::runtime_error(
75+
: compilation_error(
6276
"MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
6377
explicit malformed_input(StmtPtr stmt)
64-
: std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
78+
: compilation_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
6579
explicit malformed_input(const std::string& err, StmtPtr stmt)
66-
: std::runtime_error(
80+
: compilation_error(
6781
"MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
6882
};
6983

70-
class malformed_ir : public std::runtime_error {
84+
class malformed_ir : public compilation_error {
7185
public:
72-
explicit malformed_ir() : std::runtime_error("MALFORMED IR") {}
86+
explicit malformed_ir() : compilation_error("MALFORMED IR") {}
7387
explicit malformed_ir(const std::string& err)
74-
: std::runtime_error("MALFORMED IR: " + err) {}
88+
: compilation_error("MALFORMED IR: " + err) {}
7589
explicit malformed_ir(ExprPtr expr)
76-
: std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {}
90+
: compilation_error("MALFORMED IR: " + std::to_string(expr)) {}
7791
explicit malformed_ir(const std::string& err, ExprPtr expr)
78-
: std::runtime_error(
92+
: compilation_error(
7993
"MALFORMED IR: " + err + " - " + std::to_string(expr)) {}
8094
explicit malformed_ir(StmtPtr stmt)
81-
: std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {}
95+
: compilation_error("MALFORMED IR: " + std::to_string(stmt)) {}
8296
explicit malformed_ir(const std::string& err, StmtPtr stmt)
83-
: std::runtime_error(
97+
: compilation_error(
8498
"MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
8599
};
86100

87-
TORCH_API std::string buildErrorMessage(const std::string& s);
88-
89101
} // namespace tensorexpr
90102
} // namespace jit
91103
} // namespace torch

torch/csrc/jit/tensorexpr/loopnest.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,13 @@ bool LoopNest::vectorize(ForPtr f) {
476476
normalize(to<For>(new_f));
477477
new_f = FlattenIndexes(new_f);
478478
new_f = v.vectorize(to<For>(new_f));
479-
} catch (std::runtime_error& e) {
479+
} catch (compilation_error& e) {
480480
// We clone f before vectorizing. So, any partial vectorization will
481481
// have modified the clone. In case of an exception, we can continue
482482
// using f.
483483
new_f = f;
484+
} catch (std::runtime_error& e) {
485+
new_f = f;
484486
}
485487

486488
if (new_f != f) {

0 commit comments

Comments
 (0)