Skip to content

Update THC to master #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 1, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torch/lib/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ INSTALL(FILES
generic/THCTensorMathPointwise.cu
generic/THCTensorMathReduce.h
generic/THCTensorMathReduce.cu
generic/THCTensorMathScan.h
generic/THCTensorMathScan.cu
generic/THCTensorScatterGather.h
generic/THCTensorScatterGather.cu
generic/THCTensorIndex.h
Expand Down
30 changes: 30 additions & 0 deletions torch/lib/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,36 @@ double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y,
return 0;
}

#ifdef CUDA_HALF_TENSOR
half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy)
{
#if CUDA_VERSION >= 8000
if (n == 1) {
incx = 1;
incy = 1;
}

if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
half result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_16F, CUDA_R_32F));
return result;
}

THError("Cublas_Hdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return THC_float2half(0);
#else
THError("Cublas_Hdot requires CUDA 8.0+");
return THC_float2half(0);
#endif
}
#endif

/* Level 2 */
void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy)
{
Expand Down
3 changes: 3 additions & 0 deletions torch/lib/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
/* Level 1 */
THC_API float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, long incy);
THC_API double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y, long incy);
#ifdef CUDA_HALF_TENSOR
THC_API half THCudaBlas_Hdot(THCState *state, long n, half *x, long incx, half *y, long incy);
#endif

/* Level 2 */
THC_API void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha, float *a, long lda, float *x, long incx, float beta, float *y, long incy);
Expand Down
93 changes: 93 additions & 0 deletions torch/lib/THC/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,20 @@
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCNumerics.cuh"
#include "THCTensorMath.cuh"

#include <thrust/copy.h>
#include <thrust/count.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <cfloat>

template <typename T>
Expand All @@ -14,5 +27,85 @@ struct TensorFillOp {
const T val;
};

// copypasta from https://github.com/thrust/thrust/blob/master/examples/strided_range.cu
template <typename Iterator>
class strided_range
{
public:

typedef typename thrust::iterator_difference<Iterator>::type difference_type;

struct stride_functor : public thrust::unary_function<difference_type,
difference_type>
{
difference_type stride;

stride_functor(difference_type stride)
: stride(stride) {}

__host__ __device__
difference_type operator()(const difference_type& i) const
{
return stride * i;
}
};

typedef typename thrust::counting_iterator<difference_type> CountingIterator;
typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator;

// type of the strided_range iterator
typedef PermutationIterator iterator;

// construct strided_range for the range [first,last)
strided_range(Iterator first, Iterator last, difference_type stride)
: first(first), last(last), stride(stride) {}

iterator begin(void) const
{
return PermutationIterator(first,
TransformIterator(CountingIterator(0),
stride_functor(stride)));
}

iterator end(void) const
{
return begin() + ((last - first) + (stride - 1)) / stride;
}

protected:
Iterator first;
Iterator last;
difference_type stride;
};

struct idx_functor
{
long div;
long size;

__host__ __device__
idx_functor(long div, long size) : div(div), size(size) {}

__host__ __device__
long operator()(long val) {
return (val / div) % size + 1;
}
};

template <typename T>
struct NonZeroOp
{
NonZeroOp() {}
__host__ __device__ bool operator()(T lhs) const {
if (THCNumerics<T>::ne(lhs, ScalarConvert<float, T>::to(0.0))) {
return true;
} else {
return false;
}
}
};


#include "generic/THCTensorMath.cu"
#include "THCGenerateAllTypes.h"
26 changes: 26 additions & 0 deletions torch/lib/THC/THCTensorMath.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef THC_TENSORMATH_CUH
#define THC_TENSORMATH_CUH

// Copy the kth diagonal of a matrix B to a vector A.
template <typename T>
__global__ void THCTensor_copyFromDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideA) {
for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < size;
linearIndex += gridDim.x * blockDim.x) {
const ptrdiff_t bOffset = start + strideSum * linearIndex;
a[strideA * linearIndex] = b[bOffset];
}
}

