Skip to content

Commit 49f91ab

Browse files
committed
Merge pull request #1 from freewym/opt
swap axis for optimization in Tensor3dCopy()
2 parents 6c15072 + c831186 commit 49f91ab

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

src/ctc/cctc-tombstone.cc

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,50 @@
2424
namespace kaldi {
2525
namespace ctc {
2626

27+
// This function is used inside Tensor3dCopy to swap x with y(or z) if x's
28+
// stride is not 1 but y(or z)'s stride is 1. Since x maps to the thread
29+
// index in GPU, it would be more efficient if {src|dest}_xstride is 1.
30+
// Note that the one with both src and dst strides being 1 is more preferable
31+
// to swap with x than those with only src or dst stride being 1.
32+
static void SwapDimsForX(int32* xdim, int32* ydim, int32* zdim,
33+
int32* src_xstride, int32* src_ystride, int32* src_zstride,
34+
int32* dst_xstride, int32* dst_ystride, int32* dst_zstride) {
35+
if (*src_xstride != 1 || *dst_xstride != 1) {
36+
// first try to look for y or z with both src and dst strides being 1
37+
if (*src_ystride == 1 && *dst_ystride == 1) {
38+
std::swap(*xdim, *ydim);
39+
std::swap(*src_xstride, *src_ystride);
40+
std::swap(*dst_xstride, *dst_ystride);
41+
}
42+
else if (*src_zstride == 1 && *dst_zstride == 1) {
43+
std::swap(*xdim, *zdim);
44+
std::swap(*src_xstride, *src_zstride);
45+
std::swap(*dst_xstride, *dst_zstride);
46+
}
47+
// then try to look for the one with only src or dst stride being 1
48+
else if (*src_xstride != 1 && *dst_xstride != 1) {
49+
if (*src_ystride == 1 || *dst_ystride == 1) {
50+
std::swap(*xdim, *ydim);
51+
std::swap(*src_xstride, *src_ystride);
52+
std::swap(*dst_xstride, *dst_ystride);
53+
}
54+
else if (*src_zstride == 1 || *dst_zstride == 1) {
55+
std::swap(*xdim, *zdim);
56+
std::swap(*src_xstride, *src_zstride);
57+
std::swap(*dst_xstride, *dst_zstride);
58+
}
59+
}
60+
}
61+
}
62+
2763
template <typename Real>
2864
void Tensor3dCopy(int32 xdim, int32 ydim, int32 zdim,
2965
int32 src_xstride, int32 src_ystride, int32 src_zstride,
3066
int32 dst_xstride, int32 dst_ystride, int32 dst_zstride,
3167
const Real *src, Real *dst) {
68+
SwapDimsForX(&xdim, &ydim, &zdim, &src_xstride, &src_ystride, &src_zstride,
69+
&dst_xstride, &dst_ystride, &dst_zstride);
70+
3271
#if HAVE_CUDA == 1
3372
if (CuDevice::Instantiate().Enabled()) {
3473
Timer tim;
@@ -45,9 +84,10 @@ void Tensor3dCopy(int32 xdim, int32 ydim, int32 zdim,
4584
} else
4685
#endif
4786
{
48-
for (int32 x = 0; x < xdim; x++)
87+
// iterate over z, y, x if xstride is 1, for memory-locality reasons.
88+
for (int32 z = 0; z < zdim; z++)
4989
for (int32 y = 0; y < ydim; y++)
50-
for (int32 z = 0; z < zdim; z++)
90+
for (int32 x = 0; x < xdim; x++)
5191
dst[x * dst_xstride + y * dst_ystride + z * dst_zstride] =
5292
src[x * src_xstride + y * src_ystride + z * src_zstride];
5393
}

0 commit comments

Comments
 (0)