Skip to content

Create GlowErr type with informational error code enum #2283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 100 additions & 14 deletions include/glow/Support/Error.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ namespace glow {
/// line numbers also.
extern llvm::ExitOnError exitOnErr;

/// Take a message \p str and prepend it with the given \p file and \p line
/// number. This is useful for augmenting StringErrors with information about
/// where they were generated.
std::string addFileAndLineToError(llvm::StringRef str, llvm::StringRef file,
uint32_t line);

/// Is true_type only if applied to llvm::Error or a descendant.
template <typename T>
struct IsLLVMError : public std::is_base_of<llvm::Error, T> {};
Expand All @@ -45,6 +39,101 @@ template <typename> struct IsLLVMExpected : public std::false_type {};
template <typename T>
struct IsLLVMExpected<llvm::Expected<T>> : public std::true_type {};

/// Represents errors in Glow. GlowErr track the file name and line number of
/// where they were created as well as a textual message and/or a error code to
/// help identify the type of error the occurred programtically.
class GlowErr final : public llvm::ErrorInfo<GlowErr> {
public:
/// Used by ErrorInfo::classID.
static const uint8_t ID;
/// An enumeration of error codes representing various possible errors that
/// could occur.
/// NOTE: when updating this enum, also update ErrorCodeToString function
/// below.
enum class ErrorCode {
// An unknown error ocurred. This is the default value.
UNKNOWN,
// Model loader encountered an unsupported shape.
MODEL_LOADER_UNSUPPORTED_SHAPE,
// Model loader encountered an unsupported operator.
MODEL_LOADER_UNSUPPORTED_OPERATOR,
// Model loader encountered an unsupported attribute.
MODEL_LOADER_UNSUPPORTED_ATTRIBUTE,
// Model loader encountered an unsupported datatype.
MODEL_LOADER_UNSUPPORTED_DATATYPE,
// Model loader encountered an unsupported ONNX version.
MODEL_LOADER_UNSUPPORTED_ONNX_VERSION,
// Model loader encountered an invalid protobuf.
MODEL_LOADER_INVALID_PROTOBUF,
};

/// GlowErr is not convertable to std::error_code. This is included for
/// compatiblity with ErrorInfo.
virtual std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}

/// Log to \p OS relevant error information including the file name and
/// line number the GlowErr was created on as well as the message and/or error
/// code the GlowErr was created with.
void log(llvm::raw_ostream &OS) const override {
OS << "file: " << fileName_ << " line: " << lineNumber_;
if (ec_ != ErrorCode::UNKNOWN) {
OS << " error code: " << errorCodeToString(ec_);
}
if (!message_.empty()) {
OS << " message: " << message_;
}
}

GlowErr(llvm::StringRef fileName, size_t lineNumber, llvm::StringRef message,
ErrorCode ec)
: lineNumber_(lineNumber), fileName_(fileName), message_(message),
ec_(ec) {}

GlowErr(llvm::StringRef fileName, size_t lineNumber, ErrorCode ec,
llvm::StringRef message)
: lineNumber_(lineNumber), fileName_(fileName), message_(message),
ec_(ec) {}

GlowErr(llvm::StringRef fileName, size_t lineNumber, ErrorCode ec)
: lineNumber_(lineNumber), fileName_(fileName), ec_(ec) {}

GlowErr(llvm::StringRef fileName, size_t lineNumber, llvm::StringRef message)
: lineNumber_(lineNumber), fileName_(fileName), message_(message) {}

private:
/// Convert ErrorCode values to string.
static std::string errorCodeToString(const ErrorCode &ec) {
switch (ec) {
case ErrorCode::UNKNOWN:
return "UNKNOWN";
case ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE:
return "MODEL_LOADER_UNSUPPORTED_SHAPE";
case ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR:
return "MODEL_LOADER_UNSUPPORTED_OPERATOR";
case ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE:
return "MODEL_LOADER_UNSUPPORTED_ATTRIBUTE";
case ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE:
return "MODEL_LOADER_UNSUPPORTED_DATATYPE";
case ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION:
return "MODEL_LOADER_UNSUPPORTED_ONNX_VERSION";
case ErrorCode::MODEL_LOADER_INVALID_PROTOBUF:
return "MODEL_LOADER_INVALID_PROTOBUF";
};
llvm_unreachable("unsupported ErrorCode");
}

