|
| 1 | +#include "Descriptors.h" |
| 2 | +#include <ATen/ATen.h> |
| 3 | + |
| 4 | +namespace at { namespace native { |
| 5 | + |
| 6 | +namespace { |
| 7 | + |
| 8 | +inline miopenDataType_t getDataType(const at::Type& t) { |
| 9 | + auto scalar_type = t.scalarType(); |
| 10 | + if (scalar_type == at::kFloat) { |
| 11 | + return miopenFloat; |
| 12 | + } else if (scalar_type == at::kHalf) { |
| 13 | + return miopenHalf; |
| 14 | + } |
| 15 | + throw std::runtime_error("TensorDescriptor only supports float and half tensors"); |
| 16 | +} |
| 17 | + |
| 18 | +inline miopenDataType_t getDataType(const at::Tensor& t) { |
| 19 | + return getDataType(t.type()); |
| 20 | +} |
| 21 | + |
| 22 | +} // anonymous namespace |
| 23 | + |
| 24 | + |
| 25 | +void TensorDescriptor::set(const at::Tensor &t, size_t pad) { |
| 26 | + set(getDataType(t), t.sizes(), t.strides(), pad); |
| 27 | +} |
| 28 | + |
| 29 | +static int MIOPEN_DIM_MAX = 4; |
| 30 | + |
| 31 | +void TensorDescriptor::set(miopenDataType_t datatype, IntList t_sizes, IntList t_strides, size_t pad) { |
| 32 | + size_t dim = t_sizes.size(); |
| 33 | + if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) |
| 34 | +#define _STR(X) #X |
| 35 | +#define STR(X) _STR(X) |
| 36 | + throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions"); |
| 37 | +#undef _STR |
| 38 | +#undef STR |
| 39 | + int size[MIOPEN_DIM_MAX]; |
| 40 | + int stride[MIOPEN_DIM_MAX]; |
| 41 | + for (size_t i = 0; i < dim; ++i) { |
| 42 | + size[i] = static_cast<int>(t_sizes[i]); |
| 43 | + stride[i] = static_cast<int>(t_strides[i]); |
| 44 | + } |
| 45 | + for (size_t i = dim; i < pad; ++i) { |
| 46 | + size[i] = 1; |
| 47 | + stride[i] = 1; |
| 48 | + } |
| 49 | + set(datatype, static_cast<int>(std::max(dim, pad)), size, stride); |
| 50 | +} |
| 51 | + |
| 52 | +std::string miopenTypeToString(miopenDataType_t dtype) { |
| 53 | + switch (dtype) { |
| 54 | + case miopenFloat: |
| 55 | + return "miopenFloat"; |
| 56 | + case miopenHalf: |
| 57 | + return "miopenHalf"; |
| 58 | + default: |
| 59 | + std::ostringstream oss; |
| 60 | + oss << "(unknown data-type " << static_cast<int>(dtype) << ")"; |
| 61 | + return oss.str(); |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) { |
| 66 | + out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n"; |
| 67 | + int nbDims = 4; |
| 68 | + int dimA[MIOPEN_DIM_MAX]; |
| 69 | + int strideA[MIOPEN_DIM_MAX]; |
| 70 | + miopenDataType_t dtype; |
| 71 | + miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA); |
| 72 | + out << " type = " << miopenTypeToString(dtype) << "\n"; |
| 73 | + out << " nbDims = " << nbDims << "\n"; |
| 74 | + // Read out only nbDims of the arrays! |
| 75 | + out << " dimA = "; |
| 76 | + for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) { |
| 77 | + out << i << ", "; |
| 78 | + } |
| 79 | + out << "\n"; |
| 80 | + out << " strideA = "; |
| 81 | + for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) { |
| 82 | + out << i << ", "; |
| 83 | + } |
| 84 | + out << "\n"; |
| 85 | + return out; |
| 86 | +} |
| 87 | + |
| 88 | +void TensorDescriptor::print() { std::cout << *this; } |
| 89 | + |
| 90 | +void FilterDescriptor::set(const at::Tensor &t, int64_t pad) { |
| 91 | + auto dim = t.ndimension(); |
| 92 | + if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) |
| 93 | +#define _STR(X) #X |
| 94 | +#define STR(X) _STR(X) |
| 95 | + throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions"); |
| 96 | +#undef _STR |
| 97 | +#undef STR |
| 98 | + if (!t.is_contiguous()) { |
| 99 | + throw std::runtime_error("MIOpen filters (a.k.a. weights) must be contiguous"); |
| 100 | + } |
| 101 | + int size[MIOPEN_DIM_MAX]; |
| 102 | + int stride[MIOPEN_DIM_MAX]; |
| 103 | + for (int i = 0; i < dim; ++i) { |
| 104 | + size[i] = (int) t.size(i); |
| 105 | + } |
| 106 | + for (int i = dim; i < pad; ++i) { |
| 107 | + size[i] = (int) 1; |
| 108 | + } |
| 109 | + for (int i = dim - 1; i >=0; --i) { |
| 110 | + stride[i] = (i == dim - 1) ? 1 : stride[i+1] * size[i+1]; |
| 111 | + } |
| 112 | + dim = std::max(dim, pad); |
| 113 | + set(getDataType(t), (int) dim, size, stride); |
| 114 | +} |
| 115 | + |
| 116 | +}} |
0 commit comments