Skip to content

Commit abbf58b

Browse files
committed
Fix NumPy data type conversion error on Windows
1 parent 21a9832 commit abbf58b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def embedding_bag_with_ITensor_offsets(
184184
loop1 = ctx.net.add_loop()
185185
trip_limit1 = ctx.net.add_constant(
186186
shape=(),
187-
weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.dtype("i"))),
187+
weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.int32)),
188188
).get_output(0)
189189
loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT)
190190

@@ -205,7 +205,7 @@ def embedding_bag_with_ITensor_offsets(
205205
###### Inner loop: traverse indices ######
206206
loop2 = ctx.net.add_loop()
207207
trip_limit2 = ctx.net.add_constant(
208-
shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.dtype("i")))
208+
shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.int32))
209209
).get_output(0)
210210
loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT)
211211
rec2_j_tensor = loop2.add_recurrence(constant_0)

0 commit comments

Comments
 (0)