Skip to content

Merge from upstream #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ than Linux, which are worth keeping in mind when fixing these problems.
3. If you have a Windows box (we have a few on EC2 which you can request access to) and
you want to run the build, the easiest way is to just run `.jenkins/pytorch/win-build.sh`.
If you need to rebuild, run `REBUILD=1 .jenkins/pytorch/win-build.sh` (this will avoid
blowing away your Conda environment.) I recommend opening `cmd.exe`, and then running
`bash` to work in a bash shell (which will make various Linux commands available.)
blowing away your Conda environment.)

Even if you don't know anything about MSVC, you can use cmake to build simple programs on
Windows; this can be helpful if you want to learn more about some peculiar linking behavior
Expand All @@ -296,6 +295,53 @@ cmake ..
cmake --build .
```

### Known MSVC (and MSVC with NVCC) bugs

The PyTorch codebase sometimes likes to use exciting C++ features, and
these exciting features lead to exciting bugs in Windows compilers.
To add insult to injury, the error messages will often not tell you
which line of code actually induced the erroring template instantiation.

I've found the most effective way to debug these problems is to
carefully read over diffs, keeping in mind known bugs in MSVC/NVCC.
Here are a few well known pitfalls and workarounds:

* This is not actually a bug per se, but in general, code generated by MSVC
is more sensitive to memory errors; you may have written some code
that does a use-after-free or stack overflows; on Linux the code
might work, but on Windows your program will crash. ASAN may not
catch all of these problems: stay vigilant to the possibility that
your crash is due to a real memory problem.

* (NVCC) `at::optional` does not work when used from device code. Don't use
it from kernels. Upstream issue: https://github.com/akrzemi1/Optional/issues/58
and our local issue #10329.

* `constexpr` generally works less well on MSVC.

* The idiom `static_assert(f() == f())` to test if `f` is constexpr
does not work; you'll get "error C2131: expression did not evaluate
to a constant". Don't use these asserts on Windows.
(Example: `aten/src/ATen/core/intrusive_ptr.h`)

* (NVCC) Code you access inside a `static_assert` will eagerly be
evaluated as if it were device code, and so you might get an error
that the code is "not accessible".

```
class A {
static A singleton_;
static constexpr inline A* singleton() {
return &singleton_;
}
};
static_assert(std::is_same(A*, decltype(A::singelton()))::value, "hmm");
```

* The compiler will run out of heap if you attempt to compile files that
are too large. Splitting such files into separate files helps.
(Example: `THTensorMath`, `THTensorMoreMath`, `THTensorEvenMoreMath`.)

## Caffe2 notes

In 2018, we merged Caffe2 into the PyTorch source repository. While the
Expand Down
13 changes: 3 additions & 10 deletions aten/src/ATen/Scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
namespace at {

Tensor Scalar::toTensor() const {
if (Tag::HAS_t == tag) {
return Tensor(t);
} else if (Tag::HAS_d == tag) {
if (Tag::HAS_d == tag) {
return CPU(kDouble).scalarTensor(*this);
} else {
assert(Tag::HAS_i == tag);
Expand All @@ -24,19 +22,14 @@ Tensor Scalar::toTensor() const {
}

Scalar Scalar::local() const {
if (Tag::HAS_t != tag) {
return *this;
}
return Tensor(t)._local_scalar();
return *this;
}

Scalar Scalar::operator-() const {
if (isFloatingPoint()) {
return Scalar(-v.d);
} else if (isIntegral()) {
return Scalar(-v.i);
} else {
return -Tensor(t)._local_scalar();
return Scalar(-v.i);
}
}

Expand Down
11 changes: 2 additions & 9 deletions aten/src/ATen/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "ATen/core/ATenGeneral.h"
#include "ATen/core/ScalarType.h"
#include "ATen/TensorBase.h"
#include "ATen/core/Half.h"

namespace at {
Expand All @@ -34,9 +33,7 @@ class AT_API Scalar {

#define DEFINE_ACCESSOR(type,name,member) \
type to##name () const { \
if (Tag::HAS_t == tag) { \
return local().to##name(); \
} else if (Tag::HAS_d == tag) { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else { \
return checked_convert<type, int64_t>(v.i, #type); \
Expand All @@ -58,20 +55,16 @@ class AT_API Scalar {
bool isIntegral() const {
return Tag::HAS_i == tag;
}
bool isBackedByTensor() const {
return Tag::HAS_t == tag;
}

Scalar operator-() const;

private:
enum class Tag { HAS_d, HAS_i, HAS_t };
enum class Tag { HAS_d, HAS_i };
Tag tag;
union {
double d;
int64_t i = 0;
} v;
detail::TensorBase t;
friend struct Type;
};

Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/templates/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ Tensor Type::tensorWithAllocator(IntList sizes, IntList strides, Allocator* allo
return tensor(storage, 0, sizes, strides);
}
Tensor Type::scalarTensor(Scalar s) const {
if(s.isBackedByTensor())
return Tensor(s.t).toType(*this);
return tensor({}).fill_(s);
}

Expand Down
9 changes: 0 additions & 9 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,6 @@ def _check_module_exists(name):
import numpy


def skipIfRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if TEST_WITH_ROCM:
raise unittest.SkipTest("test doesn't currently work on the ROCm stack")
else:
fn(*args, **kwargs)
return wrapper

def skipIfRocm(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/python_torch_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ inline Tensor dispatch_arange(Scalar start, Scalar end, Scalar step, const Tenso

static inline bool allIntegral(std::initializer_list<std::reference_wrapper<Scalar>> l) {
for (Scalar& s : l) {
if (!(s.isIntegral() || (s.isBackedByTensor() && at::isIntegralType(s.toTensor().type().scalarType())))) {
if (!s.isIntegral()) {
return false;
}
}
Expand Down