Skip to content

Support Backend Specific Verifiers #3469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/Backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ Additionally, there are virtual functions that backends can override:
[below](#backend-specific-nodes-and-instructions-transformations) for more
information.

- `virtual bool verify(const Function &F) const;`

- Verifies that `Function &F` conforms to the backend-dependent graph constraints.

- `virtual bool verify(const IRFunction &IR) const;`

- Verifies that `IRFunction &IR` conforms to the backend-dependent graph constraints.

- `virtual bool shouldLower(const Node *N) const;`

- Allow the backend to prevent lowering for some `Node *N`. For example, if a
Expand Down
14 changes: 14 additions & 0 deletions include/glow/Backend/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ class Backend {
/// \returns whether the provided \p NI is supported by the backend.
virtual bool isOpSupported(const NodeInfo &NI) const = 0;

/// \returns whether all nodes inside \p F are supported.
bool checkAllNodesSupported(const Function &F) const;

/// \returns whether the provided \p F conforms to the backend-dependent graph
/// constraints. Giving the backend an opportunity to check that everything
/// conforms to its specific restrictions by overriding this function.
virtual bool verify(const Function &F) const;

/// \returns whether the provided \p IR conforms to the backend-dependent
/// graph constraints. Giving the backend an opportunity to check that
/// everything conforms to its specific restrictions by overriding this
/// function.
virtual bool verify(const IRFunction &IR) const;

/// \returns true if the supplied Node \N should be lowered. By default, all
/// Nodes are candidates for lowering.
virtual bool shouldLower(const Node *N) const { return true; }
Expand Down
48 changes: 48 additions & 0 deletions include/glow/Backend/BackendUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ class RuntimeBundle {
};
} // namespace runtime

/// Generates a struct named has_\p METHOD_NAME that looks for a method called
/// \p METHOD_NAME inside of ClassName with return type ReturnType
#define CLASS_CONTAINS_METHOD(METHOD_NAME) \
template <typename ClassName, typename ReturnType> \
struct has_##METHOD_NAME { \
private: \
template <typename T> \
static constexpr auto check(T *) -> \
typename std::is_same<decltype(std::declval<T>().METHOD_NAME()), \
ReturnType>::type; \
template <typename> static constexpr std::false_type check(...); \
typedef decltype(check<ClassName>(0)) type; \
\
public: \
static constexpr bool value = type::value; \
};

/// Use template meta-programming to check if typename ClassName contains
/// getFusedActivation() method. Below Generates a struct named
/// has_getFusedActivation that looks for said method.
CLASS_CONTAINS_METHOD(getFusedActivation)

/// If \p PH is an output placeholder in the Function \p F,
/// \returns true.
/// This is determined by checking if the PH has a user which uses the PH as an
Expand All @@ -173,6 +195,32 @@ bool isOutput(const Placeholder *PH, const IRFunction &F);
/// by the current function.
bool isInput(const Placeholder *PH, const IRFunction &F);

/// If \p N does not have fused activation \returns true
template <typename T,
std::enable_if_t<!has_getFusedActivation<T, FusedActivation>::value,
int> = 0>
bool checkNoFusion(const T &N) {
return true;
}

/// If \p N does not have fused activation \returns true
template <typename T,
std::enable_if_t<has_getFusedActivation<T, FusedActivation>::value,
int> = 0>
bool checkNoFusion(const T &N) {
if (N.getFusedActivation() != FusedActivation::NONE) {
report("Glow backend does not support fused Activations.");
return false;
}
return true;
}

/// If \p N does not have fused activation \returns true
bool checkNoFusionForNode(const Node &N);

/// If \p I does not have fused activation \returns true
bool checkNoFusionForInstr(const Instruction &I);

