Skip to content

Create BackendTransformerBase to host common functions used for backend lowering (#17074) #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions caffe2/opt/backend_cutting_test.cc
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
#include "caffe2/core/common.h"
#include "caffe2/opt/backend_cutting.h"
#include "caffe2/core/logging.h"
#include "caffe2/opt/backend_cutting.h"
#include "caffe2/utils/string_utils.h"

#include <gtest/gtest.h>

namespace {
using caffe2::StartsWith;
using caffe2::StartsWith;

void AddConv(caffe2::NetDef* net, int tick) {
auto* op = net->add_op();
op->set_type("MyConv");
op->add_input("N" + c10::to_string(tick));
op->add_input("W" + c10::to_string(tick));
op->add_input("b" + c10::to_string(tick));
op->add_output("N" + c10::to_string(tick + 1));
}
void AddConv(caffe2::NetDef* net, int tick) {
auto* op = net->add_op();
op->set_type("MyConv");
op->add_input("N" + c10::to_string(tick));
op->add_input("W" + c10::to_string(tick));
op->add_input("b" + c10::to_string(tick));
op->add_output("N" + c10::to_string(tick + 1));
}

bool Supports(const caffe2::OperatorDef& op) {
return StartsWith(op.type(), "MyConv") || StartsWith(op.type(), "MyRelu") ||
StartsWith(op.type(), "Concat");
}
bool Supports(const caffe2::OperatorDef& op) {
return StartsWith(op.type(), "MyConv") || StartsWith(op.type(), "MyRelu") ||
StartsWith(op.type(), "Concat");
}

caffe2::NetDef Transform(const caffe2::NetDef& net) {
caffe2::NetDef net_opt;
auto * op = net_opt.add_op();
op->set_type("BigOpt");
caffe2::NetDef Transform(const caffe2::NetDef& net) {
caffe2::NetDef net_opt;
auto* op = net_opt.add_op();
op->set_type("BigOpt");

for (const auto& i: net.external_input()) {
// Absorb the weights and bias
if (!StartsWith(i, "W") && !StartsWith(i, "b")) {
net_opt.add_external_input(i);
op->add_input(i);
}
}
for (const auto& i: net.external_output()) {
net_opt.add_external_output(i);
op->add_output(i);
for (const auto& i : net.external_input()) {
// Absorb the weights and bias
if (!StartsWith(i, "W") && !StartsWith(i, "b")) {
net_opt.add_external_input(i);
op->add_input(i);
}
return net_opt;
}
for (const auto& i : net.external_output()) {
net_opt.add_external_output(i);
op->add_output(i);
}
return net_opt;
}
} // namespace

// N0 -> MyConv -> N1
TEST(BackendCuttingTest, unit) {
Expand All @@ -56,7 +56,6 @@ TEST(BackendCuttingTest, unit) {
EXPECT_EQ(1, net_opt.external_output_size());
}


// X -> CopyIn -> MyConv -> MyConv -> CopyOut -> Y
TEST(BackendCuttingTest, line) {
caffe2::NetDef net;
Expand Down
122 changes: 122 additions & 0 deletions caffe2/opt/backend_transformer_base.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "caffe2/opt/backend_transformer_base.h"
#include "caffe2/onnx/onnx_exporter.h"
#include "caffe2/utils/proto_utils.h"

namespace caffe2 {

namespace {
void AnnotateOpIndex(NetDef* net) {
int i = 0;
for (auto& op : *(net->mutable_op())) {
AddArgument(kNetPos, i++, &op);
}
}
} // namespace

std::string BackendTransformerBase::getModelId(const NetDef& net) {
static std::atomic<size_t> seq_id{0};
auto model_id =
ArgumentHelper(net).GetSingleArgument<std::string>(kModelId, "");
if (model_id.empty()) {
model_id = "unnamed_" + c10::to_string(seq_id++);
}
return model_id;
}

TensorProto BackendTransformerBase::wrapShapeInfoIntoTensorProto(
const std::string& name,
const ShapeInfo& shape_info) {
TensorProto t;
t.set_name(name);
t.set_data_type(shape_info.shape.data_type());
for (const auto i : shape_info.shape.dims()) {
t.add_dims(i);
}
return t;
}

std::unordered_map<std::string, TensorShape>
BackendTransformerBase::ssaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
// Make sure weights do not contain output of any op.
for (const auto& op : pred_net->op()) {
for (const auto& output : op.output()) {
CAFFE_ENFORCE_EQ(
weights.count(output),
0,
"Weight ",
output,
" shouldn't appear in the output");
}
}
input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
// Annote the ops with net position
AnnotateOpIndex(pred_net);

// Need to add mapping for weights. This will be used to create new workspace
// with mapped weights.
for (const auto& w : weights) {
input_mapping_.emplace(w, w);
}

// Since we are going to create a mapped workspace, we need to make sure that
// the parent workspace has the mapped blob names. If the blobs don't exist
// (usually such blobs are input tensor names), we exclude them from mapping.
std::vector<std::string> exclude_mapping;
for (const auto kv : input_mapping_) {
reverse_input_mapping_.emplace(kv.second, kv.first);
if (!ws->HasBlob(kv.second)) {
exclude_mapping.emplace_back(kv.first);
}
}
for (const auto& i : exclude_mapping) {
input_mapping_.erase(i);
}
std::unordered_map<std::string, TensorShape> shape_hints_mapped;
for (const auto& kv : input_shape_hints) {
const auto it = reverse_input_mapping_.find(kv.first);
if (it != reverse_input_mapping_.end()) {
shape_hints_mapped.emplace(it->second, kv.second);
} else {
shape_hints_mapped.emplace(kv.first, kv.second);
}
}
return shape_hints_mapped;
}

ShapeInfoMap BackendTransformerBase::inferShapes(
Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& shape_hints_mapped,
const BoundShapeSpec& spec) {
ShapeInfoMap shape_map;
// Populate shapes from workplace
const std::vector<std::string> ws_blobs = ws->Blobs();
for (const auto& s : ws_blobs) {
auto shape_info = getShapeInfoFromBlob(ws->GetBlob(s));
if (shape_info.dim_type != ShapeInfo::DimType::UNKNOWN) {
shape_map[s] = shape_info;
}
}
for (const auto& kv : shape_hints_mapped) {
shape_map.emplace(
std::piecewise_construct,
std::forward_as_tuple(kv.first),
std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
}
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(*pred_net, shape_map);
const auto& out_map = eng.shape_info();

for (const auto& kv : out_map) {
shape_map.emplace(
std::piecewise_construct,
std::forward_as_tuple(kv.first),
std::forward_as_tuple(kv.second.dim_type, kv.second.shape));
}
return shape_map;
}
} // namespace caffe2
70 changes: 70 additions & 0 deletions caffe2/opt/backend_transformer_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once

#include "caffe2/core/common.h"
#include "caffe2/core/workspace.h"
#include "caffe2/opt/bound_shape_inferencer.h"
#include "caffe2/proto/caffe2_pb.h"

#include <string>
#include <unordered_map>
#include <vector>

namespace caffe2 {
namespace {
constexpr char kNetPos[] = "net_pos";
constexpr char kModelId[] = "model_id";
} // namespace

// This class contains some common functions for backend lowering and graph
// cutting
class BackendTransformerBase {
public:
BackendTransformerBase() {}
virtual ~BackendTransformerBase() {}

const std::unordered_map<std::string, std::string>& input_mapping() const {
return input_mapping_;
}

const std::unordered_map<std::string, std::string>& reverse_input_mapping()
const {
return reverse_input_mapping_;
}

virtual void transform(
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& shape_hints,
const std::unordered_set<int>& blacklisted_ops) = 0;

protected:
// get model ID from the NetDef
std::string getModelId(const NetDef& net);

// SSA rewrite the net and return name mapping
std::unordered_map<std::string, TensorShape> ssaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints);

// Wrap TensorShape into TensorProto
TensorProto wrapShapeInfoIntoTensorProto(
const std::string& name,
const ShapeInfo& shape_info);

// Do bound shape inference and collect shape infos
ShapeInfoMap inferShapes(
Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& shape_hints_mapped,
const BoundShapeSpec& spec);

// Input mapping of input name -> original input name
std::unordered_map<std::string, std::string> input_mapping_;

// Input mapping of orignal input name -> input name
std::unordered_map<std::string, std::string> reverse_input_mapping_;
};
} // namespace caffe2
Loading