Skip to content

Commit 58ce20b

Browse files
authored
Merge pull request #160 from jithunnair-amd/skip_tests
Skip KLDivLoss_cuda tests due to hang
2 parents 3f93283 + eb81bae commit 58ce20b

File tree

7 files changed

+57
-24
lines changed

7 files changed

+57
-24
lines changed

CONTRIBUTING.md

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ than Linux, which are worth keeping in mind when fixing these problems.
269269
3. If you have a Windows box (we have a few on EC2 which you can request access to) and
270270
you want to run the build, the easiest way is to just run `.jenkins/pytorch/win-build.sh`.
271271
If you need to rebuild, run `REBUILD=1 .jenkins/pytorch/win-build.sh` (this will avoid
272-
blowing away your Conda environment.) I recommend opening `cmd.exe`, and then running
273-
`bash` to work in a bash shell (which will make various Linux commands available.)
272+
blowing away your Conda environment.)
274273

275274
Even if you don't know anything about MSVC, you can use cmake to build simple programs on
276275
Windows; this can be helpful if you want to learn more about some peculiar linking behavior
@@ -296,6 +295,53 @@ cmake ..
296295
cmake --build .
297296
```
298297

298+
### Known MSVC (and MSVC with NVCC) bugs
299+
300+
The PyTorch codebase sometimes likes to use exciting C++ features, and
301+
these exciting features lead to exciting bugs in Windows compilers.
302+
To add insult to injury, the error messages will often not tell you
303+
which line of code actually induced the erroring template instantiation.
304+
305+
I've found the most effective way to debug these problems is to
306+
carefully read over diffs, keeping in mind known bugs in MSVC/NVCC.
307+
Here are a few well known pitfalls and workarounds:
308+
309+
* This is not actually a bug per se, but in general, code generated by MSVC
310+
is more sensitive to memory errors; you may have written some code
311+
that does a use-after-free or stack overflows; on Linux the code
312+
might work, but on Windows your program will crash. ASAN may not
313+
catch all of these problems: stay vigilant to the possibility that
314+
your crash is due to a real memory problem.
315+
316+
* (NVCC) `at::optional` does not work when used from device code. Don't use
317+
it from kernels. Upstream issue: https://github.com/akrzemi1/Optional/issues/58
318+
and our local issue #10329.
319+
320+
* `constexpr` generally works less well on MSVC.
321+
322+
* The idiom `static_assert(f() == f())` to test if `f` is constexpr
323+
does not work; you'll get "error C2131: expression did not evaluate
324+
to a constant". Don't use these asserts on Windows.
325+
(Example: `aten/src/ATen/core/intrusive_ptr.h`)
326+
327+
* (NVCC) Code you access inside a `static_assert` will eagerly be
328+
evaluated as if it were device code, and so you might get an error
329+
that the code is "not accessible".
330+
331+
```
332+
class A {
333+
static A singleton_;
334+
static constexpr inline A* singleton() {
335+
return &singleton_;
336+
}
337+
};
338+
static_assert(std::is_same(A*, decltype(A::singelton()))::value, "hmm");
339+
```
340+
341+
* The compiler will run out of heap if you attempt to compile files that
342+
are too large. Splitting such files into separate files helps.
343+
(Example: `THTensorMath`, `THTensorMoreMath`, `THTensorEvenMoreMath`.)
344+
299345
## Caffe2 notes
300346

301347
In 2018, we merged Caffe2 into the PyTorch source repository. While the

aten/src/ATen/Scalar.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
namespace at {
1414

1515
Tensor Scalar::toTensor() const {
16-
if (Tag::HAS_t == tag) {
17-
return Tensor(t);
18-
} else if (Tag::HAS_d == tag) {
16+
if (Tag::HAS_d == tag) {
1917
return CPU(kDouble).scalarTensor(*this);
2018
} else {
2119
assert(Tag::HAS_i == tag);
@@ -24,19 +22,14 @@ Tensor Scalar::toTensor() const {
2422
}
2523

2624
Scalar Scalar::local() const {
27-
if (Tag::HAS_t != tag) {
28-
return *this;
29-
}
30-
return Tensor(t)._local_scalar();
25+
return *this;
3126
}
3227

3328
Scalar Scalar::operator-() const {
3429
if (isFloatingPoint()) {
3530
return Scalar(-v.d);
36-
} else if (isIntegral()) {
37-
return Scalar(-v.i);
3831
} else {
39-
return -Tensor(t)._local_scalar();
32+
return Scalar(-v.i);
4033
}
4134
}
4235

aten/src/ATen/Scalar.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "ATen/core/ATenGeneral.h"
1010
#include "ATen/core/ScalarType.h"
11-
#include "ATen/TensorBase.h"
1211
#include "ATen/core/Half.h"
1312

1413
namespace at {
@@ -34,9 +33,7 @@ class AT_API Scalar {
3433

3534
#define DEFINE_ACCESSOR(type,name,member) \
3635
type to##name () const { \
37-
if (Tag::HAS_t == tag) { \
38-
return local().to##name(); \
39-
} else if (Tag::HAS_d == tag) { \
36+
if (Tag::HAS_d == tag) { \
4037
return checked_convert<type, double>(v.d, #type); \
4138
} else { \
4239
return checked_convert<type, int64_t>(v.i, #type); \
@@ -58,20 +55,16 @@ class AT_API Scalar {
5855
bool isIntegral() const {
5956
return Tag::HAS_i == tag;
6057
}
61-
bool isBackedByTensor() const {
62-
return Tag::HAS_t == tag;
63-
}
6458

6559
Scalar operator-() const;
6660

6761
private:
68-
enum class Tag { HAS_d, HAS_i, HAS_t };
62+
enum class Tag { HAS_d, HAS_i };
6963
Tag tag;
7064
union {
7165
double d;
7266
int64_t i = 0;
7367
} v;
74-
detail::TensorBase t;
7568
friend struct Type;
7669
};
7770

aten/src/ATen/templates/Type.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ Tensor Type::tensorWithAllocator(IntList sizes, IntList strides, Allocator* allo
9191
return tensor(storage, 0, sizes, strides);
9292
}
9393
Tensor Type::scalarTensor(Scalar s) const {
94-
if(s.isBackedByTensor())
95-
return Tensor(s.t).toType(*this);
9694
return tensor({}).fill_(s);
9795
}
9896

aten/src/THC/THCAtomics.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,10 @@ static inline __device__ void atomicAdd(double *address, double val) {
138138
} while (assumed != old);
139139
}
140140
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000) || defined(__HIP_PLATFORM_HCC__)
141+
#if defined(__HIP_PLATFORM_HCC__) && __hcc_workwee__ < 18312
141142
// This needs to be defined for the host side pass
142143
static inline __device__ void atomicAdd(double *address, double val) { }
143144
#endif
145+
#endif
144146

145147
#endif // THC_ATOMICS_INC

test/common_nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
573573
reference_fn=lambda i, t, m:
574574
kldivloss_reference(i, t, get_reduction(m)),
575575
check_sum_reduction=True,
576+
test_cuda=(not TEST_WITH_ROCM)
576577
),
577578
dict(
578579
module_name='MSELoss',

tools/autograd/templates/python_torch_functions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ inline Tensor dispatch_arange(Scalar start, Scalar end, Scalar step, const Tenso
8383

8484
static inline bool allIntegral(std::initializer_list<std::reference_wrapper<Scalar>> l) {
8585
for (Scalar& s : l) {
86-
if (!(s.isIntegral() || (s.isBackedByTensor() && at::isIntegralType(s.toTensor().type().scalarType())))) {
86+
if (!s.isIntegral()) {
8787
return false;
8888
}
8989
}

0 commit comments

Comments
 (0)