From 7502b820bb43fd1e737e5a9962eff391112aa4af Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 24 Jul 2025 16:51:02 +0200 Subject: [PATCH 1/9] Add TransitionLabelScorer --- src/Nn/LabelScorer/Makefile | 3 +- src/Nn/LabelScorer/TransitionLabelScorer.cc | 139 ++++++++++++++++++++ src/Nn/LabelScorer/TransitionLabelScorer.hh | 86 ++++++++++++ src/Nn/Module.cc | 8 ++ 4 files changed, 235 insertions(+), 1 deletion(-) create mode 100644 src/Nn/LabelScorer/TransitionLabelScorer.cc create mode 100644 src/Nn/LabelScorer/TransitionLabelScorer.hh diff --git a/src/Nn/LabelScorer/Makefile b/src/Nn/LabelScorer/Makefile index 5a42db46..274f80f0 100644 --- a/src/Nn/LabelScorer/Makefile +++ b/src/Nn/LabelScorer/Makefile @@ -21,7 +21,8 @@ LIBSPRINTLABELSCORER_O = \ $(OBJDIR)/FixedContextOnnxLabelScorer.o \ $(OBJDIR)/NoContextOnnxLabelScorer.o \ $(OBJDIR)/NoOpLabelScorer.o \ - $(OBJDIR)/ScoringContext.o + $(OBJDIR)/ScoringContext.o \ + $(OBJDIR)/TransitionLabelScorer.o # ----------------------------------------------------------------------------- diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc new file mode 100644 index 00000000..2b88c96b --- /dev/null +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -0,0 +1,139 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "TransitionLabelScorer.hh" + +#include + +#include "LabelScorer.hh" +#include "ScoringContext.hh" + +namespace Nn { + +const Core::ParameterFloat TransitionLabelScorer::paramLabelToLabelScore( + "label-to-label-score", + "Score for label-to-label transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramLabelLoopScore( + "label-loop-score", + "Score for label-loop transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramLabelToBlankScore( + "label-to-blank-score", + "Score for label-to-blank transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramBlankToLabelScore( + "blank-to-label-score", + "Score for blank-to-label transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramBlankLoopScore( + "blank-loop-score", + "Score for blank-loop transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramInitialLabelScore( + "initial-label-score", + "Score for initial-label transitions", + 0.0); + +const Core::ParameterFloat TransitionLabelScorer::paramInitialBlankScore( + "initial-blank-score", + "Score for initial-blank transitions", + 0.0); + +TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) + : Core::Component(config), + Precursor(config), + labelToLabelScore_(paramLabelToLabelScore(config)), + labelLoopScore_(paramLabelLoopScore(config)), + labelToBlankScore_(paramLabelToBlankScore(config)), + blankToLabelScore_(paramBlankToLabelScore(config)), + blankLoopScore_(paramBlankLoopScore(config)), + initialLabelScore_(paramInitialLabelScore(config)), + initialBlankScore_(paramInitialBlankScore(config)), + baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {} + +void TransitionLabelScorer::reset() { + baseLabelScorer_->reset(); +} + +void TransitionLabelScorer::signalNoMoreFeatures() { + baseLabelScorer_->signalNoMoreFeatures(); +} + +ScoringContextRef TransitionLabelScorer::getInitialScoringContext() { + return baseLabelScorer_->getInitialScoringContext(); +} + +ScoringContextRef TransitionLabelScorer::extendedScoringContext(LabelScorer::Request const& request) { + return baseLabelScorer_->extendedScoringContext(request); +} + +void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector const& activeContexts) { + baseLabelScorer_->cleanupCaches(activeContexts); +} + +void TransitionLabelScorer::addInput(DataView const& input) { + baseLabelScorer_->addInput(input); +} + +void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) { + baseLabelScorer_->addInputs(input, nTimesteps); +} + +std::optional TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { + auto result = baseLabelScorer_->computeScoreWithTime(request); + if (result) { + result->score += getTransitionScore(request.transitionType); + } + return result; +} + +std::optional TransitionLabelScorer::computeScoresWithTimes(std::vector const& requests) { + auto results = baseLabelScorer_->computeScoresWithTimes(requests); + if (results) { + for (size_t i = 0ul; i < requests.size(); ++i) { + results->scores[i] += getTransitionScore(requests[i].transitionType); + } + } + return results; +} + +LabelScorer::Score TransitionLabelScorer::getTransitionScore(LabelScorer::TransitionType transitionType) const { + switch (transitionType) { + case TransitionType::LABEL_TO_LABEL: + return labelToLabelScore_; + case TransitionType::LABEL_LOOP: + return labelLoopScore_; + case TransitionType::LABEL_TO_BLANK: + return labelToBlankScore_; + case TransitionType::BLANK_TO_LABEL: + return blankToLabelScore_; + case TransitionType::BLANK_LOOP: + return blankLoopScore_; + case TransitionType::INITIAL_LABEL: + return initialLabelScore_; + case TransitionType::INITIAL_BLANK: + return initialBlankScore_; + default: + error() << "Unknown transition type " << transitionType; + } + return 0; +} + +} // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh new file mode 100644 index 00000000..23bbc3cf --- /dev/null +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -0,0 +1,86 @@ +/** Copyright 2025 RWTH Aachen University. All rights reserved. + * + * Licensed under the RWTH ASR License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSITION_LABEL_SCORER_HH +#define TRANSITION_LABEL_SCORER_HH + +#include "LabelScorer.hh" + +namespace Nn { + +/* + * Label Scorer that adds predefined transition scores to the scores of an underlying base + * label scorer based on the transition type of each request. + * The score for each transition type is set via config parameters. + */ +class TransitionLabelScorer : public LabelScorer { + static const Core::ParameterFloat paramLabelToLabelScore; + static const Core::ParameterFloat paramLabelLoopScore; + static const Core::ParameterFloat paramLabelToBlankScore; + static const Core::ParameterFloat paramBlankToLabelScore; + static const Core::ParameterFloat paramBlankLoopScore; + static const Core::ParameterFloat paramInitialLabelScore; + static const Core::ParameterFloat paramInitialBlankScore; + +public: + using Precursor = LabelScorer; + + TransitionLabelScorer(Core::Configuration const& config); + virtual ~TransitionLabelScorer() = default; + + // Reset base scorer + void reset() override; + + // Forward signal to base scorer + void signalNoMoreFeatures() override; + + // Initial context of base scorer + ScoringContextRef getInitialScoringContext() override; + + // Extend context via base scorer + ScoringContextRef extendedScoringContext(Request const& request) override; + + // Clean up base scorer + void cleanupCaches(Core::CollapsedVector const& activeContexts) override; + + // Add input to base scorer + void addInput(DataView const& input) override; + + // Add inputs to sub-scorer + void addInputs(DataView const& input, size_t nTimesteps) override; + + // Compute score of base scorer and add transition score based on transition type of the request + std::optional computeScoreWithTime(Request const& request) override; + + // Compute scores of base scorer and add transition scores based on transition types of the requests + std::optional computeScoresWithTimes(std::vector const& requests) override; + +private: + Score labelToLabelScore_; + Score labelLoopScore_; + Score labelToBlankScore_; + Score blankToLabelScore_; + Score blankLoopScore_; + Score initialLabelScore_; + Score initialBlankScore_; + + Core::Ref baseLabelScorer_; + + Score getTransitionScore(TransitionType transitionType) const; +}; + +} // namespace Nn + +#endif // TRANSITION_LABEL_SCORER_HH diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index 62c1a553..45d7ca24 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -23,6 +23,7 @@ #include "LabelScorer/FixedContextOnnxLabelScorer.hh" #include "LabelScorer/NoContextOnnxLabelScorer.hh" #include "LabelScorer/NoOpLabelScorer.hh" +#include "LabelScorer/TransitionLabelScorer.hh" #include "Statistics.hh" #ifdef MODULE_NN @@ -120,6 +121,13 @@ Module_::Module_() [](Core::Configuration const& config) { return Core::ref(new FixedContextOnnxLabelScorer(config)); }); + + // Returns predefined scores based on the transition type of each score request + labelScorerFactory_.registerLabelScorer( + "transition", + [](Core::Configuration const& config) { + return Core::ref(new TransitionLabelScorer(config)); + }); }; Module_::~Module_() { From 74300017455cdd73aabe5327e610f9f6b0fb9ab9 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 24 Jul 2025 16:55:01 +0200 Subject: [PATCH 2/9] Rewrite docstring --- src/Nn/LabelScorer/TransitionLabelScorer.hh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index 23bbc3cf..9c3a48ab 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -21,9 +21,9 @@ namespace Nn { /* - * Label Scorer that adds predefined transition scores to the scores of an underlying base - * label scorer based on the transition type of each request. - * The score for each transition type is set via config parameters. + * This PR adds a new label scorer `TransitionLabelScorer` which wraps a base LabelScorer + * and adds predefined transition scores to the base scores depending on the transition type of each request. + * The transition scores are all individually specified as config parameters. */ class TransitionLabelScorer : public LabelScorer { static const Core::ParameterFloat paramLabelToLabelScore; From 2a6272e8e38b0ff85478caf6558e268da11becc7 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 24 Jul 2025 16:59:52 +0200 Subject: [PATCH 3/9] Clean up includes --- src/Nn/LabelScorer/TransitionLabelScorer.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index 2b88c96b..b4428391 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -12,13 +12,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "TransitionLabelScorer.hh" #include -#include "LabelScorer.hh" -#include "ScoringContext.hh" - namespace Nn { const Core::ParameterFloat TransitionLabelScorer::paramLabelToLabelScore( From 7e325e17750bf49390e70d4a230c510001e89173 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Thu, 24 Jul 2025 17:01:13 +0200 Subject: [PATCH 4/9] Rewrite docstring again --- src/Nn/LabelScorer/TransitionLabelScorer.hh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index 9c3a48ab..34c41b4f 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -21,8 +21,8 @@ namespace Nn { /* - * This PR adds a new label scorer `TransitionLabelScorer` which wraps a base LabelScorer - * and adds predefined transition scores to the base scores depending on the transition type of each request. + * This LabelScorer wraps a base LabelScorer and adds predefined transition scores + * to the base scores depending on the transition type of each request. * The transition scores are all individually specified as config parameters. */ class TransitionLabelScorer : public LabelScorer { From d2d78fee756b22f2ef4e962c7b222c1a46b0645e Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 24 Sep 2025 21:17:33 +0200 Subject: [PATCH 5/9] Refactor params to string list with compile time check --- .../FixedContextOnnxLabelScorer.cc | 2 +- src/Nn/LabelScorer/LabelScorer.hh | 7 +- src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc | 2 +- src/Nn/LabelScorer/TransitionLabelScorer.cc | 69 ++----------------- src/Nn/LabelScorer/TransitionLabelScorer.hh | 30 ++++---- src/Nn/Module.cc | 1 + .../LexiconfreeTimesyncBeamSearch.cc | 14 ++-- 7 files changed, 38 insertions(+), 87 deletions(-) diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc index 33d178f2..5ae580b6 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc @@ -123,7 +123,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore timeIncrement = not verticalLabelTransition_; break; default: - error() << "Unknown transition type " << request.transitionType; + error() << "Unknown transition type " << transitionTypeToIndex(request.transitionType); } // If context is not going to be modified, return the original one to avoid copying diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 6666d28d..b8c127a3 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -76,7 +76,7 @@ class LabelScorer : public virtual Core::Component, public: typedef Search::Score Score; - enum TransitionType { + enum class TransitionType { LABEL_TO_LABEL, LABEL_LOOP, LABEL_TO_BLANK, @@ -84,8 +84,13 @@ public: BLANK_LOOP, INITIAL_LABEL, INITIAL_BLANK, + sentinel // must remain at the end }; + static constexpr size_t transitionTypeToIndex(TransitionType transitionType) { + return static_cast(transitionType); + } + // Request for scoring or context extension struct Request { ScoringContextRef context; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index be5bc817..b487701a 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -221,7 +221,7 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( updateState = true; break; default: - error() << "Unknown transition type " << request.transitionType; + error() << "Unknown transition type " << transitionTypeToIndex(request.transitionType); } // If history is not going to be modified, return the original one diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index b4428391..283b7d3b 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -19,52 +19,15 @@ namespace Nn { -const Core::ParameterFloat TransitionLabelScorer::paramLabelToLabelScore( - "label-to-label-score", - "Score for label-to-label transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramLabelLoopScore( - "label-loop-score", - "Score for label-loop transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramLabelToBlankScore( - "label-to-blank-score", - "Score for label-to-blank transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramBlankToLabelScore( - "blank-to-label-score", - "Score for blank-to-label transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramBlankLoopScore( - "blank-loop-score", - "Score for blank-loop transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramInitialLabelScore( - "initial-label-score", - "Score for initial-label transitions", - 0.0); - -const Core::ParameterFloat TransitionLabelScorer::paramInitialBlankScore( - "initial-blank-score", - "Score for initial-blank transitions", - 0.0); - TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) : Core::Component(config), Precursor(config), - labelToLabelScore_(paramLabelToLabelScore(config)), - labelLoopScore_(paramLabelLoopScore(config)), - labelToBlankScore_(paramLabelToBlankScore(config)), - blankToLabelScore_(paramBlankToLabelScore(config)), - blankLoopScore_(paramBlankLoopScore(config)), - initialLabelScore_(paramInitialLabelScore(config)), - initialBlankScore_(paramInitialBlankScore(config)), - baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {} + transitionScores_(), + baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) { + for (size_t idx = 0ul; idx < paramNames.size(); ++idx) { + transitionScores_[idx] = Core::ParameterFloat(paramNames[idx], "", 0.0)(config); + } +} void TransitionLabelScorer::reset() { baseLabelScorer_->reset(); @@ -113,25 +76,7 @@ std::optional TransitionLabelScorer::computeScores } LabelScorer::Score TransitionLabelScorer::getTransitionScore(LabelScorer::TransitionType transitionType) const { - switch (transitionType) { - case TransitionType::LABEL_TO_LABEL: - return labelToLabelScore_; - case TransitionType::LABEL_LOOP: - return labelLoopScore_; - case TransitionType::LABEL_TO_BLANK: - return labelToBlankScore_; - case TransitionType::BLANK_TO_LABEL: - return blankToLabelScore_; - case TransitionType::BLANK_LOOP: - return blankLoopScore_; - case TransitionType::INITIAL_LABEL: - return initialLabelScore_; - case TransitionType::INITIAL_BLANK: - return initialBlankScore_; - default: - error() << "Unknown transition type " << transitionType; - } - return 0; + return transitionScores_[transitionTypeToIndex(transitionType)]; } } // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index 34c41b4f..b0d563f9 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -16,6 +16,9 @@ #ifndef TRANSITION_LABEL_SCORER_HH #define TRANSITION_LABEL_SCORER_HH +#include +#include + #include "LabelScorer.hh" namespace Nn { @@ -26,14 +29,6 @@ namespace Nn { * The transition scores are all individually specified as config parameters. */ class TransitionLabelScorer : public LabelScorer { - static const Core::ParameterFloat paramLabelToLabelScore; - static const Core::ParameterFloat paramLabelLoopScore; - static const Core::ParameterFloat paramLabelToBlankScore; - static const Core::ParameterFloat paramBlankToLabelScore; - static const Core::ParameterFloat paramBlankLoopScore; - static const Core::ParameterFloat paramInitialLabelScore; - static const Core::ParameterFloat paramInitialBlankScore; - public: using Precursor = LabelScorer; @@ -68,13 +63,18 @@ public: std::optional computeScoresWithTimes(std::vector const& requests) override; private: - Score labelToLabelScore_; - Score labelLoopScore_; - Score labelToBlankScore_; - Score blankToLabelScore_; - Score blankLoopScore_; - Score initialLabelScore_; - Score initialBlankScore_; + inline static constexpr auto paramNames = std::to_array({ + "label-to-label-score", + "label-loop-score", + "label-to-blank-score", + "blank-to-label-score", + "blank-loop-score", + "initial-label-score", + "initial-blank-score", + }); + static_assert(paramNames.size() == transitionTypeToIndex(TransitionType::sentinel), "paramNames must match number of TransitionType values"); + + std::array transitionScores_; Core::Ref baseLabelScorer_; diff --git a/src/Nn/Module.cc b/src/Nn/Module.cc index e3acf381..760d622a 100644 --- a/src/Nn/Module.cc +++ b/src/Nn/Module.cc @@ -135,6 +135,7 @@ Module_::Module_() "transition", [](Core::Configuration const& config) { return Core::ref(new TransitionLabelScorer(config)); + }); }; Module_::~Module_() { diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 81e17380..2ace9a4c 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -48,11 +48,11 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( score(extension.score), trace() { switch (extension.transitionType) { - case Nn::LabelScorer::INITIAL_BLANK: - case Nn::LabelScorer::INITIAL_LABEL: - case Nn::LabelScorer::LABEL_TO_LABEL: - case Nn::LabelScorer::LABEL_TO_BLANK: - case Nn::LabelScorer::BLANK_TO_LABEL: + case Nn::LabelScorer::TransitionType::INITIAL_BLANK: + case Nn::LabelScorer::TransitionType::INITIAL_LABEL: + case Nn::LabelScorer::TransitionType::LABEL_TO_LABEL: + case Nn::LabelScorer::TransitionType::LABEL_TO_BLANK: + case Nn::LabelScorer::TransitionType::BLANK_TO_LABEL: trace = Core::ref(new LatticeTrace( base.trace, extension.pron, @@ -60,8 +60,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( {extension.score, 0}, {})); break; - case Nn::LabelScorer::LABEL_LOOP: - case Nn::LabelScorer::BLANK_LOOP: + case Nn::LabelScorer::TransitionType::LABEL_LOOP: + case Nn::LabelScorer::TransitionType::BLANK_LOOP: // Copy base trace and update it trace = Core::ref(new LatticeTrace(*base.trace)); trace->sibling = {}; From 303fa46dde058083856c8c72e70d7108012b4339 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 24 Sep 2025 22:02:18 +0200 Subject: [PATCH 6/9] Remove transitionTypeToIndex function and revert associated changes --- src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc | 2 +- src/Nn/LabelScorer/LabelScorer.hh | 8 ++------ src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc | 2 +- src/Nn/LabelScorer/TransitionLabelScorer.cc | 2 +- src/Nn/LabelScorer/TransitionLabelScorer.hh | 3 ++- .../LexiconfreeTimesyncBeamSearch.cc | 14 +++++++------- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc index 5ae580b6..33d178f2 100644 --- a/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/FixedContextOnnxLabelScorer.cc @@ -123,7 +123,7 @@ ScoringContextRef FixedContextOnnxLabelScorer::extendedScoringContext(LabelScore timeIncrement = not verticalLabelTransition_; break; default: - error() << "Unknown transition type " << transitionTypeToIndex(request.transitionType); + error() << "Unknown transition type " << request.transitionType; } // If context is not going to be modified, return the original one to avoid copying diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index b8c127a3..d06555a5 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -76,7 +76,7 @@ class LabelScorer : public virtual Core::Component, public: typedef Search::Score Score; - enum class TransitionType { + enum TransitionType { LABEL_TO_LABEL, LABEL_LOOP, LABEL_TO_BLANK, @@ -84,13 +84,9 @@ public: BLANK_LOOP, INITIAL_LABEL, INITIAL_BLANK, - sentinel // must remain at the end + numTypes, // must remain at the end }; - static constexpr size_t transitionTypeToIndex(TransitionType transitionType) { - return static_cast(transitionType); - } - // Request for scoring or context extension struct Request { ScoringContextRef context; diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index b487701a..fd9695bf 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -221,7 +221,7 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( updateState = true; break; default: - error() << "Unknown transition type " << transitionTypeToIndex(request.transitionType); + error() << "Unknown transition type " << static_cast(request.transitionType); } // If history is not going to be modified, return the original one diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index 283b7d3b..091b92e0 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -76,7 +76,7 @@ std::optional TransitionLabelScorer::computeScores } LabelScorer::Score TransitionLabelScorer::getTransitionScore(LabelScorer::TransitionType transitionType) const { - return transitionScores_[transitionTypeToIndex(transitionType)]; + return transitionScores_[transitionType]; } } // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index b0d563f9..15081af9 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -63,6 +63,7 @@ public: std::optional computeScoresWithTimes(std::vector const& requests) override; private: + // List of names is set and size-checked against the TransitionType enum at compile time inline static constexpr auto paramNames = std::to_array({ "label-to-label-score", "label-loop-score", @@ -72,7 +73,7 @@ private: "initial-label-score", "initial-blank-score", }); - static_assert(paramNames.size() == transitionTypeToIndex(TransitionType::sentinel), "paramNames must match number of TransitionType values"); + static_assert(paramNames.size() == TransitionType::numTypes, "paramNames must match number of TransitionType values"); std::array transitionScores_; diff --git a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc index 2ace9a4c..81e17380 100644 --- a/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc +++ b/src/Search/LexiconfreeTimesyncBeamSearch/LexiconfreeTimesyncBeamSearch.cc @@ -48,11 +48,11 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( score(extension.score), trace() { switch (extension.transitionType) { - case Nn::LabelScorer::TransitionType::INITIAL_BLANK: - case Nn::LabelScorer::TransitionType::INITIAL_LABEL: - case Nn::LabelScorer::TransitionType::LABEL_TO_LABEL: - case Nn::LabelScorer::TransitionType::LABEL_TO_BLANK: - case Nn::LabelScorer::TransitionType::BLANK_TO_LABEL: + case Nn::LabelScorer::INITIAL_BLANK: + case Nn::LabelScorer::INITIAL_LABEL: + case Nn::LabelScorer::LABEL_TO_LABEL: + case Nn::LabelScorer::LABEL_TO_BLANK: + case Nn::LabelScorer::BLANK_TO_LABEL: trace = Core::ref(new LatticeTrace( base.trace, extension.pron, @@ -60,8 +60,8 @@ LexiconfreeTimesyncBeamSearch::LabelHypothesis::LabelHypothesis( {extension.score, 0}, {})); break; - case Nn::LabelScorer::TransitionType::LABEL_LOOP: - case Nn::LabelScorer::TransitionType::BLANK_LOOP: + case Nn::LabelScorer::LABEL_LOOP: + case Nn::LabelScorer::BLANK_LOOP: // Copy base trace and update it trace = Core::ref(new LatticeTrace(*base.trace)); trace->sibling = {}; From ddd75c7df1c994a019ee683047f1b9b2f326e84e Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 24 Sep 2025 22:08:59 +0200 Subject: [PATCH 7/9] Revert unnecessary static_cast --- src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc index fd9695bf..be5bc817 100644 --- a/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc +++ b/src/Nn/LabelScorer/StatefulOnnxLabelScorer.cc @@ -221,7 +221,7 @@ Core::Ref StatefulOnnxLabelScorer::extendedScoringContext( updateState = true; break; default: - error() << "Unknown transition type " << static_cast(request.transitionType); + error() << "Unknown transition type " << request.transitionType; } // If history is not going to be modified, return the original one From 5b89d0f779b6aee9a800d1f594ce15a15d6f54aa Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Tue, 30 Sep 2025 14:21:10 +0200 Subject: [PATCH 8/9] Move transition type string array to LabelScorer.hh --- src/Nn/LabelScorer/LabelScorer.hh | 11 +++++++++++ src/Nn/LabelScorer/TransitionLabelScorer.cc | 13 +++++-------- src/Nn/LabelScorer/TransitionLabelScorer.hh | 19 +------------------ 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index d06555a5..5997fe96 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -87,6 +87,17 @@ public: numTypes, // must remain at the end }; + inline static constexpr auto transitionTypeArray = std::to_array>({ + {"label-to-label", LABEL_TO_LABEL}, + {"label-loop", LABEL_LOOP}, + {"label-to-blank", LABEL_TO_BLANK}, + {"blank-to-label", BLANK_TO_LABEL}, + {"blank-loop", BLANK_LOOP}, + {"initial-label", INITIAL_LABEL}, + {"initial-blank", INITIAL_BLANK}, + }); + static_assert(transitionTypeArray.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); + // Request for scoring or context extension struct Request { ScoringContextRef context; diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index 091b92e0..c61309da 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -24,8 +24,9 @@ TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) Precursor(config), transitionScores_(), baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) { - for (size_t idx = 0ul; idx < paramNames.size(); ++idx) { - transitionScores_[idx] = Core::ParameterFloat(paramNames[idx], "", 0.0)(config); + for (auto const& [stringIdentifier, enumValue] : transitionTypeArray) { + auto paramName = std::string(stringIdentifier) + "-score"; + transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config); } } @@ -60,7 +61,7 @@ void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) std::optional TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) { auto result = baseLabelScorer_->computeScoreWithTime(request); if (result) { - result->score += getTransitionScore(request.transitionType); + result->score += transitionScores_[request.transitionType]; } return result; } @@ -69,14 +70,10 @@ std::optional TransitionLabelScorer::computeScores auto results = baseLabelScorer_->computeScoresWithTimes(requests); if (results) { for (size_t i = 0ul; i < requests.size(); ++i) { - results->scores[i] += getTransitionScore(requests[i].transitionType); + results->scores[i] += transitionScores_[requests[i].transitionType]; } } return results; } -LabelScorer::Score TransitionLabelScorer::getTransitionScore(LabelScorer::TransitionType transitionType) const { - return transitionScores_[transitionType]; -} - } // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.hh b/src/Nn/LabelScorer/TransitionLabelScorer.hh index 15081af9..a8bc4af0 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.hh +++ b/src/Nn/LabelScorer/TransitionLabelScorer.hh @@ -16,9 +16,6 @@ #ifndef TRANSITION_LABEL_SCORER_HH #define TRANSITION_LABEL_SCORER_HH -#include -#include - #include "LabelScorer.hh" namespace Nn { @@ -63,23 +60,9 @@ public: std::optional computeScoresWithTimes(std::vector const& requests) override; private: - // List of names is set and size-checked against the TransitionType enum at compile time - inline static constexpr auto paramNames = std::to_array({ - "label-to-label-score", - "label-loop-score", - "label-to-blank-score", - "blank-to-label-score", - "blank-loop-score", - "initial-label-score", - "initial-blank-score", - }); - static_assert(paramNames.size() == TransitionType::numTypes, "paramNames must match number of TransitionType values"); - - std::array transitionScores_; + std::unordered_map transitionScores_; Core::Ref baseLabelScorer_; - - Score getTransitionScore(TransitionType transitionType) const; }; } // namespace Nn From b9d919b326fca182f576b8b6a6ac10333ccbce93 Mon Sep 17 00:00:00 2001 From: Simon Berger Date: Wed, 1 Oct 2025 16:44:31 +0200 Subject: [PATCH 9/9] Move transitionTypeArray to protected space --- src/Nn/LabelScorer/LabelScorer.hh | 23 +++++++++++---------- src/Nn/LabelScorer/TransitionLabelScorer.cc | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/Nn/LabelScorer/LabelScorer.hh b/src/Nn/LabelScorer/LabelScorer.hh index 5997fe96..ed6072c0 100644 --- a/src/Nn/LabelScorer/LabelScorer.hh +++ b/src/Nn/LabelScorer/LabelScorer.hh @@ -87,17 +87,6 @@ public: numTypes, // must remain at the end }; - inline static constexpr auto transitionTypeArray = std::to_array>({ - {"label-to-label", LABEL_TO_LABEL}, - {"label-loop", LABEL_LOOP}, - {"label-to-blank", LABEL_TO_BLANK}, - {"blank-to-label", BLANK_TO_LABEL}, - {"blank-loop", BLANK_LOOP}, - {"initial-label", INITIAL_LABEL}, - {"initial-blank", INITIAL_BLANK}, - }); - static_assert(transitionTypeArray.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); - // Request for scoring or context extension struct Request { ScoringContextRef context; @@ -154,6 +143,18 @@ public: // Return two vectors: one vector with scores and one vector with times // By default loops over the single-request version virtual std::optional computeScoresWithTimes(std::vector const& requests); + +protected: + inline static constexpr auto transitionTypeArray_ = std::to_array>({ + {"label-to-label", LABEL_TO_LABEL}, + {"label-loop", LABEL_LOOP}, + {"label-to-blank", LABEL_TO_BLANK}, + {"blank-to-label", BLANK_TO_LABEL}, + {"blank-loop", BLANK_LOOP}, + {"initial-label", INITIAL_LABEL}, + {"initial-blank", INITIAL_BLANK}, + }); + static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values"); }; } // namespace Nn diff --git a/src/Nn/LabelScorer/TransitionLabelScorer.cc b/src/Nn/LabelScorer/TransitionLabelScorer.cc index c61309da..f7afac18 100644 --- a/src/Nn/LabelScorer/TransitionLabelScorer.cc +++ b/src/Nn/LabelScorer/TransitionLabelScorer.cc @@ -24,7 +24,7 @@ TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config) Precursor(config), transitionScores_(), baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) { - for (auto const& [stringIdentifier, enumValue] : transitionTypeArray) { + for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) { auto paramName = std::string(stringIdentifier) + "-score"; transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config); }