/// The line number the error was generated on.
size_t lineNumber_;
/// The name of the file the error was generated in.
std::string fileName_;
/// Optional message associated with the error.
std::string message_;
/// Optional error code associated with the error.
ErrorCode ec_ = ErrorCode::UNKNOWN;
};

/// Unwraps the T from within an llvm::Expected<T>. If the Expected<T> contains
/// an error, the program will exit.
#define EXIT_ON_ERR(...) (exitOnErr(__VA_ARGS__))
Expand All @@ -56,15 +145,12 @@ struct IsLLVMExpected<llvm::Expected<T>> : public std::true_type {};
#define TEMP_EXIT_ON_ERR(...) (EXIT_ON_ERR(__VA_ARGS__))

/// Make a new llvm::StringError.
#define MAKE_ERR(str) \
llvm::make_error<llvm::StringError>( \
(addFileAndLineToError(str, __FILE__, __LINE__)), \
llvm::inconvertibleErrorCode())
#define MAKE_ERR(...) llvm::make_error<GlowErr>(__FILE__, __LINE__, __VA_ARGS__)

/// Makes a new llvm::StringError and returns it.
#define RETURN_ERR(str) \
#define RETURN_ERR(...) \
do { \
return MAKE_ERR(str); \
return MAKE_ERR(__VA_ARGS__); \
} while (0)

/// Takes an llvm::Expected<T> \p lhsOrErr and if it is an Error then returns
Expand Down Expand Up @@ -94,10 +180,10 @@ struct IsLLVMExpected<llvm::Expected<T>> : public std::true_type {};

