|
17 | 17 | * indices (based on input indices) that match the root dimension.
|
18 | 18 | *
|
19 | 19 | * For example with GLOBAL tensor:
|
20 |
| - * TV[I, J] |
21 |
| - * TV[Io, Ii{4}, J] = TV.split(I, factor=4) |
| 20 | + * TV[I, K] |
| 21 | + * TV[Io, Ii{4}, K] = TV.split(I, factor=4) |
22 | 22 | * ALLOC: NONE
|
23 | 23 | * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
|
24 |
| - * FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k} |
| 24 | + * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k} |
25 | 25 | * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
|
26 | 26 | *
|
27 | 27 | *
|
28 | 28 | * For example with SHARED tensor:
|
29 | 29 | *
|
30 |
| - * global_TV[I, J] |
31 |
| - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) |
| 30 | + * global_TV[I, K] |
| 31 | + * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4) |
32 | 32 | * smem_TV.compute_at(global_TV, 1)
|
33 | 33 | * global_TV.parallelize(1, threadIDx.x)
|
34 | 34 | *
|
35 |
| - * ALLOC: alloc(smem_TV, 4 x J) |
| 35 | + * ALLOC: alloc(smem_TV, 4 x K) |
36 | 36 | * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
|
37 |
| - * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k} |
| 37 | + * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k} |
38 | 38 | * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
|
39 | 39 | * global
|
40 | 40 | *
|
41 | 41 | *
|
42 | 42 | * For example with LOCAL tensor:
|
43 |
| - * global_TV[I, J, K] |
44 |
| - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) |
45 |
| - * reg_TV.compute_at(global_TV, 1) |
| 43 | + * global_TV[I, K, L] |
| 44 | + * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4) |
| 45 | + * reg_TV.compute_at(global_TV, 2) |
46 | 46 | * global_TV.parallelize(1, threadIDx.x)
|
47 | 47 | * global_TV{i, j, k, l} -> { i * 4 + j, k, l }
|
48 |
| - * global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K + k * K + l} |
| 48 | + * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l} |
49 | 49 | *
|
50 |
| - * ALLOC: alloc(reg_TV, J x K) |
| 50 | + * ALLOC: alloc(reg_TV, K x L) |
51 | 51 | * INDEX: {k, l} -> {k, l}
|
52 |
| - * FLATTENED_INDEX: {k, l} -> {k * J + l} |
53 |
| - * PREDICATE: i * 4 + j < I && k < J && l < K -> // Same as if global |
| 52 | + * FLATTENED_INDEX: {k, l} -> {k * L + l} |
| 53 | + * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global |
54 | 54 | *
|
55 | 55 | * These indices can then be flattened later based on strides.
|
56 | 56 | */
|
|
0 commit comments