Skip to content

Merge from upstream #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Sep 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3da8d71
remove protobuf inclusion in core/logging.h (#11814)
Yangqing Sep 19, 2018
1c86860
Expunge (transitive) caffe2_pb2 dependency from tensor_impl.h from co…
ezyang Sep 19, 2018
a79f5d7
Add pretty printer for JIT IR (#10319)
Sep 19, 2018
1f34be4
Raise error when perf test result is NaN (#11588)
Sep 19, 2018
2c358ea
Caffe2: add plan name to logging (#11704)
shen-pan Sep 19, 2018
e80d1d2
Revert D9924348: Expunge (transitive) caffe2_pb2 dependency from tens…
Sep 19, 2018
8aedc27
checking device types of input and weights at RNN (#10185)
weiyangfb Sep 19, 2018
a26ad5a
Remove unnecessary check on device option pointer (#11845)
Sep 19, 2018
53b5f14
Remove inclusion of caffe2 pb (#11820)
Yangqing Sep 19, 2018
77af40c
prioritize Accelerate over OpenBLAS (#11812)
soumith Sep 19, 2018
b46f1b8
Open-source ThreadSafeActivationCleaningPredictor (#11779)
salexspb Sep 19, 2018
8601b33
fix half grad assignment (#11781)
Sep 19, 2018
32494c2
OperatorDef <==> NodeProto Conversion (#11621)
houseroad Sep 19, 2018
b3a2665
Code-reorg to have TORCH_ARG in its own header (#11787)
goldsborough Sep 19, 2018
8c3a94e
Improve autograd profiler performance (#11773)
apaszke Sep 19, 2018
fa32317
Add empty tensor tests to test_sparse (#11228)
Sep 19, 2018
5247250
Add env:// rendezvous test (#11782)
pietern Sep 19, 2018
3b1a5a1
Refactor tests part 2 (#11811)
ajyu Sep 19, 2018
ce55767
Add the missing header (#11864)
houseroad Sep 19, 2018
c307907
Minor data loader doc improvements
ssnl Sep 19, 2018
cf5a21e
Add back proto opt disable feature that was lost during refactor (#11…
bwasti Sep 19, 2018
24e958a
Move bernoulli into ATen (#10273)
ssnl Sep 19, 2018
cedd12d
Explicitly qualify references to CPU. (#11819)
ezyang Sep 19, 2018
b06e35b
Back out "Revert D9924348: Expunge (transitive) caffe2_pb2 dependency…
ezyang Sep 19, 2018
f4d2503
Fix Array.h when compiled with C++17 (#11816)
ezyang Sep 19, 2018
6302e40
Delete unnecessary include from allocator.cc/event_cpu.h
ezyang Sep 19, 2018
ae1a972
Fix #11752: correct numerical issue with log_softmax (#11866)
sytrus-in-github Sep 20, 2018
6831d64
Fix the symbolic for embedding_bag in ONNX_ATEN_FALLBACK (#11840)
houseroad Sep 20, 2018
1091c5e
Throw error on indexing a 0 dim tensor (#11679)
Sep 20, 2018
c22dcc2
Show build output in verbose mode of C++ extensions (#11724)
goldsborough Sep 20, 2018
aa8cd73
Enable build_test on windows (#11802)
mingzhe09088 Sep 20, 2018
c64331f
Add test for verifying combine_spatial_bn values in DPM (#11710)
Sep 20, 2018
83740ea
Avoid using PyThreadState.frame as it is not a public member. (#11855)
xuhdev Sep 20, 2018
23dd5b4
Back out "Open-source ThreadSafeActivationCleaningPredictor"
salexspb Sep 20, 2018
8f4601f
renable test_scalar_fusion
zou3519 Sep 20, 2018
1c77f9e
Support torch.distributed.barrier in gloo backend
pietern Sep 20, 2018
0927386
Workaround CUDA logging on some embedded platforms (#11851)
soumith Sep 20, 2018
8770128
fix link to previous versions (#11894)
soumith Sep 20, 2018
9cd0ae5
Remove deprecated factory functions from Type.
ezyang Sep 20, 2018
24ec813
Defer lazyInitCUDA() until needed (#11893)
pietern Sep 20, 2018
d8f6be6
Remove torch/legacy (#11823)
cpuhrsch Sep 20, 2018
068eac2
Jit fuse clamp (#11574)
t-vi Sep 20, 2018
6100c0e
Introduce ExtensionVersioner for C++ extensions (#11725)
goldsborough Sep 20, 2018
b91b15d
Implementing Matrix Norm for torch.norm (#11261)
yya007 Sep 20, 2018
c7751f4
MIOpen bug fixes and performance enhancements (#11766)
Sep 20, 2018
4f7cf5c
Merge remote-tracking branch 'rocm_upstream/upstream' into ifu
iotamudelta Sep 20, 2018
0655821
Skip as it fails on CI.
iotamudelta Sep 21, 2018
613dacb
Skip failing test.
iotamudelta Sep 21, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .jenkins/pytorch/macos-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if [ ! -d "${PYTORCH_ENV_DIR}/miniconda3" ]; then
fi
export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH"
source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja future six
if [ -z "${IN_CIRCLECI}" ]; then
rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch*
fi
Expand Down
16 changes: 14 additions & 2 deletions .jenkins/pytorch/perf_test/compare_with_baseline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import json
import math
import numpy
import argparse

Expand Down Expand Up @@ -35,14 +36,25 @@
print("population mean: ", mean)
print("population sigma: ", sigma)

# Let the test pass if baseline number is NaN (which happened in
# the past when we didn't have logic for catching NaN numbers)
if math.isnan(mean) or math.isnan(sigma):
mean = sys.maxsize
sigma = 0.001

sample_stats_data = json.loads(args.sample_stats)

sample_mean = sample_stats_data['mean']
sample_sigma = sample_stats_data['sigma']
sample_mean = float(sample_stats_data['mean'])
sample_sigma = float(sample_stats_data['sigma'])

print("sample mean: ", sample_mean)
print("sample sigma: ", sample_sigma)

if math.isnan(sample_mean):
raise Exception('''Error: sample mean is NaN''')
elif math.isnan(sample_sigma):
raise Exception('''Error: sample sigma is NaN''')

z_value = (sample_mean - mean) / sigma

print("z-value: ", z_value)
Expand Down
3 changes: 3 additions & 0 deletions .jenkins/pytorch/perf_test/test_gpu_speed_mnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ test_gpu_speed_mnist () {
SAMPLE_ARRAY=()
NUM_RUNS=$1

# Needs warm up to get accurate number
python main.py --epochs 1 --no-log

for (( i=1; i<=$NUM_RUNS; i++ )) do
runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log)
echo $runtime
Expand Down
2 changes: 1 addition & 1 deletion .jenkins/pytorch/win-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ curl https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe -O
call C:\\Jenkins\\Miniconda3\\Scripts\\activate.bat C:\\Jenkins\\Miniconda3
call conda install -y -q numpy mkl cffi pyyaml boto3

pip install ninja
pip install ninja future

call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ You can then build the documentation by running ``make <format>`` from the
### Previous Versions

Installation instructions and binaries for previous PyTorch versions may be found
on [our website](http://pytorch.org/previous-versions/).
on [our website](http://pytorch.org/previous-versions).


## Getting Started
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/CPUApplyUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].sizes() << ", ";
}
oss << "and " << tensors[tensors.size() - 1]
oss << "and " << tensors[tensors.size() - 1].sizes()
<< " to have the same number of elements, but got ";
for (size_t i = 0; i < tensors.size() - 1; i++) {
oss << tensors[i].numel() << ", ";
Expand All @@ -220,7 +220,7 @@ inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
checkBackend("CPU_tensor_apply", tensors, Backend::CPU);
if (!_all_equal_numel(tensors))
throw std::runtime_error(_all_equal_numel_error(tensors));
AT_ERROR(_all_equal_numel_error(tensors));
// An empty tensor has no elements
for (auto& t : tensors)
if (t.numel() == 0)
Expand Down
32 changes: 0 additions & 32 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3218,38 +3218,6 @@
kwarg_only: True
- double p
]]
[[
name: _bernoulli_
backends:
- CPU
- CUDA
cname: bernoulli
return: self
variants: function
arguments:
- THTensor* self
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- double p
]]
[[
name: _th_bernoulli
types:
- Float
- Double
return: argument 0
variants: function
cname: bernoulli_Tensor
arguments:
- arg: THTensor* output
output: True
resize: self
- arg: THGenerator* generator
default: nullptr
kwarg_only: True
- THTensor* self
]]
[[
name: _dirichlet_grad
types:
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/core/Macros.h>

#include <ostream>
#include <functional>

namespace at {

Expand All @@ -32,3 +33,11 @@ AT_CORE_API std::string DeviceTypeName(
AT_CORE_API std::ostream& operator<<(std::ostream& stream, at::DeviceType type);

} // namespace at

namespace std {
template <> struct hash<at::DeviceType> {
std::size_t operator()(const at::DeviceType &k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
8 changes: 3 additions & 5 deletions aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,10 @@ struct AT_API Tensor {
Tensor & atan_();
Tensor baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const;
Tensor & baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1);
Tensor bernoulli(const Tensor & p, Generator * generator=nullptr) const;
Tensor bernoulli(double p, Generator * generator=nullptr) const;
Tensor bernoulli() const;
Tensor bernoulli(Generator * generator=nullptr) const;
Tensor & bernoulli_(const Tensor & p, Generator * generator=nullptr);
Tensor & bernoulli_(double p, Generator * generator=nullptr);
Tensor & bernoulli_();
Tensor & bernoulli_(double p=0.5, Generator * generator=nullptr);
Tensor bernoulli(double p, Generator * generator=nullptr) const;
Tensor bincount(const Tensor & weights={}, int64_t minlength=0) const;
Tensor bmm(const Tensor & mat2) const;
Tensor ceil() const;
Expand Down
14 changes: 4 additions & 10 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,23 +605,17 @@ inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scal
inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) {
return type().baddbmm_(*this, batch1, batch2, beta, alpha);
}
inline Tensor Tensor::bernoulli(const Tensor & p, Generator * generator) const {
return type().bernoulli(*this, p, generator);
}
inline Tensor Tensor::bernoulli(double p, Generator * generator) const {
return type().bernoulli(*this, p, generator);
}
inline Tensor Tensor::bernoulli() const {
return type().bernoulli(*this);
inline Tensor Tensor::bernoulli(Generator * generator) const {
return type().bernoulli(*this, generator);
}
inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) {
return type().bernoulli_(*this, p, generator);
}
inline Tensor & Tensor::bernoulli_(double p, Generator * generator) {
return type().bernoulli_(*this, p, generator);
}
inline Tensor & Tensor::bernoulli_() {
return type().bernoulli_(*this);
inline Tensor Tensor::bernoulli(double p, Generator * generator) const {
return type().bernoulli(*this, p, generator);
}
inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const {
return type().bincount(*this, weights, minlength);
Expand Down
21 changes: 2 additions & 19 deletions aten/src/ATen/core/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,6 @@ struct AT_API Type {
virtual Tensor all(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual bool allclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
virtual Tensor any(const Tensor & self, int64_t dim, bool keepdim) const = 0;
AT_DEPRECATED(virtual Tensor arange(Scalar start, Scalar end, Scalar step) const = 0);
AT_DEPRECATED(virtual Tensor arange(Scalar end) const = 0);
virtual Tensor argmax(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor argmax(const Tensor & self) const = 0;
virtual Tensor argmin(const Tensor & self, int64_t dim, bool keepdim) const = 0;
Expand All @@ -397,12 +395,10 @@ struct AT_API Type {
virtual Tensor & atan_(Tensor & self) const = 0;
virtual Tensor baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const = 0;
virtual Tensor & baddbmm_(Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const = 0;
virtual Tensor bernoulli(const Tensor & self, const Tensor & p, Generator * generator) const = 0;
virtual Tensor bernoulli(const Tensor & self, double p, Generator * generator) const = 0;
virtual Tensor bernoulli(const Tensor & self) const = 0;
virtual Tensor bernoulli(const Tensor & self, Generator * generator) const = 0;
virtual Tensor & bernoulli_(Tensor & self, const Tensor & p, Generator * generator) const = 0;
virtual Tensor & bernoulli_(Tensor & self, double p, Generator * generator) const = 0;
virtual Tensor & bernoulli_(Tensor & self) const = 0;
virtual Tensor bernoulli(const Tensor & self, double p, Generator * generator) const = 0;
virtual Tensor bincount(const Tensor & self, const Tensor & weights, int64_t minlength) const = 0;
virtual Tensor bmm(const Tensor & self, const Tensor & mat2) const = 0;
virtual Tensor ceil(const Tensor & self) const = 0;
Expand Down Expand Up @@ -430,7 +426,6 @@ struct AT_API Type {
virtual Tensor div(const Tensor & self, Scalar other) const = 0;
virtual Tensor & div_(Tensor & self, Scalar other) const = 0;
virtual Tensor dot(const Tensor & self, const Tensor & tensor) const = 0;
AT_DEPRECATED(virtual Tensor empty(IntList size) const = 0);
virtual Tensor erf(const Tensor & self) const = 0;
virtual Tensor & erf_(Tensor & self) const = 0;
virtual Tensor erfc(const Tensor & self) const = 0;
Expand All @@ -441,13 +436,11 @@ struct AT_API Type {
virtual Tensor & expm1_(Tensor & self) const = 0;
virtual Tensor expand(const Tensor & self, IntList size, bool implicit) const = 0;
virtual Tensor expand_as(const Tensor & self, const Tensor & other) const = 0;
AT_DEPRECATED(virtual Tensor eye(int64_t n, int64_t m) const = 0);
virtual Tensor flatten(const Tensor & self, int64_t start_dim, int64_t end_dim) const = 0;
virtual Tensor & fill_(Tensor & self, Scalar value) const = 0;
virtual Tensor & fill_(Tensor & self, const Tensor & value) const = 0;
virtual Tensor floor(const Tensor & self) const = 0;
virtual Tensor & floor_(Tensor & self) const = 0;
AT_DEPRECATED(virtual Tensor full(IntList size, Scalar fill_value) const = 0);
virtual Tensor ger(const Tensor & self, const Tensor & vec2) const = 0;
virtual std::tuple<Tensor,Tensor> gesv(const Tensor & self, const Tensor & A) const = 0;
virtual Tensor fft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0;
Expand All @@ -469,7 +462,6 @@ struct AT_API Type {
virtual bool is_signed(const Tensor & self) const = 0;
virtual bool is_sparse(const Tensor & self) const = 0;
virtual std::tuple<Tensor,Tensor> kthvalue(const Tensor & self, int64_t k, int64_t dim, bool keepdim) const = 0;
AT_DEPRECATED(virtual Tensor linspace(Scalar start, Scalar end, int64_t steps) const = 0);
virtual Tensor log(const Tensor & self) const = 0;
virtual Tensor & log_(Tensor & self) const = 0;
virtual Tensor log10(const Tensor & self) const = 0;
Expand All @@ -479,7 +471,6 @@ struct AT_API Type {
virtual Tensor log2(const Tensor & self) const = 0;
virtual Tensor & log2_(Tensor & self) const = 0;
virtual Tensor logdet(const Tensor & self) const = 0;
AT_DEPRECATED(virtual Tensor logspace(Scalar start, Scalar end, int64_t steps) const = 0);
virtual Tensor log_softmax(const Tensor & self, int64_t dim) const = 0;
virtual Tensor logsumexp(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor matmul(const Tensor & self, const Tensor & other) const = 0;
Expand All @@ -504,16 +495,9 @@ struct AT_API Type {
virtual Tensor mvlgamma(const Tensor & self, int64_t p) const = 0;
virtual Tensor & mvlgamma_(Tensor & self, int64_t p) const = 0;
virtual Tensor narrow(const Tensor & self, int64_t dim, int64_t start, int64_t length) const = 0;
AT_DEPRECATED(virtual Tensor ones(IntList size) const = 0);
virtual Tensor permute(const Tensor & self, IntList dims) const = 0;
virtual Tensor pin_memory(const Tensor & self) const = 0;
virtual Tensor pinverse(const Tensor & self, double rcond) const = 0;
AT_DEPRECATED(virtual Tensor rand(IntList size, Generator * generator) const = 0);
AT_DEPRECATED(virtual Tensor randint(int64_t high, IntList size, Generator * generator) const = 0);
AT_DEPRECATED(virtual Tensor randint(int64_t low, int64_t high, IntList size, Generator * generator) const = 0);
AT_DEPRECATED(virtual Tensor randn(IntList size, Generator * generator) const = 0);
AT_DEPRECATED(virtual Tensor randperm(int64_t n, Generator * generator) const = 0);
AT_DEPRECATED(virtual Tensor range(Scalar start, Scalar end, Scalar step) const = 0);
virtual Tensor repeat(const Tensor & self, IntList repeats) const = 0;
virtual Tensor reshape(const Tensor & self, IntList shape) const = 0;
virtual Tensor reshape_as(const Tensor & self, const Tensor & other) const = 0;
Expand Down Expand Up @@ -581,7 +565,6 @@ struct AT_API Type {
virtual Tensor var(const Tensor & self, int64_t dim, bool unbiased, bool keepdim) const = 0;
virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0;
virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0;
AT_DEPRECATED(virtual Tensor zeros(IntList size) const = 0);
virtual Tensor norm(const Tensor & self, Scalar p) const = 0;
virtual Tensor norm(const Tensor & self, Scalar p, int64_t dim, bool keepdim) const = 0;
virtual Tensor clone(const Tensor & self) const = 0;
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/core/context_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <ATen/core/context_base.h>

namespace caffe2 {

// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h
StaticContextMap& GetStaticContexts() {
static StaticContextMap static_contexts;
return static_contexts;
}

void set_static_context(at::DeviceType t, BaseStaticContext* ptr) {
auto& static_contexts = GetStaticContexts();
static_contexts[t] = ptr;
}

BaseStaticContext* get_static_context(at::DeviceType t) {
auto* ptr = GetStaticContexts()[t];
AT_ASSERTM(ptr, "StaticContext for ", t, " is not registered yet.");
return ptr;
}

} // namespace caffe2
25 changes: 25 additions & 0 deletions aten/src/ATen/core/context_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/core/Error.h>
#include <ATen/core/UniqueVoidPtr.h>
#include <ATen/core/typeid.h>
#include <ATen/core/ATenGeneral.h>

namespace caffe2 {
class Event;
Expand Down Expand Up @@ -184,3 +185,27 @@ class AT_CORE_API BaseContext {
};

} // namespace at

namespace caffe2 {

using at::BaseContext;
using at::BaseStaticContext;

using StaticContextMap = std::unordered_map<at::DeviceType, BaseStaticContext*>;
AT_API StaticContextMap& GetStaticContexts();
AT_API void set_static_context(at::DeviceType t, BaseStaticContext* ptr);
AT_API BaseStaticContext* get_static_context(at::DeviceType t);

template <at::DeviceType t>
struct StaticContextFunctionRegisterer {
explicit StaticContextFunctionRegisterer(BaseStaticContext* ptr) {
set_static_context(t, ptr);
}
};

#define REGISTER_STATIC_CONTEXT(t, f) \
namespace { \
static StaticContextFunctionRegisterer<t> g_static_context_##d(f); \
}

} // namespace caffe2
10 changes: 10 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,14 @@ interleave2(const Vec256<T>& a, const Vec256<T>& b) {
Vec256<T>::loadu(static_cast<void*>(buffer2)));
}

template <typename src_T, typename dst_T>
void convert(const src_T *src, dst_T *dst, int64_t n) {
#pragma unroll
for (int64_t i = 0; i < n; i++) {
*dst = static_cast<dst_T>(*src);
src++;
dst++;
}
}

}}}
32 changes: 32 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,38 @@ struct Vec256<int32_t> : public Vec256i {
}
};

template <>
void convert(const int32_t *src, float *dst, int64_t n) {
int64_t i;
// int32_t and float have same size
#pragma unroll
for (i = 0; i <= (n - Vec256<int32_t>::size); i += Vec256<int32_t>::size) {
auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
auto output_vec = _mm256_cvtepi32_ps(input_vec);
_mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
}
#pragma unroll
for (; i < n; i++) {
dst[i] = static_cast<float>(src[i]);
}
}

template <>
void convert(const int32_t *src, double *dst, int64_t n) {
int64_t i;
// int32_t has half the size of double
#pragma unroll
for (i = 0; i <= (n - Vec256<double>::size); i += Vec256<double>::size) {
auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
_mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
}
#pragma unroll
for (; i < n; i++) {
dst[i] = static_cast<double>(src[i]);
}
}

template <>
struct Vec256<int16_t> : public Vec256i {
static constexpr int size = 16;
Expand Down
Loading