/// Takes a predicate \p and if it is false then creates a new llvm::StringError
/// and returns it.
#define RETURN_ERR_IF_NOT(p, str) \
#define RETURN_ERR_IF_NOT(p, ...) \
do { \
if (!(p)) { \
RETURN_ERR(str); \
RETURN_ERR(__VA_ARGS__); \
} \
} while (0)
} // end namespace glow
Expand Down
39 changes: 26 additions & 13 deletions lib/Importer/ONNXModelLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ llvm::Error ONNXModelLoader::setVersion(ONNX_NAMESPACE::ModelProto MP) {
opsetVersion_ = 0;
RETURN_ERR_IF_NOT(
irVersion_ >= 3,
"This ONNX model with ir_version < 3 is too old to be supported.");
"This ONNX model with ir_version < 3 is too old to be supported.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_ONNX_VERSION);
for (const auto &imp : MP.opset_import()) {
if (!imp.has_domain() || imp.domain() == "") {
opsetVersion_ = imp.version();
Expand All @@ -156,7 +157,8 @@ ONNXModelLoader::loadProto(google::protobuf::io::ZeroCopyInputStream &iStream) {
codedStream.SetTotalBytesLimit(MAX_PROTO_SIZE, MAX_PROTO_SIZE);
ONNX_NAMESPACE::ModelProto MP;
bool parseNet = MP.ParseFromCodedStream(&codedStream);
RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto");
RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto",
GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
return MP;
}

Expand All @@ -169,7 +171,8 @@ ONNXModelLoader::loadProto(const void *onnxModel, size_t onnxModelSize) {
llvm::Expected<ONNX_NAMESPACE::ModelProto>
ONNXModelLoader::loadProto(const std::string &filename) {
std::ifstream ff(filename, std::ios::in | std::ios::binary);
RETURN_ERR_IF_NOT(ff, "Can't find the model or network files.");
RETURN_ERR_IF_NOT(ff, "Can't find the model or network files.",
GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);

// TODO: intend to find a way to reuse the following function later
// for the text format onnx model:
Expand All @@ -181,7 +184,8 @@ ONNXModelLoader::loadProto(const std::string &filename) {
ONNX_NAMESPACE::ModelProto MP;
bool parseNet = google::protobuf::TextFormat::ParseFromString(str, &MP);

RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto");
RETURN_ERR_IF_NOT(parseNet, "Failed to parse ModelProto",
GlowErr::ErrorCode::MODEL_LOADER_INVALID_PROTOBUF);
return MP;
}

Expand Down Expand Up @@ -232,7 +236,8 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in,
std::istringstream inStream(in.raw_data(), std::stringstream::binary);
inStream.read(T->getUnsafePtr(), T->size() * sizeof(float));
} else {
RETURN_ERR("Unsupported Tensor format.");
RETURN_ERR("Unsupported Tensor format.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
}
} else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT64) {
T->reset(ElemKind::Int64ITy, dim);
Expand All @@ -247,7 +252,8 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in,
std::istringstream inStream(in.raw_data(), std::stringstream::binary);
inStream.read(T->getUnsafePtr(), T->size() * sizeof(int64_t));
} else {
RETURN_ERR("Unsupported Tensor format.");
RETURN_ERR("Unsupported Tensor format.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
}
} else if (in.data_type() == ONNX_NAMESPACE::TensorProto::INT32) {
// There are few cases when we will have int32 tensors. For example, the
Expand All @@ -264,10 +270,12 @@ static llvm::Error loadTensor(const ONNX_NAMESPACE::TensorProto &in,
std::istringstream inStream(in.raw_data(), std::stringstream::binary);
inStream.read(T->getUnsafePtr(), T->size() * sizeof(int32_t));
} else {
RETURN_ERR("Unsupported Tensor format.");
RETURN_ERR("Unsupported Tensor format.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
}
} else {
RETURN_ERR("Only float and index tensors are supported");
RETURN_ERR("Only float and index tensors are supported",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);
}
return llvm::Error::success();
}
Expand Down Expand Up @@ -307,7 +315,8 @@ llvm::Error ONNXModelLoader::loadConstant(const ONNX_NAMESPACE::NodeProto &op,

RETURN_ERR_IF_NOT(dict.at("value")->type() ==
ONNX_NAMESPACE::AttributeProto::TENSOR,
"Only Tensor type constants are supported.");
"Only Tensor type constants are supported.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_DATATYPE);

std::unique_ptr<Tensor> T(new Tensor());
RETURN_IF_ERR(loadTensor(dict.at("value")->t(), T.get()));
Expand Down Expand Up @@ -511,7 +520,8 @@ llvm::Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op,

// Glow doesn't support argmax output yet.
if (op.output_size() > 1) {
RETURN_ERR("Glow doesn't support argmax output yet.");
RETURN_ERR("Glow doesn't support argmax output yet.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
}
// Load the inputs:
NodeValue in;
Expand All @@ -529,7 +539,8 @@ llvm::Error ONNXModelLoader::loadPool(const ONNX_NAMESPACE::NodeProto &op,

if (in.dims().size() != 4 || kernels.size() != 2) {
// Glow only handles 2D pooling currently.
RETURN_ERR("Glow only handles 2D pooling currently.");
RETURN_ERR("Glow only handles 2D pooling currently.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_SHAPE);
}

auto *tr = G_.createTranspose(opName, in, NCHW2NHWC);
Expand Down Expand Up @@ -785,7 +796,8 @@ llvm::Error ONNXModelLoader::loadPad(const ONNX_NAMESPACE::NodeProto &op,
} else if (modeStr == "edge") {
mode = PaddingMode::EDGE;
} else {
RETURN_ERR("Pad: Invalid mode");
RETURN_ERR("Pad: Invalid mode",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_ATTRIBUTE);
}
}
float value = 0.f; // Default
Expand Down Expand Up @@ -873,7 +885,8 @@ llvm::Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
return loadPad(op, dict);
}

RETURN_ERR("Failed to load operator.");
RETURN_ERR("Failed to load operator.",
GlowErr::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
}

llvm::Error ONNXModelLoader::loadInitializers(ONNX_NAMESPACE::GraphProto &net) {
Expand Down
7 changes: 2 additions & 5 deletions lib/Support/Error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
namespace glow {
llvm::ExitOnError exitOnErr("Encountered an error, exiting.\n");

std::string addFileAndLineToError(llvm::StringRef str, llvm::StringRef file,
uint32_t line) {
return llvm::formatv("Error at file {0} line {1} \"{2}\"", file, line, str);
}

/// ID used by llvm::ErrorInfo::isA's dynamic typing.
uint8_t const GlowErr::ID = 0;
} // namespace glow