Skip to content

Commit a3ecb33

Browse files
authored
Improve the comments at the beginning of index_compute.h (#1946)
I just started to learn indexing, and the comment at the beginning of index_compute.h does not look good...
1 parent f7bc341 commit a3ecb33

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

torch/csrc/jit/codegen/cuda/index_compute.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,40 @@
1717
* indices (based on input indices) that match the root dimension.
1818
*
1919
* For example with GLOBAL tensor:
20-
* TV[I, J]
21-
* TV[Io, Ii{4}, J] = TV.split(I, factor=4)
20+
* TV[I, K]
21+
* TV[Io, Ii{4}, K] = TV.split(I, factor=4)
2222
* ALLOC: NONE
2323
* INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
24-
* FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k}
24+
* FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
2525
* PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
2626
*
2727
*
2828
* For example with SHARED tensor:
2929
*
30-
* global_TV[I, J]
31-
* global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
30+
* global_TV[I, K]
31+
* global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
3232
* smem_TV.compute_at(global_TV, 1)
3333
* global_TV.parallelize(1, threadIDx.x)
3434
*
35-
* ALLOC: alloc(smem_TV, 4 x J)
35+
* ALLOC: alloc(smem_TV, 4 x K)
3636
* INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
37-
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k}
37+
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
3838
* PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
3939
* global
4040
*
4141
*
4242
* For example with LOCAL tensor:
43-
* global_TV[I, J, K]
44-
* global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
45-
* reg_TV.compute_at(global_TV, 1)
43+
* global_TV[I, K, L]
44+
* global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
45+
* reg_TV.compute_at(global_TV, 2)
4646
* global_TV.parallelize(1, threadIDx.x)
4747
* global_TV{i, j, k, l} -> { i * 4 + j, k, l }
48-
* global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K + k * K + l}
48+
* global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
4949
*
50-
* ALLOC: alloc(reg_TV, J x K)
50+
* ALLOC: alloc(reg_TV, K x L)
5151
* INDEX: {k, l} -> {k, l}
52-
* FLATTENED_INDEX: {k, l} -> {k * J + l}
53-
* PREDICATE: i * 4 + j < I && k < J && l < K -> // Same as if global
52+
* FLATTENED_INDEX: {k, l} -> {k * L + l}
53+
* PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
5454
*
5555
* These indices can then be flattened later based on strides.
5656
*/

0 commit comments

Comments
 (0)