24
24
namespace kaldi {
25
25
namespace ctc {
26
26
27
+ // This function is used inside Tensor3dCopy to swap x with y(or z) if x's
28
+ // stride is not 1 but y(or z)'s stride is 1. Since x maps to the thread
29
+ // index in GPU, it would be more efficient if {src|dest}_xstride is 1.
30
+ // Note that the one with both src and dst strides being 1 is more preferable
31
+ // to swap with x than those with only src or dst stride being 1.
32
+ static void SwapDimsForX (int32* xdim, int32* ydim, int32* zdim,
33
+ int32* src_xstride, int32* src_ystride, int32* src_zstride,
34
+ int32* dst_xstride, int32* dst_ystride, int32* dst_zstride) {
35
+ if (*src_xstride != 1 || *dst_xstride != 1 ) {
36
+ // first try to look for y or z with both src and dst strides being 1
37
+ if (*src_ystride == 1 && *dst_ystride == 1 ) {
38
+ std::swap (*xdim, *ydim);
39
+ std::swap (*src_xstride, *src_ystride);
40
+ std::swap (*dst_xstride, *dst_ystride);
41
+ }
42
+ else if (*src_zstride == 1 && *dst_zstride == 1 ) {
43
+ std::swap (*xdim, *zdim);
44
+ std::swap (*src_xstride, *src_zstride);
45
+ std::swap (*dst_xstride, *dst_zstride);
46
+ }
47
+ // then try to look for the one with only src or dst stride being 1
48
+ else if (*src_xstride != 1 && *dst_xstride != 1 ) {
49
+ if (*src_ystride == 1 || *dst_ystride == 1 ) {
50
+ std::swap (*xdim, *ydim);
51
+ std::swap (*src_xstride, *src_ystride);
52
+ std::swap (*dst_xstride, *dst_ystride);
53
+ }
54
+ else if (*src_zstride == 1 || *dst_zstride == 1 ) {
55
+ std::swap (*xdim, *zdim);
56
+ std::swap (*src_xstride, *src_zstride);
57
+ std::swap (*dst_xstride, *dst_zstride);
58
+ }
59
+ }
60
+ }
61
+ }
62
+
27
63
template <typename Real>
28
64
void Tensor3dCopy (int32 xdim, int32 ydim, int32 zdim,
29
65
int32 src_xstride, int32 src_ystride, int32 src_zstride,
30
66
int32 dst_xstride, int32 dst_ystride, int32 dst_zstride,
31
67
const Real *src, Real *dst) {
68
+ SwapDimsForX (&xdim, &ydim, &zdim, &src_xstride, &src_ystride, &src_zstride,
69
+ &dst_xstride, &dst_ystride, &dst_zstride);
70
+
32
71
#if HAVE_CUDA == 1
33
72
if (CuDevice::Instantiate ().Enabled ()) {
34
73
Timer tim;
@@ -45,9 +84,10 @@ void Tensor3dCopy(int32 xdim, int32 ydim, int32 zdim,
45
84
} else
46
85
#endif
47
86
{
48
- for (int32 x = 0 ; x < xdim; x++)
87
+ // iterate over z, y, x if xstride is 1, for memory-locality reasons.
88
+ for (int32 z = 0 ; z < zdim; z++)
49
89
for (int32 y = 0 ; y < ydim; y++)
50
- for (int32 z = 0 ; z < zdim; z ++)
90
+ for (int32 x = 0 ; x < xdim; x ++)
51
91
dst[x * dst_xstride + y * dst_ystride + z * dst_zstride] =
52
92
src[x * src_xstride + y * src_ystride + z * src_zstride];
53
93
}
0 commit comments