Skip to content

Integrate from upstream #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0ebbfc2
Add utility function make_tensor (#12288)
devashisht Oct 5, 2018
9ebac3d
Improve type kind error message (#12344)
Oct 5, 2018
705d80b
Remove some Type.tensor usages and remove native_tensor without size.…
gchanan Oct 5, 2018
e2d2b27
Revert D10212616: [pytorch][PR] Remove some Type.tensor usages and re…
gchanan Oct 5, 2018
f808684
Fix bug in grad.py when conv bias != None (#12281)
daquexian Oct 5, 2018
c7e8044
Support additional device types (#12293)
nairbv Oct 5, 2018
bd09ab6
Remove stages from IR, they are not longer used
zdevito Oct 5, 2018
f9fb37c
Guard Denormals-Are-Zero with runtime CPU check (#12386)
colesbury Oct 5, 2018
54d9823
Make caffe2::Tensor::dims() return an IntList instead of a const vect…
ezyang Oct 5, 2018
57fcc57
set CMAKE_INSTALL_MESSAGE to NEVER (#12392)
anderspapitto Oct 5, 2018
b937cbb
Fix a bug that would resize tensor storage on export
zdevito Oct 5, 2018
99de456
Split reduction_front_backops.[cc|cu] into smaller units to allow bui…
3l1 Oct 5, 2018
e1fe617
Fix flipped pad buffer constructor arguments
Oct 6, 2018
3f04ca9
Remove duplicate math transpilation function (ROCm 233) (#12387)
iotamudelta Oct 6, 2018
058a318
Warn about local_rank not being globally unique. (#12370)
ezyang Oct 6, 2018
92b0e70
Add weak script mode for script functions (#11963)
Oct 6, 2018
14b48a2
Use custom CPU thread pool in async_scheduling (#12295)
Oct 6, 2018
6954659
Remove some Type.tensor usages and remove native_tensor without size.…
gchanan Oct 6, 2018
0e966fc
Back out "[caffe2] Use custom CPU thread pool in async_scheduling" (#…
gchanan Oct 6, 2018
ac9bb8e
Make dynamic_cast_if_rtti safer (#12408)
smessmer Oct 6, 2018
6f664d3
Improve TypeMeta (#11502)
smessmer Oct 6, 2018
db8d01b
Move JIT tests to gtest (#12030)
goldsborough Oct 7, 2018
0e44db8
Add check for backend of arguments to bmm cpu (#12434)
t-vi Oct 8, 2018
8689d8a
Format inline code block. (#12441)
marcemq Oct 8, 2018
def655e
fix critical section of atomic add op
Oct 8, 2018
28e1571
Add the x64 msvc toolchain into PATH (#12446)
peterjc123 Oct 8, 2018
3d39130
Merge remote-tracking branch 'rocm_upstream/upstream' into ifu
iotamudelta Oct 8, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .jenkins/pytorch/win-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ if "%REBUILD%"=="" ( call conda install -y -q numpy cffi pyyaml boto3 )
:: Install ninja
if "%REBUILD%"=="" ( pip install ninja )

call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x64
call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64

git submodule update --init --recursive
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
# ---[ Project and semantic versioning.
project(Caffe2 CXX C)

set(CMAKE_INSTALL_MESSAGE NEVER)

set(CMAKE_CXX_STANDARD 11)
if (NOT MSVC)
set(CMAKE_C_STANDARD 11)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ python setup.py install

### Docker image

Dockerfile is supplied to build images with cuda support and cudnn v7. You can pass -e PYTHON_VERSION=x.y flag to specificy which python to be used by Miniconda, or leave it unset to use the default. Build as usual
Dockerfile is supplied to build images with cuda support and cudnn v7. You can pass `-e PYTHON_VERSION=x.y` flag to specificy which python to be used by Miniconda, or leave it unset to use the default. Build as usual
```
docker build -t pytorch -f docker/pytorch/Dockerfile .
```
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ CONFIGURE_FILE(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
CONFIGURE_FILE(cuda/CUDAConfig.h.in "${CMAKE_CURRENT_SOURCE_DIR}/cuda/CUDAConfig.h")

# NB: If you edit these globs, you'll have to update setup.py package_data as well
FILE(GLOB base_h "*.h" "detail/*.h")
FILE(GLOB base_cpp "*.cpp" "detail/*.cpp")
FILE(GLOB base_h "*.h" "detail/*.h" "cpu/*.h")
FILE(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp")
add_subdirectory(core)
FILE(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh")
FILE(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp")
Expand Down
18 changes: 2 additions & 16 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
#include "ATen/CPUGenerator.h"
#include "ATen/RegisterCPU.h"
#include "ATen/Tensor.h"
#include <ATen/cpu/FlushDenormal.h>

#include "TH/TH.h" // for USE_LAPACK

#ifdef USE_SSE3
#include <pmmintrin.h>
#endif

namespace at {

static inline void errorHandler(const char * msg, void * data) {
Expand Down Expand Up @@ -94,18 +91,7 @@ bool Context::hasLAPACK() const {
}

bool Context::setFlushDenormal(bool on) {
#ifdef USE_SSE3
// Setting flush-to-zero (FTZ) flag
_MM_SET_FLUSH_ZERO_MODE(on ? _MM_FLUSH_ZERO_ON
: _MM_FLUSH_ZERO_OFF);

// Setting denormals-are-zero (DAZ) flag
_MM_SET_DENORMALS_ZERO_MODE(on ? _MM_DENORMALS_ZERO_ON
: _MM_DENORMALS_ZERO_OFF);
return true;
#else
return false;
#endif
return at::cpu::set_flush_denormal(on);
}

TypeExtendedInterface& getType(TensorOptions options) {
Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/InitialTensorOptions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <ATen/core/TensorOptions.h>

namespace at {

// Represents the initial TensorOptions, before the "defaults" are ever changed.
// This is designed to be used in library code, where the explicit devices, dtypes, etc. are known.
// NOTE: this is not a stable API.
inline TensorOptions initialTensorOptions() {
return TensorOptions(kCPU).dtype(kFloat).layout(kStrided)
.requires_grad(false).is_variable(false);
}

}
11 changes: 6 additions & 5 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#include <ATen/ATen.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/InitialTensorOptions.h>

namespace at {

namespace {
Backend sparseTensorIdToDenseBackend(TensorTypeId type_id) {
DeviceType sparseTensorIdToDeviceType(TensorTypeId type_id) {
if (type_id == SparseCPUTensorId()) {
return Backend::CPU;
return kCPU;
} else if (type_id == SparseCUDATensorId()) {
return Backend::CUDA;
return kCUDA;
} else {
AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id);
}
Expand All @@ -33,8 +34,8 @@ SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeM
, size_{0}
, sparseDims_(1)
, denseDims_(0)
, indices_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), ScalarType::Long)->tensor({1, 0}))
, values_(globalContext().getNonVariableTypeOpt(sparseTensorIdToDenseBackend(type_id), dataTypeToScalarType(data_type.id()))->tensor()) {}
, indices_(at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
, values_(at::empty({0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(dataTypeToScalarType(data_type.id())))) {}

IntList SparseTensorImpl::sizes() const {
return size_;
Expand Down
58 changes: 58 additions & 0 deletions aten/src/ATen/core/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ class ArrayRef final {
return Data + Length;
}

// These are actually the same as iterator, since ArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return Data;
}
constexpr const_iterator cend() const {
return Data + Length;
}

constexpr reverse_iterator rbegin() const {
return reverse_iterator(end());
}
Expand Down Expand Up @@ -209,4 +218,53 @@ class ArrayRef final {
/// @}
};

template <typename T>
std::ostream& operator<<(std::ostream & out, ArrayRef<T> list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}

// WARNING: Template instantiation will NOT be willing to do an implicit
// conversions to get you to an at::ArrayRef, which is why we need so
// many overloads.

template <typename T>
bool operator==(at::ArrayRef<T> a1, at::ArrayRef<T> a2) {
return a1.equals(a2);
}

template <typename T>
bool operator!=(at::ArrayRef<T> a1, at::ArrayRef<T> a2) {
return !a1.equals(a2);
}

template <typename T>
bool operator==(std::vector<T> a1, at::ArrayRef<T> a2) {
return at::ArrayRef<T>(a1).equals(a2);
}

template <typename T>
bool operator!=(std::vector<T> a1, at::ArrayRef<T> a2) {
return !at::ArrayRef<T>(a1).equals(a2);
}

template <typename T>
bool operator==(at::ArrayRef<T> a1, std::vector<T> a2) {
return a1.equals(at::ArrayRef<T>(a2));
}

template <typename T>
bool operator!=(at::ArrayRef<T> a1, std::vector<T> a2) {
return !a1.equals(at::ArrayRef<T>(a2));
}

using IntList = ArrayRef<int64_t>;

} // namespace at
59 changes: 34 additions & 25 deletions aten/src/ATen/core/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,37 @@
#include <ATen/core/Error.h>
#include <ATen/core/Macros.h>

#include <algorithm>
#include <array>
#include <exception>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>

namespace at {
namespace {
std::pair<Device::Type, size_t> parse_type(const std::string& device_string) {
auto position = device_string.find("cpu");
if (position != std::string::npos) {
return {Device::Type::CPU, 3};
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 7> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
{"opencl", DeviceType::OPENCL},
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[device_string](const std::pair<std::string, DeviceType>& p) {
return p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
position = device_string.find("cuda");
if (position != std::string::npos) {
return {Device::Type::CUDA, 4};
}
AT_ERROR("Expected 'cpu' or 'cuda' device type at start of device string");
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, or hip device type at start of device string");
}
} // namespace

Expand Down Expand Up @@ -47,28 +61,23 @@ std::pair<Device::Type, size_t> parse_type(const std::string& device_string) {
// }
Device::Device(const std::string& device_string) : Device(Type::CPU) {
AT_CHECK(!device_string.empty(), "Device string must not be empty");

size_t position;
std::tie(type_, position) = parse_type(device_string);

// e.g. 'cuda', 'cpu'.
if (position == device_string.size()) {
int index = device_string.find(":");
if (index == std::string::npos) {
type_ = parse_type(device_string);
return;
} else {
std::string s;
s = device_string.substr(0, index);
AT_CHECK(!s.empty(), "Device string must not be empty");
type_ = parse_type(s);
}

AT_CHECK(
device_string[position] == ':',
"Expected ':' to separate device type from index in device string");
// Skip the colon.
position += 1;

const auto index_string = device_string.substr(position);
std::string device_index = device_string.substr(index + 1);
try {
index_ = at::stoi(index_string);
index_ = at::stoi(device_index);
} catch (const std::exception&) {
AT_ERROR(
"Could not parse device index '",
index_string,
device_index,
"' in device string '",
device_string,
"'");
Expand Down
12 changes: 0 additions & 12 deletions aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,6 @@ struct FormatGuard {
std::ios saved;
};

std::ostream& operator<<(std::ostream & out, IntList list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}

std::ostream& operator<<(std::ostream & out, Backend b) {
return out << toString(b);
}
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/Formatting.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

namespace at {

CAFFE2_API std::ostream& operator<<(std::ostream& out, IntList list);
CAFFE2_API std::ostream& operator<<(std::ostream& out, Backend b);
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
CAFFE2_API std::ostream& print(
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ enum class ScalarType : int8_t {
};

static inline DataType scalarTypeToDataType(ScalarType scalar_type) {
#define DEFINE_CASE(ctype,name,_) \
case ScalarType:: name : return caffe2::TypeMeta::Id<ctype>();
#define DEFINE_CASE(ctype, name, _) \
case ScalarType::name: \
return caffe2::TypeIdentifier::Get<ctype>();

switch(scalar_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
Expand All @@ -93,9 +94,9 @@ static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
}

static inline ScalarType dataTypeToScalarType(DataType dtype) {
#define DEFINE_IF(ctype,name,_) \
if (dtype == caffe2::TypeMeta::Id<ctype>()) { \
return ScalarType:: name; \
#define DEFINE_IF(ctype, name, _) \
if (dtype == caffe2::TypeIdentifier::Get<ctype>()) { \
return ScalarType::name; \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
#undef DEFINE_IF
Expand Down Expand Up @@ -189,7 +190,6 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
}

class Tensor;
typedef ArrayRef<int64_t> IntList;
typedef ArrayRef<Tensor> TensorList;

inline std::ostream& operator<<(
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/core/SmallVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,19 @@ inline size_t capacity_in_bytes(const SmallVector<T, N>& X) {
return X.capacity_in_bytes();
}

template <typename T, unsigned N>
std::ostream& operator<<(std::ostream & out, const SmallVector<T, N>& list) {
int i = 0;
out << "[";
for(auto e : list) {
if (i++ > 0)
out << ", ";
out << e;
}
out << "]";
return out;
}

} // end namespace at

namespace std {
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,17 @@ struct CAFFE2_API WeakTensor {
private:
c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl> weak_impl_;
};

namespace detail {
// Helper creator for Tensor clas which doesn't requires the users to pass
// in an intrusive_ptr instead it just converts the argument passed to
// requested intrusive_ptr type.
template <typename T, typename... Args>
Tensor make_tensor(Args&&... args) {
return Tensor(c10::make_intrusive<T>(std::forward<Args>(args)...));
}
} // namespace detail

} // namespace at

#include "ATen/core/TensorMethods.h"
Loading