diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ptq.py index 923cd76b43..0fa5d4cd64 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ptq.py @@ -30,12 +30,13 @@ def get_batch(self, names): batch = self.dataset_iterator.next() self.current_batch_idx += self.batch_size - # Treat the first element as input and others as targets. + inputs_gpu=[] if isinstance(batch, list): - batch = batch[0].to(self.device) + for example in batch: + inputs_gpu.append(example.to(self.device).data_ptr()) else: - batch = batch.to(self.device) - return [batch.data_ptr()] + inputs_gpu.append(batch.to(self.device).data_ptr()) + return inputs_gpu def read_calibration_cache(self):