Skip to content

Commit ddc01e4

Browse files
authored
Exclude unsupported data types (#1951)
* Exclude unsupported data types
1 parent 992e17c commit ddc01e4

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ TEST_F(NVFuserTest, FusionStandaloneFull_CUDA) {
4343
fusion->addInput(fill_val2);
4444
fusion->addInput(fill_val3);
4545
for (auto dtype : dtypes) {
46+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
47+
continue;
48+
}
4649
auto out_tv = full({size}, fill_val1, aten_to_data_type(dtype));
4750
fusion->addOutput(out_tv);
4851
out_tv = full({size, size}, fill_val2, aten_to_data_type(dtype));
@@ -57,6 +60,9 @@ TEST_F(NVFuserTest, FusionStandaloneFull_CUDA) {
5760
std::vector<at::Tensor> expect;
5861
expect.reserve(dtypes.size());
5962
for (auto dtype : dtypes) {
63+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
64+
continue;
65+
}
6066
const auto options =
6167
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
6268
expect.emplace_back(at::full({size}, 11, options));
@@ -94,6 +100,9 @@ TEST_F(NVFuserTest, FusionStandaloneZeros_CUDA) {
94100
Val* size = IrBuilder::create<Int>();
95101
fusion->addInput(size);
96102
for (auto dtype : dtypes) {
103+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
104+
continue;
105+
}
97106
auto out_tv = zeros({size}, aten_to_data_type(dtype));
98107
fusion->addOutput(out_tv);
99108
out_tv = zeros({size, size}, aten_to_data_type(dtype));
@@ -108,6 +117,9 @@ TEST_F(NVFuserTest, FusionStandaloneZeros_CUDA) {
108117
std::vector<at::Tensor> expect;
109118
expect.reserve(dtypes.size());
110119
for (auto dtype : dtypes) {
120+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
121+
continue;
122+
}
111123
const auto options =
112124
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
113125
expect.emplace_back(at::zeros({size}, options));
@@ -145,6 +157,9 @@ TEST_F(NVFuserTest, FusionStandaloneOnes_CUDA) {
145157
Val* size = IrBuilder::create<Int>();
146158
fusion->addInput(size);
147159
for (auto dtype : dtypes) {
160+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
161+
continue;
162+
}
148163
auto out_tv = ones({size}, aten_to_data_type(dtype));
149164
fusion->addOutput(out_tv);
150165
out_tv = ones({size, size}, aten_to_data_type(dtype));
@@ -159,6 +174,9 @@ TEST_F(NVFuserTest, FusionStandaloneOnes_CUDA) {
159174
std::vector<at::Tensor> expect;
160175
expect.reserve(dtypes.size());
161176
for (auto dtype : dtypes) {
177+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
178+
continue;
179+
}
162180
const auto options =
163181
at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
164182
expect.emplace_back(at::ones({size}, options));
@@ -183,6 +201,10 @@ TEST_F(NVFuserTest, FusionStandaloneARange_CUDA) {
183201
auto dtypes = {kFloat, kLong, kDouble};
184202

185203
for (auto dtype : dtypes) {
204+
if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) {
205+
continue;
206+
}
207+
186208
auto fusion = std::make_unique<Fusion>();
187209
FusionGuard fg(fusion.get());
188210

torch/csrc/jit/codegen/cuda/type.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <torch/csrc/jit/codegen/cuda/type.h>
22

3+
#include <ATen/cuda/CUDAContext.h>
4+
35
#include <stdexcept>
46
#include <unordered_map>
57

@@ -160,6 +162,17 @@ DataType getTypeFromComplexType(DataType dtype) {
160162
}
161163
}
162164

165+
bool isSupportedTypeByDevice(DataType dtype) {
166+
auto prop = at::cuda::getCurrentDeviceProperties();
167+
auto major_ver = prop->major;
168+
switch (dtype) {
169+
case DataType::BFloat16:
170+
return major_ver >= 8;
171+
default:
172+
return true;
173+
}
174+
}
175+
163176
bool isIntegerOp(const BinaryOpType bopt) {
164177
return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift;
165178
}

torch/csrc/jit/codegen/cuda/type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ int getVectorSizeFromType(DataType dtype);
101101
DataType getTypeFromVectorType(DataType dtype);
102102
// Return the corresponding scalar of a complex type
103103
DataType getTypeFromComplexType(DataType dtype);
104+
// Return if the datatype is supported on the current device
105+
TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype);
104106

105107
enum class ExprType {
106108
Invalid,

0 commit comments

Comments
 (0)