// Copy vector B to the kth diagonal of a matrix A
template <typename T>
__global__ void THCTensor_copyToDiagonal(T* a, T* b, ptrdiff_t start, ptrdiff_t size, ptrdiff_t strideSum, ptrdiff_t strideB) {
for (ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < size;
linearIndex += gridDim.x * blockDim.x) {
const ptrdiff_t aOffset = start + strideSum * linearIndex;
a[aOffset] = b[strideB * linearIndex];
}
}

#endif
13 changes: 3 additions & 10 deletions torch/lib/THC/THCTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include "generic/THCTensorMathCompareT.h"
#include "THCGenerateAllTypes.h"

#include "generic/THCTensorMathScan.h"
#include "THCGenerateAllTypes.h"

#include "generic/THCTensorMasked.h"
#include "THCGenerateAllTypes.h"

Expand All @@ -37,14 +40,6 @@
#include "generic/THCTensorSort.h"
#include "THCGenerateAllTypes.h"

THC_API void THCudaTensor_tril(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
THC_API void THCudaTensor_triu(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
THC_API void THCudaTensor_diag(THCState *state, THCudaTensor *self, THCudaTensor *src, long k);
THC_API float THCudaTensor_trace(THCState *state, THCudaTensor *self);

THC_API void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);
THC_API void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim);

// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCudaTensor_gesv(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
THC_API void THCudaTensor_gels(THCState *state, THCudaTensor *rb_, THCudaTensor *ra_, THCudaTensor *b_, THCudaTensor *a_);
Expand All @@ -58,8 +53,6 @@ THC_API void THCudaTensor_potrf(THCState *state, THCudaTensor *ra_, THCudaTensor
THC_API void THCudaTensor_potrs(THCState *state, THCudaTensor *rb_, THCudaTensor *a, THCudaTensor *b);
THC_API void THCudaTensor_qr(THCState *state, THCudaTensor *rq_, THCudaTensor *rr_, THCudaTensor *a);

THC_API float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value);

THC_API void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size);
THC_API void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size);

Expand Down
62 changes: 0 additions & 62 deletions torch/lib/THC/THCTensorMath2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,6 @@
#include "THCTensorMathReduce.cuh"
#include "THCTensorMathPointwise.cuh"

#include <thrust/device_ptr.h>
#include <thrust/transform_reduce.h>
#include <thrust/functional.h>
#include <thrust/inner_product.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif

struct TensorTPowOp {
TensorTPowOp(float v) : val(v) {}

__device__ __forceinline__ void operator()(float* out, float* in) {
*out = powf(val, *in);
}

__device__ __forceinline__ void operator()(float* v) {
*v = powf(val, *v);
}

const float val;
};

void THCudaTensor_tpow(THCState *state, THCudaTensor *self_, float value, THCudaTensor *src)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
if (!THC_pointwiseApply1(state, self_, TensorTPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);

if (!THC_pointwiseApply2(state, self_, src, TensorTPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}

THCudaCheck(cudaGetLastError());
}

struct TensorATan2Op {
__device__ __forceinline__ void operator()(float* out, float* a, float* b) {
*out = atan2f(*a, *b);
Expand All @@ -68,28 +28,6 @@ void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx,
THCudaCheck(cudaGetLastError());
}

float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
self = THCudaTensor_newContiguous(state, self);
ptrdiff_t size = THCudaTensor_nElement(state, self);
src = THCudaTensor_newContiguous(state, src);
thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
thrust::device_ptr<float> src_data(THCudaTensor_data(state, src));

float result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, (float) 0,
thrust::plus<float>(), TensorDistOp<float>(value));

THCudaTensor_free(state, src);
THCudaTensor_free(state, self);

return pow(result, (float)1.0/value);
}

void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size)
{
THAssert(THCudaTensor_checkGPU(state, 1, r_));
Expand Down
Loading