Skip to content

[STABLE FORCED_ALIGN ABI PORT] Remove Accessor use in forced_align #4022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2062dc7
Make alphas_a standard C array
samanklesaria Aug 5, 2025
e70113c
Convert backptr to standard array
samanklesaria Aug 5, 2025
4039399
Create Accessor class
samanklesaria Aug 5, 2025
b733629
Add MutAccessor
samanklesaria Aug 5, 2025
9beb34a
Fix multidimensional indexing bug
samanklesaria Aug 5, 2025
11d1e21
Use strides rather than computing standard strides from dims
samanklesaria Aug 5, 2025
b47c053
Merge Accessor and MutAccessor
samanklesaria Aug 6, 2025
7a94b04
Move Accessor to its own file and add tests
samanklesaria Aug 6, 2025
75d246a
Add comment about original indexing
samanklesaria Aug 6, 2025
30ed519
Add requested comment about scalar_t
samanklesaria Aug 6, 2025
be13f64
WIP
samanklesaria Aug 6, 2025
258ca00
Merge branch 'main' into forced_align_accessors
samanklesaria Aug 6, 2025
77fd1ad
Use stable tensors throughout forced_align code
samanklesaria Aug 6, 2025
ced6124
Free alphas_a array
samanklesaria Aug 7, 2025
d27a416
Merge branch 'stable_forced_align' into forced_align_backptr
samanklesaria Aug 7, 2025
71ce212
Free backPtr_a
samanklesaria Aug 7, 2025
eb50150
Merge branch 'forced_align_backptr' into forced_align_accessors
samanklesaria Aug 7, 2025
9629864
Fix merge conflict
samanklesaria Aug 7, 2025
847b726
Correct dimensionality of path variable
samanklesaria Aug 7, 2025
2663def
Use 1d indexing in original layout for alphas_a
samanklesaria Aug 8, 2025
5fa467d
Merge branch 'stable_forced_align' into forced_align_backptr
samanklesaria Aug 8, 2025
724606a
Merge branch 'forced_align_backptr' into forced_align_accessors
samanklesaria Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(
lfilter.cpp
overdrive.cpp
utils.cpp
accessor_tests.cpp
)

set(
Expand Down
46 changes: 46 additions & 0 deletions src/libtorchaudio/accessor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include <torch/csrc/stable/tensor.h>
#include <type_traits>
#include <cstdarg>

using torch::stable::Tensor;

template<unsigned int k, typename T, bool IsConst = true>
class Accessor {
int64_t strides[k];
T *data;

public:
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;

Accessor(tensor_type tensor) {
data = (T*)tensor.template data_ptr();
for (unsigned int i = 0; i < k; i++) {
strides[i] = tensor.stride(i);
}
}

T index(...) {
va_list args;
va_start(args, k);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
return data[ix];
}

template<bool C = IsConst>
typename std::enable_if<!C, void>::type set_index(T value, ...) {
va_list args;
va_start(args, value);
int64_t ix = 0;
for (unsigned int i = 0; i < k; i++) {
ix += strides[i] * va_arg(args, int);
}
va_end(args);
data[ix] = value;
}
};
46 changes: 46 additions & 0 deletions src/libtorchaudio/accessor_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <libtorchaudio/accessor.h>
#include <cstdint>
#include <torch/torch.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/library.h>

namespace torchaudio {

namespace accessor_tests {

using namespace std;
using torch::stable::Tensor;

bool test_accessor(const Tensor tensor) {
int64_t* data_ptr = (int64_t*)tensor.data_ptr();
auto accessor = Accessor<3, int64_t>(tensor);
for (unsigned int i = 0; i < tensor.size(0); i++) {
for (unsigned int j = 0; j < tensor.size(1); j++) {
for (unsigned int k = 0; k < tensor.size(2); k++) {
auto check = *(data_ptr++) == accessor.index(i, j, k);
if (!check) {
return false;
}
}
}
}
return true;
}

void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
auto result = test_accessor(std::move(t1));
stack[0] = from(result);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"_test_accessor(Tensor log_probs) -> bool");
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_test_accessor", &boxed_test_accessor);
}

}
}
191 changes: 116 additions & 75 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <libtorchaudio/accessor.h>
#include <torch/headeronly/util/Half.h>


using namespace std;

