Skip to content
13 changes: 13 additions & 0 deletions src/Nn/LabelScorer/LabelScorer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public:
BLANK_LOOP,
INITIAL_LABEL,
INITIAL_BLANK,
numTypes, // must remain at the end
};

// Request for scoring or context extension
Expand Down Expand Up @@ -142,6 +143,18 @@ public:
// Return two vectors: one vector with scores and one vector with times
// By default loops over the single-request version
virtual std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests);

protected:
inline static constexpr auto transitionTypeArray_ = std::to_array<std::pair<std::string_view, TransitionType>>({
{"label-to-label", LABEL_TO_LABEL},
{"label-loop", LABEL_LOOP},
{"label-to-blank", LABEL_TO_BLANK},
{"blank-to-label", BLANK_TO_LABEL},
{"blank-loop", BLANK_LOOP},
{"initial-label", INITIAL_LABEL},
{"initial-blank", INITIAL_BLANK},
});
static_assert(transitionTypeArray_.size() == TransitionType::numTypes, "transitionTypeArray size must match number of TransitionType values");
};

} // namespace Nn
Expand Down
3 changes: 2 additions & 1 deletion src/Nn/LabelScorer/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ LIBSPRINTLABELSCORER_O = \
$(OBJDIR)/NoContextOnnxLabelScorer.o \
$(OBJDIR)/NoOpLabelScorer.o \
$(OBJDIR)/ScoringContext.o \
$(OBJDIR)/StatefulOnnxLabelScorer.o
$(OBJDIR)/StatefulOnnxLabelScorer.o \
$(OBJDIR)/TransitionLabelScorer.o

# -----------------------------------------------------------------------------

Expand Down
79 changes: 79 additions & 0 deletions src/Nn/LabelScorer/TransitionLabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/** Copyright 2025 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "TransitionLabelScorer.hh"

#include <Nn/Module.hh>

namespace Nn {

TransitionLabelScorer::TransitionLabelScorer(Core::Configuration const& config)
: Core::Component(config),
Precursor(config),
transitionScores_(),
baseLabelScorer_(Nn::Module::instance().labelScorerFactory().createLabelScorer(select("base-scorer"))) {
for (auto const& [stringIdentifier, enumValue] : transitionTypeArray_) {
auto paramName = std::string(stringIdentifier) + "-score";
transitionScores_[enumValue] = Core::ParameterFloat(paramName.c_str(), "", 0.0)(config);
}
}

void TransitionLabelScorer::reset() {
baseLabelScorer_->reset();
}

void TransitionLabelScorer::signalNoMoreFeatures() {
baseLabelScorer_->signalNoMoreFeatures();
}

ScoringContextRef TransitionLabelScorer::getInitialScoringContext() {
return baseLabelScorer_->getInitialScoringContext();
}

ScoringContextRef TransitionLabelScorer::extendedScoringContext(LabelScorer::Request const& request) {
return baseLabelScorer_->extendedScoringContext(request);
}

void TransitionLabelScorer::cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) {
baseLabelScorer_->cleanupCaches(activeContexts);
}

void TransitionLabelScorer::addInput(DataView const& input) {
baseLabelScorer_->addInput(input);
}

void TransitionLabelScorer::addInputs(DataView const& input, size_t nTimesteps) {
baseLabelScorer_->addInputs(input, nTimesteps);
}

std::optional<LabelScorer::ScoreWithTime> TransitionLabelScorer::computeScoreWithTime(LabelScorer::Request const& request) {
auto result = baseLabelScorer_->computeScoreWithTime(request);
if (result) {
result->score += transitionScores_[request.transitionType];
}
return result;
}

std::optional<LabelScorer::ScoresWithTimes> TransitionLabelScorer::computeScoresWithTimes(std::vector<LabelScorer::Request> const& requests) {
auto results = baseLabelScorer_->computeScoresWithTimes(requests);
if (results) {
for (size_t i = 0ul; i < requests.size(); ++i) {
results->scores[i] += transitionScores_[requests[i].transitionType];
}
}
return results;
}

} // namespace Nn
70 changes: 70 additions & 0 deletions src/Nn/LabelScorer/TransitionLabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/** Copyright 2025 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TRANSITION_LABEL_SCORER_HH
#define TRANSITION_LABEL_SCORER_HH

#include "LabelScorer.hh"

namespace Nn {

/*
* This LabelScorer wraps a base LabelScorer and adds predefined transition scores
* to the base scores depending on the transition type of each request.
* The transition scores are all individually specified as config parameters.
*/
class TransitionLabelScorer : public LabelScorer {
public:
using Precursor = LabelScorer;

TransitionLabelScorer(Core::Configuration const& config);
virtual ~TransitionLabelScorer() = default;

// Reset base scorer
void reset() override;

// Forward signal to base scorer
void signalNoMoreFeatures() override;

// Initial context of base scorer
ScoringContextRef getInitialScoringContext() override;

// Extend context via base scorer
ScoringContextRef extendedScoringContext(Request const& request) override;

// Clean up base scorer
void cleanupCaches(Core::CollapsedVector<ScoringContextRef> const& activeContexts) override;

// Add input to base scorer
void addInput(DataView const& input) override;

// Add inputs to sub-scorer
void addInputs(DataView const& input, size_t nTimesteps) override;

// Compute score of base scorer and add transition score based on transition type of the request
std::optional<ScoreWithTime> computeScoreWithTime(Request const& request) override;

// Compute scores of base scorer and add transition scores based on transition types of the requests
std::optional<ScoresWithTimes> computeScoresWithTimes(std::vector<Request> const& requests) override;

private:
std::unordered_map<TransitionType, Score> transitionScores_;

Core::Ref<LabelScorer> baseLabelScorer_;
};

} // namespace Nn

#endif // TRANSITION_LABEL_SCORER_HH
8 changes: 8 additions & 0 deletions src/Nn/Module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "LabelScorer/NoContextOnnxLabelScorer.hh"
#include "LabelScorer/NoOpLabelScorer.hh"
#include "LabelScorer/StatefulOnnxLabelScorer.hh"
#include "LabelScorer/TransitionLabelScorer.hh"
#include "Statistics.hh"

#ifdef MODULE_NN
Expand Down Expand Up @@ -128,6 +129,13 @@ Module_::Module_()
[](Core::Configuration const& config) {
return Core::ref(new StatefulOnnxLabelScorer(config));
});

// Returns predefined scores based on the transition type of each score request
labelScorerFactory_.registerLabelScorer(
"transition",
[](Core::Configuration const& config) {
return Core::ref(new TransitionLabelScorer(config));
});
};

Module_::~Module_() {
Expand Down