Skip to content

Commit a058e93

Browse files
ezyangfacebook-github-bot
authored andcommitted
Refactor error msg stack handling, add TORCH_RETHROW (pytorch#37101)
Summary: Pull Request resolved: pytorch#37101 Fixes pytorch#36954. The basic concept is to streamline the process of rethrowing c10::Error with extra error information. This is in a few steps: - I completely remodeled the Error data type and the internal invariants. Instead of manually adding in newlines, the message stack formatting process is responsible for inserting newlines and spacing as necessary. Call sites are then modified to respect the new API model. - TORCH_RETHROW macro is added, which adds context to an error message and then rethrows it. New internal assert failure looks like: ``` 0 INTERNAL ASSERT FAILED at ../c10/test/util/exception_test.cpp:64, please report a bug to PyTorch. Exception raised from TestBody at ../c10/test/util/exception_test.cpp:64 (most recent call first): frame #0: <unknown function> + 0x6aab9 (0x7ff611d3aab9 in /data/users/ezyang/pytorch-tmp/build/lib/libc10.so) frame #1: ... ``` Error message with context looks like: ``` This is an error This is context 1 This is context 2 ``` Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Differential Revision: D21202891 Pulled By: ezyang fbshipit-source-id: 361cadd16bc52e5886dba08e79277771ada76169
1 parent efd8f70 commit a058e93

File tree

6 files changed

+156
-57
lines changed

6 files changed

+156
-57
lines changed

c10/test/util/exception_test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,24 @@
22
#include <gtest/gtest.h>
33
#include <stdexcept>
44

5+
using c10::Error;
6+
57
namespace {
68
bool throw_func() {
79
throw std::runtime_error("I'm throwing...");
810
}
11+
12+
template<class Functor>
13+
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
14+
try {
15+
std::forward<Functor>(functor)();
16+
} catch (const Error& e) {
17+
EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
18+
return;
19+
}
20+
ADD_FAILURE() << "Expected to throw exception with message \""
21+
<< expectedMessage << "\" but didn't throw";
22+
}
923
} // namespace
1024

1125
TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
@@ -22,3 +36,32 @@ TEST(ExceptionTest, TORCH_INTERNAL_ASSERT_DEBUG_ONLY) {
2236
TEST(WarningTest, JustPrintWarning) {
2337
TORCH_WARN("I'm a warning");
2438
}
39+
40+
TEST(ExceptionTest, ErrorFormatting) {
41+
expectThrowsEq([]() {
42+
TORCH_CHECK(false, "This is invalid");
43+
}, "This is invalid");
44+
45+
expectThrowsEq([]() {
46+
try {
47+
TORCH_CHECK(false, "This is invalid");
48+
} catch (Error& e) {
49+
TORCH_RETHROW(e, "While checking X");
50+
}
51+
}, "This is invalid (While checking X)");
52+
53+
expectThrowsEq([]() {
54+
try {
55+
try {
56+
TORCH_CHECK(false, "This is invalid");
57+
} catch (Error& e) {
58+
TORCH_RETHROW(e, "While checking X");
59+
}
60+
} catch (Error& e) {
61+
TORCH_RETHROW(e, "While checking Y");
62+
}
63+
},
64+
R"msg(This is invalid
65+
While checking X
66+
While checking Y)msg");
67+
}

c10/util/Exception.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,15 @@
44
#include <c10/util/Logging.h>
55

66
#include <iostream>
7+
#include <sstream>
78
#include <numeric>
89
#include <string>
910

1011
namespace c10 {
1112

12-
Error::Error(
13-
const std::string& new_msg,
14-
const std::string& backtrace,
15-
const void* caller)
16-
: msg_stack_{new_msg}, backtrace_(backtrace), caller_(caller) {
17-
msg_ = msg();
18-
msg_without_backtrace_ = msg_without_backtrace();
13+
Error::Error(std::string msg, std::string backtrace, const void* caller)
14+
: msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
15+
refresh_what();
1916
}
2017

2118
// PyTorch-style error message
@@ -38,29 +35,45 @@ Error::Error(
3835
"] ",
3936
condition,
4037
". ",
41-
msg,
42-
"\n"),
38+
msg),
4339
backtrace,
4440
caller) {}
4541

