@@ -131,10 +131,10 @@ void forced_align_impl(
131
131
}
132
132
133
133
std::tuple<Tensor, Tensor> compute (
134
- const Tensor& logProbs,
135
- const Tensor& targets,
136
- const Tensor& inputLengths,
137
- const Tensor& targetLengths,
134
+ const Tensor logProbs,
135
+ const Tensor targets,
136
+ Tensor inputLengths,
137
+ Tensor targetLengths,
138
138
const int64_t blank) {
139
139
AOTI_TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
140
140
AOTI_TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
@@ -174,13 +174,40 @@ std::tuple<Tensor, Tensor> compute(
174
174
blank >= 0 && blank < logProbs.size (-1 ),
175
175
" blank must be within [0, num classes)" );
176
176
177
- // TODO: Requires port of `max` and `item` operators.
178
- // TORCH_CHECK(
179
- // logProbs.size(1) == at::max(inputLengths).item().toInt(),
180
- // "input length mismatch");
181
- // TORCH_CHECK(
182
- // targets.size(1) == at::max(targetLengths).item().toInt(),
183
- // "target length mismatch");
177
+ int32_t targetLengths_dtype;
178
+ TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (targetLengths.get (), &targetLengths_dtype));
179
+ AOTI_TORCH_CHECK (
180
+ targetLengths_dtype == aoti_torch_dtype_int32 () || targetLengths_dtype == aoti_torch_dtype_int64 (),
181
+ " target lengths must be int32 or int64 type" );
182
+ auto target_length_max = amax (targetLengths, 0 , false );
183
+ void *target_max_length_ptr = target_length_max.data_ptr ();
184
+ int64_t target_max_length;
185
+ if (targetLengths_dtype == aoti_torch_dtype_int32 ()) {
186
+ printf (" \n\n ## INT32\n\n " );
187
+ int32_t *ptr = (int32_t *)(target_max_length_ptr);
188
+ target_max_length = (int64_t )(*ptr);
189
+ } else if (targetLengths_dtype == aoti_torch_dtype_int64 ()) {
190
+ printf (" \n\n ## INT64\n\n " );
191
+ target_max_length = *((int64_t *)(target_max_length_ptr));
192
+ }
193
+ printf (" TARGET MAX LENGTH IS %ld\n " , target_max_length);
194
+ TORCH_CHECK (targets.size (1 ) == target_max_length, " target length mismatch" );
195
+
196
+ int32_t inputLengths_dtype;
197
+ TORCH_ERROR_CODE_CHECK (aoti_torch_get_dtype (inputLengths.get (), &inputLengths_dtype));
198
+ AOTI_TORCH_CHECK (
199
+ inputLengths_dtype == aoti_torch_dtype_int32 () || inputLengths_dtype == aoti_torch_dtype_int64 (),
200
+ " input lengths must be int32 or int64 type" );
201
+ auto input_length_max = amax (inputLengths, 0 , false );
202
+ void *input_max_length_ptr = input_length_max.data_ptr ();
203
+ int64_t input_max_length;
204
+ if (inputLengths_dtype == aoti_torch_dtype_int32 ()) {
205
+ int32_t *ptr = (int32_t *)(input_max_length_ptr);
206
+ input_max_length = (int64_t )(*ptr);
207
+ } else if (inputLengths_dtype == aoti_torch_dtype_int64 ()) {
208
+ input_max_length = *((int64_t *)(input_max_length_ptr));
209
+ }
210
+ TORCH_CHECK (logProbs.size (1 ) == input_max_length, " input length mismatch" );
184
211
185
212
const auto B = logProbs.size (0 );
186
213
const auto T = logProbs.size (1 );
0 commit comments