diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 6666d28d..ed6072c0 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -84,6 +84,7 @@ public: BLANK_LOOP, INITIAL_LABEL, INITIAL_BLANK, + numTypes, // must remain at the end }; // Request for scoring or context extension @@ -142,6 +143,18 @@ public: // Return two vectors: one vector with scores and one vector with times // By default loops over the single-request version virtual std::optional computeScoresWithTimes(std::vector const& requests); + +protected: + inline static constexpr auto transitionTypeArray_ = std::to_array>({ + {"label-to-label", LABEL_TO_LABEL}, + {"label-loop", LABEL_LOOP}, + {"label-to-blank", LABEL_TO_BLANK}, + {"blank-to-label", BLANK_TO_LABEL}, + {"blank-loop", BLANK_LOOP}, + {"initial-label", INITIAL_LABEL}, + {"initial-blank", INITIAL_BLANK}, + }); + static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); }; } // namespace Nn diff --git a/src/Nn/LabelScorer/Makefile b/src/Nn/LabelScorer/Makefile index 22d17757..2a634b91 100644 --- a/src/Nn/LabelScorer/Makefile +++ b/src/Nn/LabelScorer/Makefile @@ -22,7 +22,8 @@ LIBSPRINTLABELSCORER_O = \ $(OBJDIR)/NoContextOnnxLabelScorer.o \ $(OBJDIR)/NoOpLabelScorer.o \ $(OBJDIR)/ScoringContext.o \ - $(OBJDIR)/StatefulOnnxLabelScorer.o + $(OBJDIR)/StatefulOnnxLabelScorer.o \ + $(OBJDIR)/TransitionLabelScorer.o # ----------------------------------------------------------------------------- diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc new file mode 100644 index 00000000..f7afac18 --- /dev/null +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -0,0 +1,79 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TransitionLabelScorer.hh" + +#include + +namespace Nn { + +TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) + : Core::Component(config), + Precursor(config), + transitionScores_(), + baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) { + for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) { + auto paramName = std::string(stringIdentifier) + "-score"; + transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config); + } +} + +void TransitionLabelScorer::reset() { + baseLabelScorer_->reset(); +} + +void TransitionLabelScorer::signalNoMoreFeatures() { + baseLabelScorer_->signalNoMoreFeatures(); +} + +ScoringContextRef TransitionLabelScorer::getInitialScoringContext() { + return baseLabelScorer_->getInitialScoringContext(); +} + +ScoringContextRef TransitionLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { + return baseLabelScorer_->extendedScoringContext(request); +} + +void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + baseLabelScorer_->cleanupCaches(activeContexts); +} + +void TransitionLabelScorer::addInput(DataView const& input) { + baseLabelScorer_->addInput(input); +} + +void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { + baseLabelScorer_->addInputs(input, nTimesteps); +} + +std::optional TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { + auto result = baseLabelScorer_->computeScoreWithTime(request); + if (result) { + result->score += transitionScores_[request.transitionType]; + } + return result; +} + +std::optional TransitionLabelScorer::computeScoresWithTimes(std::vector const& requests) { + auto results = baseLabelScorer_->computeScoresWithTimes(requests); + if (results) { + for (size_t i = 0ul; i < requests.size(); ++i) { + results->scores[i] += transitionScores_[requests[i].transitionType]; + } + } + return results; +} + +} // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh new file mode 100644 index 00000000..a8bc4af0 --- /dev/null +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -0,0 +1,70 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSITION_LABEL_SCORER_HH +#define TRANSITION_LABEL_SCORER_HH + +#include "LabelScorer.hh" + +namespace Nn { + +/* + * This LabelScorer wraps a base LabelScorer and adds predefined transition scores + * to the base scores depending on the transition type of each request. + * The transition scores are all individually specified as config parameters. + */ +class TransitionLabelScorer : public LabelScorer { +public: + using Precursor = LabelScorer; + + TransitionLabelScorer(Core::Configuration const& config); + virtual ~TransitionLabelScorer() = default; + + // Reset base scorer + void reset() override; + + // Forward signal to base scorer + void signalNoMoreFeatures() override; + + // Initial context of base scorer + ScoringContextRef getInitialScoringContext() override; + + // Extend context via base scorer + ScoringContextRef extendedScoringContext(Request const& request) override; + + // Clean up base scorer + void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + + // Add input to base scorer + void addInput(DataView const& input) override; + + // Add inputs to sub-scorer + void addInputs(DataView const& input, size_t nTimesteps) override; + + // Compute score of base scorer and add transition score based on transition type of the request + std::optional computeScoreWithTime(Request const& request) override; + + // Compute scores of base scorer and add transition scores based on transition types of the requests + std::optional computeScoresWithTimes(std::vector const& requests) override; + +private: + std::unordered_map transitionScores_; + + Core::Ref baseLabelScorer_; +}; + +} // namespace Nn + +#endif // TRANSITION_LABEL_SCORER_HH diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index db673397..760d622a 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -24,6 +24,7 @@ #include "LabelScorer/NoContextOnnxLabelScorer.hh" #include "LabelScorer/NoOpLabelScorer.hh" #include "LabelScorer/StatefulOnnxLabelScorer.hh" +#include "LabelScorer/TransitionLabelScorer.hh" #include "Statistics.hh" #ifdef MODULE_NN @@ -128,6 +129,13 @@ Module_::Module_() [](Core::Configuration const& config) { return Core::ref(new StatefulOnnxLabelScorer(config)); }); + + // Returns predefined scores based on the transition type of each score request + labelScorerFactory_.registerLabelScorer( + "transition", + [](Core::Configuration const& config) { + return Core::ref(new TransitionLabelScorer(config)); + }); }; Module_::~Module_() {