diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 4b87b20e5f..164eca5ea0 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,23 +1,27 @@ -#include -#include -#include +#include +#include +#include +#include +#include + using namespace std; namespace torchaudio { namespace alignment { namespace cpu { + +using torch::stable::Tensor; + // Inspired from // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp -template +template void forced_align_impl( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const int64_t blank, - torch::Tensor& paths) { + const Tensor logProbs, + const Tensor targets, + target_t blank, + Tensor paths) { const scalar_t kNegInfinity = -std::numeric_limits::infinity(); - using target_t = typename std:: - conditional::type; const auto batchIndex = 0; // TODO: support batch version and use the real batch index const auto T = logProbs.size(1); @@ -34,42 +38,38 @@ void forced_align_impl( backPtr_a[i] = -1; } - auto logProbs_a = logProbs.accessor(); - auto targets_a = targets.accessor(); - auto paths_a = paths.accessor(); + auto logProbs_a = Accessor<3, scalar_t, true>(logProbs); + auto targets_a = Accessor<2, target_t, true>(targets); + auto paths_a = Accessor<2, target_t, false>(paths); auto R = 0; for (auto i = 1; i < L; i++) { - if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) { + if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) { ++R; } } - TORCH_CHECK( + AOTI_TORCH_CHECK( T >= L + R, - "targets length is too long for CTC. Found log_probs length: ", - T, - ", targets length: ", - L, - ", and number of repeats: ", - R); + "targets length is too long for CTC"); auto start = T - (L + R) > 0 ? 0 : 1; auto end = (S == 1) ? 1 : 2; for (auto i = start; i < end; i++) { - auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; - alphas_a[i] = logProbs_a[batchIndex][0][labelIdx]; // alphas_a[0, i] + auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2); + alphas_a[i] = logProbs_a.index(batchIndex,0,labelIdx); + } for (auto t = 1; t < T; t++) { if (T - t <= L + R) { if ((start % 2 == 1) && - targets_a[batchIndex][start / 2] != - targets_a[batchIndex][start / 2 + 1]) { + targets_a.index(batchIndex, start / 2) != + targets_a.index(batchIndex, start / 2 + 1)) { start = start + 1; } start = start + 1; } if (t <= L + R) { if (end % 2 == 0 && end < 2 * L && - targets_a[batchIndex][end / 2 - 1] != - targets_a[batchIndex][end / 2]) { + targets_a.index(batchIndex, end / 2 - 1) != + targets_a.index(batchIndex, end / 2)) { end = end + 1; } end = end + 1; @@ -82,8 +82,8 @@ void forced_align_impl( } if (start == 0) { alphas_a[curIdxOffset * S] = - alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0] - backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0 + alphas_a[prevIdxOffset * S] + logProbs_a.index(batchIndex, t, blank); + backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0 startloop += 1; } @@ -92,14 +92,14 @@ void forced_align_impl( auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1]; auto x2 = -std::numeric_limits::infinity(); - auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; + auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2); // In CTC, the optimal path may optionally chose to skip a blank label. // x2 represents skipping a letter, and can only happen if we're not // currently on a blank_label, and we're not on a repeat letter // (i != 1) just ensures we don't access targets[i - 2] if its i < 2 if (i % 2 != 0 && i != 1 && - targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) { + targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) { x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2]; } scalar_t result = 0.0; @@ -113,7 +113,8 @@ void forced_align_impl( result = x0; backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0 } - alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i] + + alphas_a[curIdxOffset * S + i] = result + logProbs_a.index(batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i] } } auto idx1 = (T - 1) % 2; @@ -122,75 +123,121 @@ void forced_align_impl( delete[] alphas_a; // path stores the token index for each time step after force alignment. for (auto t = T - 1; t > -1; t--) { - auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2]; - paths_a[batchIndex][t] = lbl_idx; + auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2); + paths_a.set_index(lbl_idx, batchIndex, t); ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ltrIdx] } delete[] backPtr_a; } -std::tuple compute( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, +std::tuple compute( + const Tensor logProbs, + const Tensor targets, + Tensor inputLengths, + Tensor targetLengths, const int64_t blank) { - TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); - TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); - TORCH_CHECK( - logProbs.device() == targets.device(), + AOTI_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); + AOTI_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); + AOTI_TORCH_CHECK( + logProbs.get_device() == targets.get_device(), "log_probs and targets need to be on the same device"); - TORCH_CHECK( - logProbs.dtype() == torch::kFloat64 || - logProbs.dtype() == torch::kFloat32 || - logProbs.dtype() == torch::kFloat16, + int32_t logprobs_dtype; + aoti_torch_get_dtype(logProbs.get(), &logprobs_dtype); + AOTI_TORCH_CHECK( + logprobs_dtype == aoti_torch_dtype_float64() || + logprobs_dtype == aoti_torch_dtype_float32() || + logprobs_dtype == aoti_torch_dtype_float16(), "log_probs must be float64, float32 or float16 (half) type"); - TORCH_CHECK( - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK( + targets_dtype == aoti_torch_dtype_int32() || targets_dtype == aoti_torch_dtype_int64(), "targets must be int32 or int64 type"); - TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( + AOTI_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); + AOTI_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + AOTI_TORCH_CHECK( logProbs.dim() == 3, "log_probs must be 3-D (batch_size, input length, num classes)"); - TORCH_CHECK( + AOTI_TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch_size, target length,)"); - TORCH_CHECK( + AOTI_TORCH_CHECK( inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + AOTI_TORCH_CHECK( targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)"); - TORCH_CHECK( + AOTI_TORCH_CHECK( logProbs.size(0) == 1, "The batch dimension for log_probs must be 1 at the current version.") - TORCH_CHECK( + AOTI_TORCH_CHECK( targets.size(0) == 1, "The batch dimension for targets must be 1 at the current version.") - TORCH_CHECK( + AOTI_TORCH_CHECK( blank >= 0 && blank < logProbs.size(-1), "blank must be within [0, num classes)"); - TORCH_CHECK( - logProbs.size(1) == at::max(inputLengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(targetLengths).item().toInt(), - "target length mismatch"); + int32_t targetLengths_dtype; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(targetLengths.get(), &targetLengths_dtype)); + AOTI_TORCH_CHECK( + targetLengths_dtype == aoti_torch_dtype_int32() || targetLengths_dtype == aoti_torch_dtype_int64(), + "target lengths must be int32 or int64 type"); + auto target_length_max = amax(targetLengths, 0, false); + void *target_max_length_ptr = target_length_max.data_ptr(); + int64_t target_max_length; + if (targetLengths_dtype == aoti_torch_dtype_int32()) { + printf("\n\n## INT32\n\n"); + int32_t *ptr = (int32_t *)(target_max_length_ptr); + target_max_length = (int64_t)(*ptr); + } else if (targetLengths_dtype == aoti_torch_dtype_int64()) { + printf("\n\n## INT64\n\n"); + target_max_length = *((int64_t *)(target_max_length_ptr)); + } + printf("TARGET MAX LENGTH IS %ld\n", target_max_length); + TORCH_CHECK(targets.size(1) == target_max_length, "target length mismatch"); + + int32_t inputLengths_dtype; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(inputLengths.get(), &inputLengths_dtype)); + AOTI_TORCH_CHECK( + inputLengths_dtype == aoti_torch_dtype_int32() || inputLengths_dtype == aoti_torch_dtype_int64(), + "input lengths must be int32 or int64 type"); + auto input_length_max = amax(inputLengths, 0, false); + void *input_max_length_ptr = input_length_max.data_ptr(); + int64_t input_max_length; + if (inputLengths_dtype == aoti_torch_dtype_int32()) { + int32_t *ptr = (int32_t *)(input_max_length_ptr); + input_max_length = (int64_t)(*ptr); + } else if (inputLengths_dtype == aoti_torch_dtype_int64()) { + input_max_length = *((int64_t *)(input_max_length_ptr)); + } + TORCH_CHECK(logProbs.size(1) == input_max_length, "input length mismatch"); const auto B = logProbs.size(0); const auto T = logProbs.size(1); - auto paths = torch::zeros( - {B, T}, - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - logProbs.scalar_type(), "forced_align_impl", [&] { - if (targets.scalar_type() == torch::kInt64) { - forced_align_impl( - logProbs, targets, blank, paths); - } else { - forced_align_impl( - logProbs, targets, blank, paths); - } - }); + + int64_t paths_size[2] = {B, T}; + int64_t paths_stride[2] = {T, 1}; + AtenTensorHandle paths_h; + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + aoti_torch_empty_strided(2, paths_size, paths_stride, targets_dtype, targets_device, targets.get_device(), &paths_h); + auto paths = Tensor(paths_h); + + + if (targets_dtype == aoti_torch_dtype_int64()) { + if (logprobs_dtype == aoti_torch_dtype_float64()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logprobs_dtype == aoti_torch_dtype_float32()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logprobs_dtype == aoti_torch_dtype_float16()) { + forced_align_impl(logProbs, targets, blank, paths); + } + } else if (targets_dtype == aoti_torch_dtype_int32()) { + if (logprobs_dtype == aoti_torch_dtype_float64()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logprobs_dtype == aoti_torch_dtype_float32()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logprobs_dtype == aoti_torch_dtype_float16()) { + forced_align_impl(logProbs, targets, blank, paths); + } + } return std::make_tuple( paths, logProbs @@ -198,10 +245,20 @@ std::tuple compute( } +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + int64_t blank = to(stack[4]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); +} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("forced_align", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("forced_align", &boxed_compute); } } // namespace cpu