46-
std::string Error::msg() const {
47-
return std::accumulate(
48-
msg_stack_.begin(), msg_stack_.end(), std::string("")) +
49-
backtrace_;
42+
std::string Error::compute_what(bool include_backtrace) const {
43+
std::ostringstream oss;
44+
45+
oss << msg_;
46+
47+
if (context_.size() == 1) {
48+
// Fold error and context in one line
49+
oss << " (" << context_[0] << ")";
50+
} else {
51+
for (const auto& c : context_) {
52+
oss << "\n " << c;
53+
}
54+
}
55+
56+
if (include_backtrace) {
57+
oss << "\n" << backtrace_;
58+
}
59+
60+
return oss.str();
5061
}
5162

52-
std::string Error::msg_without_backtrace() const {
53-
return std::accumulate(msg_stack_.begin(), msg_stack_.end(), std::string(""));
63+
void Error::refresh_what() {
64+
what_ = compute_what(/*include_backtrace*/ true);
65+
what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
5466
}
5567

56-
void Error::AppendMessage(const std::string& new_msg) {
57-
msg_stack_.push_back(new_msg);
58-
// Refresh the cache
59-
// TODO: Calling AppendMessage O(n) times has O(n^2) cost. We can fix
68+
void Error::add_context(std::string new_msg) {
69+
context_.push_back(std::move(new_msg));
70+
// TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
6071
// this perf problem by populating the fields lazily... if this ever
6172
// actually is a problem.
62-
msg_ = msg();
63-
msg_without_backtrace_ = msg_without_backtrace();
73+
// NB: If you do fix this, make sure you do it in a thread safe way!
74+
// what() is almost certainly expected to be thread safe even when
75+
// accessed across multiple threads
76+
refresh_what();
6477
}
6578

6679
namespace Warning {

c10/util/Exception.h

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,26 @@ namespace c10 {
2626
/// NB: c10::Error is handled specially by the default torch to suppress the
2727
/// backtrace, see torch/csrc/Exceptions.h
2828
class C10_API Error : public std::exception {
29-
std::vector<std::string> msg_stack_;
29+
// The actual error message.
30+
std::string msg_;
31+
32+
// Context for the message (in order of decreasing specificity). Context will
33+
// be automatically formatted appropriately, so it is not necessary to add
34+
// extra leading/trailing newlines to strings inside this vector
35+
std::vector<std::string> context_;
36+
37+
// The C++ backtrace at the point when this exception was raised. This
38+
// may be empty if there is no valid backtrace. (We don't use optional
39+
// here to reduce the dependencies this file has.)
3040
std::string backtrace_;
3141

3242
// These two are derived fields from msg_stack_ and backtrace_, but we need
3343
// fields for the strings so that we can return a const char* (as the
34-
// signature of std::exception requires).
35-
std::string msg_;
36-
std::string msg_without_backtrace_;
44+
// signature of std::exception requires). Currently, the invariant
45+
// is that these fields are ALWAYS populated consistently with respect
46+
// to msg_stack_ and backtrace_.
47+
std::string what_;
48+
std::string what_without_backtrace_;
3749

3850
// This is a little debugging trick: you can stash a relevant pointer
3951
// in caller, and then when you catch the exception, you can compare
@@ -43,11 +55,11 @@ class C10_API Error : public std::exception {
4355
const void* caller_;
4456

4557
public:
46-
Error(
47-
const std::string& msg,
48-
const std::string& backtrace,
49-
const void* caller = nullptr);
50-
Error(SourceLocation source_location, const std::string& msg);
58+
// PyTorch-style Error constructor. NB: the implementation of this
59+
// is actually in Logging.cpp
60+
Error(SourceLocation source_location, std::string msg);
61+
62+
// Caffe2-style error message
5163
Error(
5264
const char* file,
5365
const uint32_t line,
@@ -56,30 +68,51 @@ class C10_API Error : public std::exception {
5668
const std::string& backtrace,
5769
const void* caller = nullptr);
5870

59-
void AppendMessage(const std::string& msg);
71+
// Base constructor
72+
Error(
73+
std::string msg,
74+
std::string backtrace,
75+
const void* caller = nullptr);
6076

61-
const std::vector<std::string>& msg_stack() const {
62-
return msg_stack_;
77+
// Add some new context to the message stack. The last added context
78+
// will be formatted at the end of the context list upon printing.
79+
// WARNING: This method is O(n) in the size of the stack, so don't go
80+
// wild adding a ridiculous amount of context to error messages.
81+
void add_context(std::string msg);
82+
83+
const std::string& msg() const {
84+
return msg_;
85+
}
86+
87+
const std::vector<std::string>& context() const {
88+
return context_;
89+
}
90+
91+
const std::string& backtrace() const {
92+
return backtrace_;
6393
}
6494

6595
/// Returns the complete error message, including the source location.
96+
/// The returned pointer is invalidated if you call add_context() on
97+
/// this object.
6698
const char* what() const noexcept override {
67-
return msg_.c_str();
99+
return what_.c_str();
68100
}
69101

70102
const void* caller() const noexcept {
71103
return caller_;
72104
}
73105

74106
/// Returns only the error message string, without source location.
107+
/// The returned pointer is invalidated if you call add_context() on
108+
/// this object.
75109
const char* what_without_backtrace() const noexcept {
76-
return msg_without_backtrace_.c_str();
110+
return what_without_backtrace_.c_str();
77111
}
78112

79113
private:
80-
// Compute the full message from msg_ and msg_without_backtrace_
81-
std::string msg() const;
82-
std::string msg_without_backtrace() const;
114+
void refresh_what();
115+
std::string compute_what(bool include_backtrace) const;
83116
};
84117

85118
class C10_API WarningHandler {
@@ -204,6 +237,16 @@ inline std::string if_empty_then(std::string x, std::string y) {
204237
// Error reporting macros
205238
// ----------------------------------------------------------------------------
206239

240+
#ifdef STRIP_ERROR_MESSAGES
241+
#define TORCH_RETHROW(e, ...) throw
242+
#else
243+
#define TORCH_RETHROW(e, ...) \
244+
do { \
245+
e.add_context(::c10::str(__VA_ARGS__)); \
246+
throw; \
247+
} while (false)
248+
#endif
249+
207250
// A utility macro to provide assert()-like functionality; that is, enforcement
208251
// of internal invariants in code. It supports an arbitrary number of extra
209252
// arguments (evaluated only on failure), which will be printed in the assert

c10/util/Logging.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void ThrowEnforceNotMet(
4444
const void* caller) {
4545
c10::Error e(file, line, condition, msg, (*GetFetchStackTrace())(), caller);
4646
if (FLAGS_caffe2_use_fatal_for_enforce) {
47-
LOG(FATAL) << e.msg_stack()[0];
47+
LOG(FATAL) << e.msg();
4848
}
4949
throw e;
5050
}
@@ -63,8 +63,8 @@ void ThrowEnforceFiniteNotMet(
6363

6464
// PyTorch-style error message
6565
// (This must be defined here for access to GetFetchStackTrace)
66-
Error::Error(SourceLocation source_location, const std::string& msg)
67-
: Error(msg, str(" (", source_location, ")\n", (*GetFetchStackTrace())())) {
66+
Error::Error(SourceLocation source_location, std::string msg)
67+
: Error(std::move(msg), str("Exception raised from ", source_location, " (most recent call first):\n", (*GetFetchStackTrace())())) {
6868
}
6969

7070
using APIUsageLoggerType = std::function<void(const std::string&)>;

caffe2/core/operator.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <string>
1212
#include <typeinfo>
1313
#include <vector>
14+
#include <sstream>
1415

1516
#include <c10/macros/Macros.h>
1617
#include <c10/util/Registry.h>
@@ -154,9 +155,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
154155
return inputs_.at(idx)->template Get<T>();
155156
} catch (::caffe2::EnforceNotMet& enf) {
156157
if (has_debug_def()) {
157-
enf.AppendMessage(".\nOffending Blob name: ");
158-
enf.AppendMessage(debug_def().input(idx));
159-
enf.AppendMessage(".\n");
158+
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
160159
}
161160
throw enf;
162161
}
@@ -180,9 +179,7 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
180179
return tensor;
181180
} catch (::caffe2::EnforceNotMet& enf) {
182181
if (has_debug_def()) {
183-
enf.AppendMessage(".\nOffending Blob name: ");
184-
enf.AppendMessage(debug_def().input(idx));
185-
enf.AppendMessage(".\n");
182+
TORCH_RETHROW(enf, "Offending Blob name: ", debug_def().input(idx), ".");
186183
}
187184
throw enf;
188185
}
@@ -521,26 +518,30 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
521518
return;
522519
}
523520

524-
bool found_input;
521+
bool found_input = false;
522+
bool found_output = false;
525523
if (err->caller() != nullptr) {
524+
std::ostringstream oss;
526525
for (size_t i = 0; i < inputs_.size(); i++) {
527526
if (inputs_[i]->GetRaw() == err->caller()) {
528527
found_input = true;
529-
err->AppendMessage(
530-
"\n** while accessing input: " + debug_def().input(i));
528+
oss << "while accessing input: " << debug_def().input(i);
531529
break;
532530
}
533531
}
534532
for (size_t i = 0; i < outputs_.size(); i++) {
535533
if (outputs_[i]->GetRaw() == err->caller()) {
534+
found_output = true;
536535
if (found_input) {
537-
err->AppendMessage("\n OR ");
536+
oss << " OR ";
538537
}
539-
err->AppendMessage(
540-
"\n** while accessing output: " + debug_def().output(i));
538+
oss << "while accessing output: " << debug_def().output(i);
541539
break;
542540
}
543541
}
542+
if (found_input || found_output) {
543+
err->add_context(oss.str());
544+
}
544545
}
545546
}
546547

@@ -1071,7 +1072,7 @@ class Operator : public OperatorBase {
10711072
return result;
10721073
} catch (EnforceNotMet& err) {
10731074
if (has_debug_def()) {
1074-
err.AppendMessage(
1075+
err.add_context(
10751076
"Error from operator: \n" + ProtoDebugString(debug_def()));
10761077
AddRelatedBlobInfo(&err);
10771078
}
@@ -1109,7 +1110,7 @@ class Operator : public OperatorBase {
11091110
return result;
11101111
} catch (EnforceNotMet& err) {
11111112
if (has_debug_def()) {
1112-
err.AppendMessage(
1113+
err.add_context(
11131114
"Error from operator: \n" + ProtoDebugString(debug_def()));
11141115
AddRelatedBlobInfo(&err);
11151116
}

caffe2/ideep/utils/ideep_operator.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ class IDEEPOperator : public OperatorBase {
6262
StopAllObservers();
6363
return result;
6464
} catch (EnforceNotMet& err) {
65-
err.AppendMessage(getErrorMsg());
66-
throw;
65+
TORCH_RETHROW(err, getErrorMsg());
6766
} catch (ideep::error& e) {
6867
LOG(ERROR) << "IDEEP error:" << e.message;
6968
throw;

0 commit comments

Comments
 (0)