namespace torchaudio {
namespace alignment {
namespace cpu {

using torch::stable::Tensor;

// Inspired from
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
template <typename scalar_t, at::ScalarType target_scalar_type>
template <typename scalar_t, typename target_t>
void forced_align_impl(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const int64_t blank,
torch::Tensor& paths) {
const Tensor logProbs,
const Tensor targets,
target_t blank,
Tensor paths) {
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
const auto batchIndex =
0; // TODO: support batch version and use the real batch index
const auto T = logProbs.size(1);
const auto L = targets.size(1);
const auto S = 2 * L + 1;
torch::Tensor alphas = torch::empty(
{2, S},
torch::TensorOptions()
.device(logProbs.device())
.dtype(logProbs.dtype()))
.fill_(kNegInfinity);
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
auto targets_a = targets.accessor<target_t, 2>();
auto paths_a = paths.accessor<target_t, 2>();
auto alphas_a = alphas.accessor<scalar_t, 2>();
auto backPtr_a = backPtr.accessor<int8_t, 2>();

auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype()
for (int i = 0; i < 2 * S; i++) {
alphas_a[i] = kNegInfinity;
}

auto backPtr_a = new int8_t[T * S];
for (int i = 0; i < T * S; i++) {
backPtr_a[i] = -1;
}

auto logProbs_a = Accessor<3, scalar_t, true>(logProbs);
auto targets_a = Accessor<2, target_t, true>(targets);
auto paths_a = Accessor<2, target_t, false>(paths);
auto R = 0;
for (auto i = 1; i < L; i++) {
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) {
++R;
}
}
Expand All @@ -51,22 +62,23 @@ void forced_align_impl(
auto start = T - (L + R) > 0 ? 0 : 1;
auto end = (S == 1) ? 1 : 2;
for (auto i = start; i < end; i++) {
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
alphas_a[i] = logProbs_a.index(batchIndex,0,labelIdx);

}
for (auto t = 1; t < T; t++) {
if (T - t <= L + R) {
if ((start % 2 == 1) &&
targets_a[batchIndex][start / 2] !=
targets_a[batchIndex][start / 2 + 1]) {
targets_a.index(batchIndex, start / 2) !=
targets_a.index(batchIndex, start / 2 + 1)) {
start = start + 1;
}
start = start + 1;
}
if (t <= L + R) {
if (end % 2 == 0 && end < 2 * L &&
targets_a[batchIndex][end / 2 - 1] !=
targets_a[batchIndex][end / 2]) {
targets_a.index(batchIndex, end / 2 - 1) !=
targets_a.index(batchIndex, end / 2)) {
end = end + 1;
}
end = end + 1;
Expand All @@ -75,72 +87,76 @@ void forced_align_impl(
auto curIdxOffset = t % 2;
auto prevIdxOffset = (t - 1) % 2;
for (auto j = 0; j < S; ++j) {
alphas_a[curIdxOffset][j] = -std::numeric_limits<scalar_t>::infinity();
alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t>::infinity(); // alphas_a[curIdxOffset][j]
}
if (start == 0) {
alphas_a[curIdxOffset][0] =
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
backPtr_a[t][0] = 0;
alphas_a[curIdxOffset * S] =
alphas_a[prevIdxOffset * S] + logProbs_a.index(batchIndex, t, blank);
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
startloop += 1;
}

for (auto i = startloop; i < end; i++) {
auto x0 = alphas_a[prevIdxOffset][i];
auto x1 = alphas_a[prevIdxOffset][i - 1];
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
auto x2 = -std::numeric_limits<scalar_t>::infinity();

auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);

// In CTC, the optimal path may optionally chose to skip a blank label.
// x2 represents skipping a letter, and can only happen if we're not
// currently on a blank_label, and we're not on a repeat letter
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
if (i % 2 != 0 && i != 1 &&
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
x2 = alphas_a[prevIdxOffset][i - 2];
targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) {
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
}
scalar_t result = 0.0;
if (x2 > x1 && x2 > x0) {
result = x2;
backPtr_a[t][i] = 2;
backPtr_a[t * S + i] = 2; // backPtr_a[t][i] = 2
} else if (x1 > x0 && x1 > x2) {
result = x1;
backPtr_a[t][i] = 1;
backPtr_a[t * S + i] = 1; // backPtr_a[t][i] = 1
} else {
result = x0;
backPtr_a[t][i] = 0;
backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0
}
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];

alphas_a[curIdxOffset * S + i] = result + logProbs_a.index(batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i]
}
}
auto idx1 = (T - 1) % 2;
auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2;
auto ltrIdx = alphas_a[S * idx1 + S - 1] >
alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
delete[] alphas_a;
// path stores the token index for each time step after force alignment.
for (auto t = T - 1; t > -1; t--) {
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
paths_a[batchIndex][t] = lbl_idx;
ltrIdx -= backPtr_a[t][ltrIdx];
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2);
paths_a.set_index(lbl_idx, batchIndex, t);
ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ltrIdx]
}
delete[] backPtr_a;
}

std::tuple<torch::Tensor, torch::Tensor> compute(
const torch::Tensor& logProbs,
const torch::Tensor& targets,
const torch::Tensor& inputLengths,
const torch::Tensor& targetLengths,
std::tuple<Tensor, Tensor> compute(
const Tensor& logProbs,
const Tensor& targets,
const Tensor& inputLengths,
const Tensor& targetLengths,
const int64_t blank) {
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
TORCH_CHECK(
logProbs.device() == targets.device(),
logProbs.get_device() == targets.get_device(),
"log_probs and targets need to be on the same device");
TORCH_CHECK(
logProbs.dtype() == torch::kFloat64 ||
logProbs.dtype() == torch::kFloat32 ||
logProbs.dtype() == torch::kFloat16,
logProbs.dtype() == aoti_torch_dtype_float64() ||
logProbs.dtype() == aoti_torch_dtype_float32() ||
logProbs.dtype() == aoti_torch_dtype_float16(),
"log_probs must be float64, float32 or float16 (half) type");
TORCH_CHECK(
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
targets.dtype() == aoti_torch_dtype_int32() || targets.dtype() == aoti_torch_dtype_int64(),
"targets must be int32 or int64 type");
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
Expand All @@ -163,39 +179,64 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
blank >= 0 && blank < logProbs.size(-1),
"blank must be within [0, num classes)");

TORCH_CHECK(
logProbs.size(1) == at::max(inputLengths).item().toInt(),
"input length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(targetLengths).item().toInt(),
"target length mismatch");
// TODO: Requires port of `max` and `item` operators.
// TORCH_CHECK(
// logProbs.size(1) == at::max(inputLengths).item().toInt(),
// "input length mismatch");
// TORCH_CHECK(
// targets.size(1) == at::max(targetLengths).item().toInt(),
// "target length mismatch");

const auto B = logProbs.size(0);
const auto T = logProbs.size(1);
auto paths = torch::zeros(
{B, T},
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
logProbs.scalar_type(), "forced_align_impl", [&] {
if (targets.scalar_type() == torch::kInt64) {
forced_align_impl<scalar_t, torch::kInt64>(
logProbs, targets, blank, paths);
} else {
forced_align_impl<scalar_t, torch::kInt32>(
logProbs, targets, blank, paths);
}
});

int64_t paths_size[2] = {B, T};
int64_t paths_stride[2] = {T, 1};
AtenTensorHandle paths_h;
int32_t targets_device;
aoti_torch_get_device_type(targets.get(), &targets_device);
aoti_torch_empty_strided(2, paths_size, paths_stride, targets.dtype(), targets_device, targets.get_device(), &paths_h);
auto paths = Tensor(paths_h);


if (targets.dtype() == aoti_torch_dtype_int64()) {
if (logProbs.dtype() == aoti_torch_dtype_float64()) {
forced_align_impl<double, int64_t>(logProbs, targets, blank, paths);
} else if (logProbs.dtype() == aoti_torch_dtype_float32()) {
forced_align_impl<float, int64_t>(logProbs, targets, blank, paths);
} else if (logProbs.dtype() == aoti_torch_dtype_float16()) {
forced_align_impl<c10::Half, int64_t>(logProbs, targets, blank, paths);
}
} else if (targets.dtype() == aoti_torch_dtype_int32()) {
if (logProbs.dtype() == aoti_torch_dtype_float64()) {
forced_align_impl<double, int32_t>(logProbs, targets, blank, paths);
} else if (logProbs.dtype() == aoti_torch_dtype_float32()) {
forced_align_impl<float, int32_t>(logProbs, targets, blank, paths);
} else if (logProbs.dtype() == aoti_torch_dtype_float16()) {
forced_align_impl<c10::Half, int32_t>(logProbs, targets, blank, paths);
}
}
return std::make_tuple(
paths,
logProbs
);
}


void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor t1(to<AtenTensorHandle>(stack[0]));
Tensor t2(to<AtenTensorHandle>(stack[1]));
Tensor t3(to<AtenTensorHandle>(stack[2]));
Tensor t4(to<AtenTensorHandle>(stack[3]));
int64_t blank = to<int64_t>(stack[4]);
auto result = compute(
std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank);
stack[0] = from(std::get<0>(result));
stack[1] = from(std::get<1>(result));
}


TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &compute);
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("forced_align", &boxed_compute);
}

} // namespace cpu
Expand Down
Loading
Loading