Skip to content

Commit 02cc730

Browse files
jytugapaszke
authored andcommitted
Thd functions v3 (#46)
1 parent ef1983f commit 02cc730

File tree

11 files changed

+549
-113
lines changed

11 files changed

+549
-113
lines changed

torch/lib/THD/base/Storage.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include "Tensor.hpp"
43
#include "Type.hpp"
54

65
#include <cstddef>
@@ -13,6 +12,8 @@
1312

1413
namespace thd {
1514

15+
class Tensor;
16+
1617
struct Storage {
1718
Storage() {};
1819
Storage(const Storage& other) = delete;
@@ -48,3 +49,4 @@ using IntStorage = StorageScalarInterface<long long>;
4849

4950
} // namespace thd
5051

52+
#include "Tensor.hpp"

torch/lib/THD/base/Tensor.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

3+
#include "Storage.hpp"
34
#include "Type.hpp"
45

6+
#include <TH/TH.h>
57
#include <cstddef>
68
#include <cstdint>
79
#include <initializer_list>
@@ -38,6 +40,21 @@ struct Tensor {
3840

3941
virtual Tensor& resize(const std::initializer_list<long>& new_size) = 0;
4042
virtual Tensor& resize(const std::vector<long>& new_size) = 0;
43+
virtual Tensor& resize(THLongStorage *size,
44+
THLongStorage *stride) = 0;
45+
virtual Tensor& resizeAs(const Tensor& src) = 0;
46+
virtual Tensor& set(const Tensor& src) = 0;
47+
virtual Tensor& setStorage(const Storage& storage,
48+
ptrdiff_t storageOffset,
49+
THLongStorage *size,
50+
THLongStorage *stride) = 0;
51+
virtual Tensor& narrow(const Tensor& src,
52+
int dimension,
53+
long firstIndex,
54+
long size) = 0;
55+
virtual Tensor& select(const Tensor& src, int dimension, long sliceIndex) = 0;
56+
virtual Tensor& transpose(const Tensor& src, int dimension1, int dimension2) = 0;
57+
virtual Tensor& unfold(const Tensor& src, int dimension, long size, long step) = 0;
4158

4259
virtual thd::Type type() const = 0;
4360
};

torch/lib/THD/base/storages/THStorage.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ struct THStorage : public interface_traits<real>::storage_interface_type {
5151
virtual thd::Type type() const override;
5252

5353
virtual std::unique_ptr<Tensor> newTensor() const override;
54+
virtual storage_type *getRaw() const;
5455

5556
protected:
5657
storage_type *storage;
5758
};
5859

5960
} // namespace thd
60-

torch/lib/THD/base/storages/generic/THStorage.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,9 @@ std::unique_ptr<Tensor> THStorage<real>::newTensor() const {
9393
return std::unique_ptr<Tensor>(new THTensor<real>());
9494
}
9595

96+
template<>
97+
THStorage<real>::storage_type *THStorage<real>::getRaw() const {
98+
return storage;
99+
}
100+
96101
#endif

torch/lib/THD/base/tensors/THTensor.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "THTensor.hpp"
22
#include "../Traits.hpp"
33

4-
54
namespace thd {
65

76
#include "generic/THTensor.cpp"

torch/lib/THD/base/tensors/THTensor.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ struct th_tensor_traits {};
1717
#include "base/tensors/generic/THTensor.hpp"
1818
#include <TH/THGenerateAllTypes.h>
1919

20+
} // namespace thd
21+
22+
#include "../storages/THStorage.hpp"
23+
24+
namespace thd {
2025

2126
template<typename real>
2227
struct THTensor : public interface_traits<real>::tensor_interface_type {
@@ -50,6 +55,23 @@ struct THTensor : public interface_traits<real>::tensor_interface_type {
5055

5156
virtual THTensor& resize(const std::initializer_list<long>& new_size) override;
5257
virtual THTensor& resize(const std::vector<long>& new_size) override;
58+
virtual THTensor& resize(THLongStorage *size,
59+
THLongStorage *stride) override;
60+
virtual THTensor& resizeAs(const Tensor& src) override;
61+
virtual THTensor& set(const Tensor& src) override;
62+
virtual THTensor& setStorage(const Storage& storage,
63+
ptrdiff_t storageOffset,
64+
THLongStorage *size,
65+
THLongStorage *stride) override;
66+
67+
virtual THTensor& narrow(const Tensor& src, int dimension,
68+
long firstIndex, long size) override;
69+
virtual THTensor& select(const Tensor& src, int dimension,
70+
long sliceIndex) override;
71+
virtual THTensor& transpose(const Tensor& src, int dimension1,
72+
int dimension2) override;
73+
virtual THTensor& unfold(const Tensor& src, int dimension,
74+
long size, long step) override;
5375

5476
virtual THTensor& fill(scalar_type value) override;
5577
virtual THTensor& add(const Tensor& source, scalar_type scalar) override;
@@ -59,6 +81,9 @@ struct THTensor : public interface_traits<real>::tensor_interface_type {
5981
private:
6082
template<typename iterator>
6183
THTensor& resize(const iterator& begin, const iterator& end);
84+
template<typename iterator>
85+
THTensor& resize(const iterator& size_begin, const iterator& size_end,
86+
const iterator& stride_begin, const iterator& stride_end);
6287

6388
protected:
6489
tensor_type *tensor;

torch/lib/THD/base/tensors/generic/THTensor.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,100 @@ auto THTensor<real>::resize(const std::vector<long> &new_size) -> THTensor& {
9494
return resize(new_size.begin(), new_size.end());
9595
}
9696

97+
template<>
98+
auto THTensor<real>::resize(THLongStorage *size,
99+
THLongStorage *stride) -> THTensor& {
100+
THTensor_(resize)(tensor, size, stride);
101+
return *this;
102+
}
103+
104+
template<>
105+
auto THTensor<real>::resizeAs(const Tensor& src) -> THTensor& {
106+
THTensor_(resizeAs)(tensor, dynamic_cast<const THTensor<real>&>(src).tensor);
107+
return *this;
108+
}
109+
97110
template<>
98111
template<typename iterator>
99112
auto THTensor<real>::resize(const iterator& begin, const iterator& end) -> THTensor& {
100113
THLongStorage *sizes = THLongStorage_newWithSize(std::distance(begin, end));
101114
long *sizes_d = sizes->data;
102115
for (auto it = begin; it != end; ++it)
103116
*sizes_d++ = *it;
117+
// TODO this might leak on error
104118
THTensor_(resize)(tensor, sizes, nullptr);
119+
THLongStorage_free(sizes);
120+
return *this;
121+
}
122+
123+
template<>
124+
auto THTensor<real>::set(const Tensor& src) -> THTensor& {
125+
THTensor_(set)(
126+
tensor,
127+
(dynamic_cast<const THTensor<real>&>(src)).tensor
128+
);
129+
return *this;
130+
}
131+
132+
template<>
133+
auto THTensor<real>::setStorage(const Storage& storage,
134+
ptrdiff_t storageOffset,
135+
THLongStorage *size,
136+
THLongStorage *stride) -> THTensor& {
137+
THTensor_(setStorage)(
138+
tensor,
139+
(dynamic_cast<const THStorage<real>&>(storage)).getRaw(),
140+
storageOffset,
141+
size,
142+
stride
143+
);
144+
return *this;
145+
}
146+
147+
template<>
148+
auto THTensor<real>::narrow(const Tensor& src,
149+
int dimension,
150+
long firstIndex,
151+
long size) -> THTensor& {
152+
THTensor_(narrow)(
153+
tensor,
154+
(dynamic_cast<const THTensor<real>&>(src)).tensor,
155+
dimension,
156+
firstIndex,
157+
size
158+
);
159+
return *this;
160+
}
161+
162+
template<>
163+
auto THTensor<real>::select(const Tensor& src, int dimension,
164+
long sliceIndex) -> THTensor& {
165+
THTensor_(select)(
166+
tensor,
167+
(dynamic_cast<const THTensor<real>&>(src)).tensor,
168+
dimension,
169+
sliceIndex
170+
);
171+
return *this;
172+
}
173+
174+
template<>
175+
auto THTensor<real>::transpose(const Tensor& src, int dimension1,
176+
int dimension2) -> THTensor& {
177+
auto src_raw = (dynamic_cast<const THTensor<real>&>(src)).tensor;
178+
if (tensor != src_raw)
179+
set(src);
180+
THTensor_(transpose)(tensor, src_raw, dimension1, dimension2);
181+
return *this;
182+
}
183+
184+
template<>
185+
auto THTensor<real>::unfold(const Tensor& src, int dimension,
186+
long size, long step) ->THTensor& {
187+
auto src_raw = (dynamic_cast<const THTensor<real>&>(src)).tensor;
188+
if (tensor != src_raw)
189+
set(src);
190+
THTensor_(unfold)(tensor, src_raw, dimension, size, step);
105191
return *this;
106192
}
107193

torch/lib/THD/master_worker/common/Functions.hpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,28 @@
55
namespace thd {
66

77
enum Functions: std::uint16_t {
8-
construct,
9-
constructWithSize,
10-
free,
11-
resize,
12-
resizeAs,
13-
resize1d,
14-
resize2d,
15-
resize3d,
16-
resize4d,
17-
resize5d,
18-
set,
19-
setStorage,
20-
setStorage1d,
21-
setStorage2d,
22-
setStorage3d,
23-
setStorage4d,
24-
narrow,
25-
select,
26-
add,
27-
fill,
8+
tensorConstruct,
9+
tensorConstructWithSize,
10+
tensorFree,
11+
tensorResize,
12+
tensorResizeAs,
13+
tensorResize1d,
14+
tensorResize2d,
15+
tensorResize3d,
16+
tensorResize4d,
17+
tensorResize5d,
18+
tensorSet,
19+
tensorSetStorage,
20+
tensorSetStorage1d,
21+
tensorSetStorage2d,
22+
tensorSetStorage3d,
23+
tensorSetStorage4d,
24+
tensorNarrow,
25+
tensorSelect,
26+
tensorTranspose,
27+
tensorUnfold,
28+
tensorAdd,
29+
tensorFill,
2830

2931
// storage functions
3032
storageSet,

0 commit comments

Comments
 (0)