/// Contains information for placeholder during allocation.
struct PlaceholderInputOutputInfo {
/// The placeholder address.
Expand Down
8 changes: 8 additions & 0 deletions include/glow/Support/Error.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class GlowErr final : public llvm::ErrorInfo<GlowErr> {
MODEL_WRITER_INVALID_FILENAME,
// Model writer cannot serialize graph to the file.
MODEL_WRITER_SERIALIZATION_ERROR,
// Compilation error; IR unsupported after generation.
COMPILE_UNSUPPORTED_IR_AFTER_GENERATE,
// Compilation error; IR unsupported after optimization.
COMPILE_UNSUPPORTED_IR_AFTER_OPTIMIZE,
};

/// GlowErr is not convertable to std::error_code. This is included for
Expand Down Expand Up @@ -164,6 +168,10 @@ class GlowErr final : public llvm::ErrorInfo<GlowErr> {
return "MODEL_WRITER_INVALID_FILENAME";
case ErrorCode::MODEL_WRITER_SERIALIZATION_ERROR:
return "MODEL_WRITER_SERIALIZATION_ERROR";
case ErrorCode::COMPILE_UNSUPPORTED_IR_AFTER_GENERATE:
return "COMPILE_UNSUPPORTED_IR_AFTER_GENERATE";
case ErrorCode::COMPILE_UNSUPPORTED_IR_AFTER_OPTIMIZE:
return "COMPILE_UNSUPPORTED_IR_AFTER_OPTIMIZE";
};

llvm_unreachable("unsupported ErrorCode");
Expand Down
19 changes: 19 additions & 0 deletions lib/Backend/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,25 @@ void Backend::autoInstrument(TraceInfo &traceInfo, IRFunction *IR) const {
IR->pushInstr(new TraceEventInst("end_trace", backingWeight, index));
}

bool Backend::checkAllNodesSupported(const Function &F) const {
bool allSupported = true;
for (const Node &N : F.getNodes()) {
if (!isOpSupported(N)) {
allSupported = false;
report("Unsupported node found while compiling Function " +
F.getName().str() + " for backend " + getBackendName() + ": " +
N.getDebugDesc());
}
}
return allSupported;
}

bool Backend::verify(const Function &F) const {
return checkAllNodesSupported(F);
}

bool Backend::verify(const IRFunction &IR) const { return true; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to make this pure virtual to force BE to deal with it, versus thinking it does something when it doesn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Would like to turn them into pure virtual, but, I don't want to break any backends when doing this commit + I will intend to create a beginner / good first issue for someone to do write a verifier for the Habana backend.


FunctionPassPipeline Backend::getOptimizationPipeline() const {
return createDefaultGraphOptimizationPassPipeline();
};
39 changes: 39 additions & 0 deletions lib/Backend/BackendUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,45 @@ bool isInput(const Placeholder *PH,
return false;
}

/// If \p N does not have fused activation \returns true
bool checkNoFusionForNode(const Node &N) {
#define DEF_NODE(CLASS, NAME) \
case Kinded::Kind::CLASS##Kind: { \
const CLASS *CI = llvm::cast<CLASS>(&N); \
return checkNoFusion(*CI); \
break; \
}
switch (N.getKind()) {
#include "glow/AutoGenNodes.def"
default:
llvm_unreachable("Invalid node.");
}
return true;
}

/// If \p I does not have fused activation \returns true
bool checkNoFusionForInstr(const Instruction &I) {
#define DEF_VALUE(CLASS, NAME)
#define DEF_INSTR(CLASS, NAME) \
case Kinded::Kind::CLASS##Kind: { \
const CLASS *CI = llvm::cast<CLASS>(&I); \
return checkNoFusion(*CI); \
break; \
}
#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) \
case Kinded::Kind::CLASS##Kind: { \
const CLASS *CI = llvm::cast<CLASS>(&I); \
return checkNoFusion(*CI); \
break; \
}
switch (I.getKind()) {
#include "glow/AutoGenInstr.def"
default:
llvm_unreachable("Invalid instruction.");
}
return true;
}

template <typename FUN, typename ARR>
ContiguousPlaceholders getContiguousPlaceHolder(const ARR &holders,
const FUN &F) {
Expand Down
107 changes: 107 additions & 0 deletions lib/Backends/Interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "glow/Graph/Graph.h"
#include "glow/Graph/Nodes.h"
#include "glow/IR/IR.h"
#include "glow/IR/Instrs.h"
#include "glow/Optimizer/IROptimizer/IROptimizer.h"

using namespace glow;
Expand Down Expand Up @@ -467,6 +468,112 @@ bool Interpreter::isOpSupported(const NodeInfo &NI) const {
}
}

/// Use template meta-programming to check if typename ClassName contains
/// has_getLayout() method. Below Generates a struct named has_getLayout that
/// looks for said method.
CLASS_CONTAINS_METHOD(getLayout)

template <typename T, std::enable_if_t<
!has_getLayout<T, ConvolutionLayout>::value, int> = 0>
static bool checkLayout(const T &I) {
return true;
}

template <typename T,
std::enable_if_t<has_getLayout<T, ConvolutionLayout>::value, int> = 0>
static bool checkLayout(const T &I) {
if (I.getLayout() != NHWC) {
report("Glow Interpreter supports only NHWC");
return false;
}
return true;
}

static bool checkLayoutForNode(const Node &N) {
#define DEF_NODE(CLASS, NAME) \
case Kinded::Kind::CLASS##Kind: { \
const CLASS *CI = llvm::cast<CLASS>(&N); \
return checkLayout(*CI); \
break; \
}
switch (N.getKind()) {
#include "glow/AutoGenNodes.def"
default:
llvm_unreachable("Invalid instruction.");
}
return true;
}

