Skip to content

Commit 93c6d68

Browse files
committed
Add length checks now that amax exists
1 parent 3133a16 commit 93c6d68

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ void forced_align_impl(
131131
}
132132

133133
std::tuple<Tensor, Tensor> compute(
134-
const Tensor& logProbs,
135-
const Tensor& targets,
136-
const Tensor& inputLengths,
137-
const Tensor& targetLengths,
134+
const Tensor logProbs,
135+
const Tensor targets,
136+
Tensor inputLengths,
137+
Tensor targetLengths,
138138
const int64_t blank) {
139139
AOTI_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
140140
AOTI_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
@@ -174,13 +174,40 @@ std::tuple<Tensor, Tensor> compute(
174174
blank >= 0 && blank < logProbs.size(-1),
175175
"blank must be within [0, num classes)");
176176

177-
// TODO: Requires port of `max` and `item` operators.
178-
// TORCH_CHECK(
179-
// logProbs.size(1) == at::max(inputLengths).item().toInt(),
180-
// "input length mismatch");
181-
// TORCH_CHECK(
182-
// targets.size(1) == at::max(targetLengths).item().toInt(),
183-
// "target length mismatch");
177+
int32_t targetLengths_dtype;
178+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(targetLengths.get(), &targetLengths_dtype));
179+
AOTI_TORCH_CHECK(
180+
targetLengths_dtype == aoti_torch_dtype_int32() || targetLengths_dtype == aoti_torch_dtype_int64(),
181+
"target lengths must be int32 or int64 type");
182+
auto target_length_max = amax(targetLengths, 0, false);
183+
void *target_max_length_ptr = target_length_max.data_ptr();
184+
int64_t target_max_length;
185+
if (targetLengths_dtype == aoti_torch_dtype_int32()) {
186+
printf("\n\n## INT32\n\n");
187+
int32_t *ptr = (int32_t *)(target_max_length_ptr);
188+
target_max_length = (int64_t)(*ptr);
189+
} else if (targetLengths_dtype == aoti_torch_dtype_int64()) {
190+
printf("\n\n## INT64\n\n");
191+
target_max_length = *((int64_t *)(target_max_length_ptr));
192+
}
193+
printf("TARGET MAX LENGTH IS %ld\n", target_max_length);
194+
TORCH_CHECK(targets.size(1) == target_max_length, "target length mismatch");
195+
196+
int32_t inputLengths_dtype;
197+
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(inputLengths.get(), &inputLengths_dtype));
198+
AOTI_TORCH_CHECK(
199+
inputLengths_dtype == aoti_torch_dtype_int32() || inputLengths_dtype == aoti_torch_dtype_int64(),
200+
"input lengths must be int32 or int64 type");
201+
auto input_length_max = amax(inputLengths, 0, false);
202+
void *input_max_length_ptr = input_length_max.data_ptr();
203+
int64_t input_max_length;
204+
if (inputLengths_dtype == aoti_torch_dtype_int32()) {
205+
int32_t *ptr = (int32_t *)(input_max_length_ptr);
206+
input_max_length = (int64_t)(*ptr);
207+
} else if (inputLengths_dtype == aoti_torch_dtype_int64()) {
208+
input_max_length = *((int64_t *)(input_max_length_ptr));
209+
}
210+
TORCH_CHECK(logProbs.size(1) == input_max_length, "input length mismatch");
184211

185212
const auto B = logProbs.size(0);
186213
const auto T = logProbs.size(1);

0 commit comments

Comments
 (0)