Skip to content

Commit d21065c

Browse files
authored
one_hot tensor_op added (#23)
1 parent 5290e45 commit d21065c

File tree

4 files changed

+107
-2
lines changed

4 files changed

+107
-2
lines changed

bnn/core/tensor_ops.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,34 @@ namespace bnn
113113
divide
114114
(TensorCPU<data_type>* x, data_type divisor);
115115

116+
/*
117+
* Returns a one hot tensor with a new axis appended
118+
* in its shape.
119+
*
120+
* @tparam data_type Data type of the elements
121+
* supported by C++.
122+
* @param x Tensor<data_type>* Tensor whose
123+
* elements are to be converted to one
124+
* hot notation.
125+
* @param on_value data_type The value which
126+
* is to be used for filling in the one
127+
* hot tensor when the value in the original
128+
* tensor matches with the index in the
129+
* range [0, depth) for the new axis.
130+
* @param off_value data_type The value which
131+
* is to be used for filling in the one
132+
* hot tensor when the value in the original
133+
* tensor doesn't match with the index in the
134+
* range [0, depth) for the new axis.
135+
* @param depth unsigned The size of the axis which
136+
* is to be appended.
137+
*/
138+
template <class data_type>
139+
TensorCPU<data_type>*
140+
one_hot
141+
(TensorCPU<data_type>* x, data_type on_value,
142+
data_type off_value, unsigned depth);
143+
116144
}
117145
}
118146

bnn/core/tensor_ops_impl.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,6 @@ namespace bnn
207207
template <class data_type>
208208
struct DivideArgs: ScalarArgs<data_type>
209209
{
210-
data_type val;
211-
212210
data_type* zd;
213211
};
214212

@@ -393,6 +391,52 @@ namespace bnn
393391
return z;
394392
}
395393

394+
template <class data_type>
395+
struct OneHot: UnaryArgs<data_type>
396+
{
397+
data_type* zd;
398+
399+
data_type on, off;
400+
401+
unsigned depth;
402+
403+
};
404+
405+
template <class data_type>
406+
void
407+
_one_hot_job
408+
(Args<data_type>* _args, unsigned start,
409+
unsigned end)
410+
{
411+
OneHot<data_type>* args = reinterpret_cast<OneHot<data_type>*>(_args);
412+
unsigned k = start;
413+
for(unsigned i = start; i < end; i++)
414+
{
415+
for(unsigned j = 0; j < args->depth; j++)
416+
{
417+
args->zd[k] = j == args->xd[i] ? args->on : args->off;
418+
k++;
419+
}
420+
}
421+
}
422+
423+
template <class data_type>
424+
TensorCPU<data_type>*
425+
one_hot
426+
(TensorCPU<data_type>* x, data_type on_value,
427+
data_type off_value, unsigned depth)
428+
{
429+
vector<unsigned> shape
430+
(x->get_shape(), x->get_shape() + x->get_ndims());
431+
shape.push_back(depth);
432+
TensorCPU<data_type>* z = new TensorCPU<data_type>(shape);
433+
OneHot<data_type> args;
434+
args.zd = z->get_data_pointer();
435+
args.on = on_value, args.off = off_value, args.depth = depth;
436+
op(x, &args, &_one_hot_job<data_type>);
437+
return z;
438+
}
439+
396440
#include "bnn/templates/core/tensor_ops.hpp"
397441

398442
}

bnn/templates/core/tensor_ops.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,17 @@ template TensorCPU<unsigned long long>* divide<unsigned long long>(TensorCPU<uns
8585
template TensorCPU<float>* divide<float>(TensorCPU<float>* x, float divisor);
8686
template TensorCPU<double>* divide<double>(TensorCPU<double>* x, double divisor);
8787
template TensorCPU<long double>* divide<long double>(TensorCPU<long double>* x, long double divisor);
88+
template TensorCPU<bool>* one_hot<bool>(TensorCPU<bool>* x, bool on_value, bool off_value, unsigned depth);
89+
template TensorCPU<short>* one_hot<short>(TensorCPU<short>* x, short on_value, short off_value, unsigned depth);
90+
template TensorCPU<unsigned short>* one_hot<unsigned short>(TensorCPU<unsigned short>* x, unsigned short on_value, unsigned short off_value, unsigned depth);
91+
template TensorCPU<int>* one_hot<int>(TensorCPU<int>* x, int on_value, int off_value, unsigned depth);
92+
template TensorCPU<unsigned int>* one_hot<unsigned int>(TensorCPU<unsigned int>* x, unsigned int on_value, unsigned int off_value, unsigned depth);
93+
template TensorCPU<long>* one_hot<long>(TensorCPU<long>* x, long on_value, long off_value, unsigned depth);
94+
template TensorCPU<unsigned long>* one_hot<unsigned long>(TensorCPU<unsigned long>* x, unsigned long on_value, unsigned long off_value, unsigned depth);
95+
template TensorCPU<long long>* one_hot<long long>(TensorCPU<long long>* x, long long on_value, long long off_value, unsigned depth);
96+
template TensorCPU<unsigned long long>* one_hot<unsigned long long>(TensorCPU<unsigned long long>* x, unsigned long long on_value, unsigned long long off_value, unsigned depth);
97+
template TensorCPU<float>* one_hot<float>(TensorCPU<float>* x, float on_value, float off_value, unsigned depth);
98+
template TensorCPU<double>* one_hot<double>(TensorCPU<double>* x, double on_value, double off_value, unsigned depth);
99+
template TensorCPU<long double>* one_hot<long double>(TensorCPU<long double>* x, long double on_value, long double off_value, unsigned depth);
88100

89101
#endif

bnn/tests/test_core.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ TEST(Core, TensorOpsDivide)
4848
EXPECT_EQ(z->at(0, 4, 8, 50), 2)<<"Expected quotient is 2.";
4949
}
5050

51+
TEST(Core, TensorOpsOneHot)
52+
{
53+
vector<unsigned> shape = {1000};
54+
TensorCPU<unsigned>* labels = new TensorCPU<unsigned>(shape);
55+
bnn::core::fill(labels, (unsigned)9);
56+
for(unsigned i = 0; i < 9; i++)
57+
{
58+
labels->set(i, i);
59+
}
60+
TensorCPU<unsigned>* new_labels = one_hot(labels, (unsigned)1, (unsigned)0, (unsigned)10);
61+
for(unsigned i = 0; i < 10; i++)
62+
{
63+
for(unsigned j = 0; j < 10; j++)
64+
{
65+
EXPECT_EQ(new_labels->at(i, j), labels->at(i) == j)
66+
<<"Expected one hot value is "<<(labels->at(i) == j)
67+
<<" for "<<i<<", "<<j;
68+
}
69+
}
70+
}
71+
5172
TEST(Core, TensorCPU)
5273
{
5374
TensorCPU<float>* t_f = new TensorCPU<float>;

0 commit comments

Comments
 (0)