Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ class Partitioner {
void saveRepeatedConstants(const std::string& func_name);
void saveTailDictConstants(const std::string& func_name);
void saveRoPEConstants(const std::string& func_name);
void TEST_saveScaleConstants(const std::string& func_name);
void matchParameters(const std::string& func_name);
void matchResults(const std::string& func_name);
void createFunction(const std::string& func_name);
Expand Down Expand Up @@ -1436,6 +1437,23 @@ void Partitioner::saveRepeatedConstants(const std::string& func_name) {
}
}

void Partitioner::TEST_saveScaleConstants(const std::string& func_name) {
auto& func_group = all_functions.at(func_name);
auto& subgr_group = func_group.refs;
auto& model_group = func_group.mdls;

using CPtr = std::shared_ptr<ov::op::v0::Constant>;
std::vector<CPtr> to_keep;

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstScales>(std::ref(to_keep));
rewr.run_on_model(model_group.front());

for (auto&& const_to_keep : to_keep) {
func_group.consts_to_keep.insert(const_to_keep);
}
}

void Partitioner::saveTailDictConstants(const std::string& func_name) {
if (!part_ctx.use_host_gather_quant) {
// No need to preserve as constants
Expand Down Expand Up @@ -2507,6 +2525,7 @@ ov::npuw::Partitioning ov::npuw::getPartitioning(const std::shared_ptr<ov::Model
p.saveRepeatedConstants(func_group);
p.saveTailDictConstants(func_group);
p.saveRoPEConstants(func_group);
p.TEST_saveScaleConstants(func_group);
p.matchParameters(func_group);
p.matchResults(func_group);
p.matchRepeatedSubgraphs(func_group);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,33 @@ void untangleConst(std::shared_ptr<ov::Model> model) {
}
}

PreserveConstScales::PreserveConstScales(PreserveConstScales::Results to_keep) {
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({opp::any_input(), qcoeff});
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), qmuls});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();

auto matched_qcoeff = std::static_pointer_cast<ov::op::v0::Constant>(matched_node_qcoeff);
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);

auto qcoeff_shape = matched_qcoeff->output(0).get_shape();

if (qcoeff_shape.size() == 2 && matched_matmul->get_transpose_b()) {
to_keep.get().push_back(matched_qcoeff);
std::cout << "Keeping Scale " << matched_qcoeff->get_friendly_name() << " as constant" << std::endl;
return false; // root hasn't changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(qmm, "OptPreserveConstScales"), std::move(callback));
}

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ class ConvToMatmul : public ov::pass::MatcherPass {
// UntangleConst
void untangleConst(std::shared_ptr<ov::Model> model);

class PreserveConstScales : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::PreserveConstScales");

using CPtr = std::shared_ptr<ov::op::v0::Constant>;
using Results = std::reference_wrapper<std::vector<CPtr>>;

PreserveConstScales(Results to_keep);
};

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down
Loading