bool Interpreter::verify(const Function &F) const {
if (!checkAllNodesSupported(F)) {
return false;
}
for (const Node &N : F.getNodes()) {
if (!checkLayoutForNode(N)) {
return false;
}
if (!checkNoFusionForNode(N)) {
return false;
}
switch (N.getKind()) {
case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: {
auto *CQCI = llvm::cast<ChannelwiseQuantizedConvolutionNode>(&N);
if (!CQCI->getGroupwise()) {
report("Glow Interpreter does not support Non-groupwise variant");
return false;
}
continue;
}

default:
continue;
}
}
return true;
}

static bool checkLayoutForInstr(const Instruction &I) {
#define DEF_VALUE(CLASS, NAME)
#define DEF_INSTR(CLASS, NAME) \
case Kinded::Kind::CLASS##Kind: { \
const CLASS *CI = llvm::cast<CLASS>(&I); \
return checkLayout(*CI); \
break; \
}
#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME)
switch (I.getKind()) {
#include "glow/AutoGenInstr.def"
default:
llvm_unreachable("Invalid instruction.");
}
return true;
}

bool Interpreter::verify(const IRFunction &IR) const {
for (const auto &I : IR.getInstrs()) {
if (!checkNoFusionForInstr(I)) {
return false;
}
if (!checkLayoutForInstr(I)) {
return false;
}
switch (I.getKind()) {
case Kinded::Kind::ChannelwiseQuantizedConvolutionInstKind: {
auto *CQCI = llvm::cast<ChannelwiseQuantizedConvolutionInst>(&I);
if (!CQCI->getGroupwise()) {
report("Glow Interpreter does not support Non-groupwise variant");
return false;
}
continue;
}

default:
continue;
}
}
return true;
}

bool Interpreter::shouldLower(const Node *N) const {
switch (N->getKind()) {
case Kinded::Kind::ConvolutionNodeKind:
Expand Down
3 changes: 3 additions & 0 deletions lib/Backends/Interpreter/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class Interpreter final : public BackendUsingGlowIR {

bool isOpSupported(const NodeInfo &NI) const override;

bool verify(const Function &F) const override;
bool verify(const IRFunction &IR) const override;

bool shouldLower(const Node *N) const override;

/// @}
Expand Down
11 changes: 0 additions & 11 deletions lib/Backends/Interpreter/InterpreterNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,6 @@ void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl(
}

void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) {
assert(I->getLayout() == NHWC &&
"Glow Interpreter supports only NHWC Convolutions");
assert(I->getFusedActivation() == FusedActivation::NONE &&
"Glow Interpreter does not support fused Activations.");
auto kernelSizes = I->getKernels();
auto pads = I->getPads();
auto strides = I->getStrides();
Expand All @@ -307,8 +303,6 @@ void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) {

void BoundInterpreterFunction::fwdConvolutionGradInst(
const ConvolutionGradInst *I) {
assert(I->getLayout() == NHWC &&
"Glow Interpreter supports only NHWC Convolutions");
auto inW = getWeightHandle(I->getSrc());
auto inG = getWeightHandle(I->getSrcGrad());
auto outG = getWeightHandle(I->getDestGrad());
Expand Down Expand Up @@ -593,8 +587,6 @@ void BoundInterpreterFunction::fwdConvolution3DGradInst(

void BoundInterpreterFunction::fwdChannelwiseQuantizedConvolutionInst(
const ChannelwiseQuantizedConvolutionInst *I) {
assert(I->getGroupwise() && "Non-groupwise not supported");

using AccumulatorTy = int32_t;

auto inW = getWeightHandle<int8_t>(I->getSrc());
Expand Down Expand Up @@ -759,7 +751,6 @@ static void fwdMaxPool(Tensor *inW, Tensor *outW, Tensor *argmaxW,
}

void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) {
assert(I->getLayout() == NHWC && "Glow Interpreter supports only NHWC Pools");
auto inW = getTensor(I->getSrc());
auto outW = getTensor(I->getDest());

Expand All @@ -777,7 +768,6 @@ void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) {

void BoundInterpreterFunction::fwdMaxPoolWithArgmaxInst(
const MaxPoolWithArgmaxInst *I) {
assert(I->getLayout() == NHWC && "Glow Interpreter supports only NHWC Pools");
auto inW = getTensor(I->getSrc());
auto outW = getTensor(I->getDest());
auto argmaxW = getTensor(I->getArgmax());
Expand Down Expand Up @@ -896,7 +886,6 @@ void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) {
}

void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) {
assert(I->getLayout() == NHWC && "Glow Interpreter supports only NHWC Pools");
if (I->getSrc()->getType()->isQuantizedType()) {
fwdAvgPoolInstI8Impl(I);
return;
Expand Down
Loading