Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ void testGPU_FusionParser() {
at::Tensor input1 = at::randn({16}, options);
at::Tensor input2 = at::randn({16}, options);
fuser::cuda::scheduleFusion(prog.fusion_.get(), {input1, input2});

// CONSIDER:
// 1. this can be moved to a dedicated "golden" file
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
Expand Down Expand Up @@ -1082,6 +1081,12 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Te
<< actual_kernel.str() << "\n=================" << std::endl;
TORCH_CHECK(false);
}
fuser::cuda::compileKernel(&prog);
at::Tensor output = at::empty_like(input1);
// no broadcasting needed, omitting the last optional argument;
torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output});
at::Tensor output_ref = input1 * input2 * input1;
TORCH_CHECK(output_ref.equal(output));
}

void testGPU_FusionForLoop() {
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ void runKernel(
CudaKernel* entry,
const at::ArrayRef<IValue> inputs,
const std::vector<at::Tensor>& outputs,
const std::vector<int64_t>& broadcasted_shape) {
const c10::optional<at::IntArrayRef>& broadcasted_size) {
validateKernelArgs(*entry, inputs, outputs);

const auto prior_device = at::cuda::current_device();
Expand All @@ -552,7 +552,7 @@ void runKernel(
// from I/O expected by the generated CUDA kernel.
for (auto& input : inputs) {
if (input.isTensor()) {
kernel_args.push(input.toTensor(), broadcasted_shape);
kernel_args.push(input.toTensor(), broadcasted_size);
} else {
kernel_args.push(input);
}
Expand All @@ -567,7 +567,7 @@ void runKernel(
EvaluationContext eval_context(fusion);
for (int i = 0; i < (int)inputs.size(); i++) {
if (inputs[i].isTensor()) {
ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_shape);
ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_size);
int nDims = ess.sizes.size();
TensorView* tv = fusion->inputs()[i]->as<TensorView>();
for (int j = 0; j < nDims; j++) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ TORCH_CUDA_API void runKernel(
CudaKernel* entry,
const at::ArrayRef<c10::IValue> inputs,
const std::vector<at::Tensor>& outputs,
const std::vector<int64_t>& broadcasted_shape);
const c10::optional<at::IntArrayRef>& broadcasted_size = c10::nullopt);

// Facility API to run kernel in tests.
TORCH_CUDA_API void runTestKernel(
Expand Down