|
| 1 | +#include <torch/csrc/jit/codegen/cuda/swizzle.h> |
| 2 | + |
| 3 | +#include <torch/csrc/jit/codegen/cuda/arith.h> |
| 4 | +#include <torch/csrc/jit/codegen/cuda/ir_builder.h> |
| 5 | + |
| 6 | +namespace torch { |
| 7 | +namespace jit { |
| 8 | +namespace fuser { |
| 9 | +namespace cuda { |
| 10 | +namespace swizzles { |
| 11 | + |
| 12 | +// ------------------------------------------------------------ |
| 13 | +// Swizzle Definitions |
| 14 | +// for each swizzle name: |
| 15 | +// un(Swizzle Name) e.g. unZShape is the inverse of ZShape, |
| 16 | +// (unswizzle is needed for inlining and is currently not actively used.) |
| 17 | +// ------------------------------------------------------------ |
| 18 | + |
| 19 | +// Unit Z swizzle: |
| 20 | +// Alternate directions of Y dimension: |
| 21 | +// 1 2 3 1 2 3 |
| 22 | +// 4 5 6 => 6 5 4 |
| 23 | +// 7 8 9 7 8 9 |
| 24 | +std::pair<Val*, Val*> ZShape(Val* x, Val* y, Val* size_y) { |
| 25 | + auto zero = x->fusion()->zeroVal(); |
| 26 | + auto one = x->fusion()->oneVal(); |
| 27 | + auto two = IrBuilder::create<Int>(2); |
| 28 | + return {x, where(eq(mod(x, two), zero), y, sub(sub(size_y, one), y))}; |
| 29 | +} |
| 30 | + |
| 31 | +// ZShape is inverse of itself |
| 32 | +std::pair<Val*, Val*> unZShape(Val* x, Val* y, Val* size_y) { |
| 33 | + return ZShape(x, y, size_y); |
| 34 | +} |
| 35 | + |
| 36 | +// Block cyclic Xor swizzle: (bank conflict removal) |
| 37 | +// Apply cyclic Xor within blocks: |
| 38 | +// Example: cyclic Xor |
| 39 | +// 1 2 3 4 1 2 3 4 |
| 40 | +// 5 6 7 8 6 5 8 7 |
| 41 | +// 9 10 11 12 => 11 12 9 10 |
| 42 | +// 13 14 15 16 16 15 14 13 |
| 43 | +std::pair<Val*, Val*> Xor(Val* x, Val* y) { |
| 44 | + // Need to validate in swizzle configuration: |
| 45 | + // size_x == size_y |
| 46 | + return {x, bitwise_xor(x, y)}; |
| 47 | +} |
| 48 | + |
| 49 | +// Xor is inverse of itself |
| 50 | +std::pair<Val*, Val*> unXor(Val* x, Val* y) { |
| 51 | + return Xor(x, y); |
| 52 | +} |
| 53 | + |
| 54 | +// Block cyclic shift swizzle: (bank conflict removal) |
| 55 | +// Apply cyclic shift within blocks: |
| 56 | +// Example: cyclic shift |
| 57 | +// 1 2 3 4 1 2 3 4 |
| 58 | +// 5 6 7 8 8 5 6 7 |
| 59 | +// 9 10 11 12 => 11 12 9 10 |
| 60 | +// 13 14 15 16 14 15 16 13 |
| 61 | +std::pair<Val*, Val*> CyclicShift(Val* x, Val* y, Val* size_x) { |
| 62 | + return {x, mod(add(x, y), size_x)}; |
| 63 | +} |
| 64 | + |
| 65 | +std::pair<Val*, Val*> unCyclicShift(Val* x, Val* y, Val* size_x) { |
| 66 | + return {x, mod(sub(add(size_x, y), x), size_x)}; |
| 67 | +} |
| 68 | + |
| 69 | +// Scatter swizzle: |
| 70 | +// Corresponds to the data layout out of ldmatrix intrinsic. |
| 71 | +// supported dimensions are : 8x4, 16x4, 32x4 |
| 72 | +std::pair<Val*, Val*> Scatter(Val* x, Val* y, int size_x) { |
| 73 | + TORCH_CHECK( |
| 74 | + size_x == 8 || size_x == 16 || size_x == 32, |
| 75 | + "Unsupported Scatter swizzle size"); |
| 76 | + Val* size_x_val = IrBuilder::create<Int>(size_x); |
| 77 | + auto four = IrBuilder::create<Int>(4); |
| 78 | + return {cpp_div(add(mul(y, size_x_val), x), four), mod(x, four)}; |
| 79 | +} |
| 80 | + |
| 81 | +std::pair<Val*, Val*> unScatter(Val* x, Val* y, int size_x) { |
| 82 | + TORCH_CHECK( |
| 83 | + size_x == 8 || size_x == 16 || size_x == 32, |
| 84 | + "Unsupported Scatter swizzle size"); |
| 85 | + Val* size_x_div_4 = IrBuilder::create<Int>(size_x / 4); |
| 86 | + auto four = IrBuilder::create<Int>(4); |
| 87 | + return {add(y, mul(mod(x, size_x_div_4), four)), cpp_div(x, size_x_div_4)}; |
| 88 | +} |
| 89 | + |
| 90 | +} // namespace swizzles |
| 91 | + |
| 92 | +std::pair<Val*, Val*> dispatchSwizzle( |
| 93 | + Swizzle2DType type, |
| 94 | + Val* x, |
| 95 | + Val* y, |
| 96 | + Val* maybe_size_x, |
| 97 | + Val* maybe_size_y) { |
| 98 | + switch (type) { |
| 99 | + case Swizzle2DType::ZShape: |
| 100 | + return swizzles::ZShape(x, y, maybe_size_y); |
| 101 | + case Swizzle2DType::XOR: |
| 102 | + return swizzles::Xor(x, y); |
| 103 | + case Swizzle2DType::CyclicShift: |
| 104 | + return swizzles::CyclicShift(x, y, maybe_size_x); |
| 105 | + case Swizzle2DType::Scatter: |
| 106 | + return swizzles::Scatter(x, y, maybe_size_x->evaluateInt()); |
| 107 | + default: |
| 108 | + TORCH_INTERNAL_ASSERT(false, "Unsupported swizzle type"); |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +std::pair<Val*, Val*> dispatchUnSwizzle( |
| 113 | + Swizzle2DType type, |
| 114 | + Val* x, |
| 115 | + Val* y, |
| 116 | + Val* maybe_size_x, |
| 117 | + Val* maybe_size_y) { |
| 118 | + switch (type) { |
| 119 | + case Swizzle2DType::ZShape: |
| 120 | + return swizzles::unZShape(x, y, maybe_size_y); |
| 121 | + case Swizzle2DType::XOR: |
| 122 | + return swizzles::unXor(x, y); |
| 123 | + case Swizzle2DType::CyclicShift: |
| 124 | + return swizzles::unCyclicShift(x, y, maybe_size_x); |
| 125 | + case Swizzle2DType::Scatter: |
| 126 | + return swizzles::unScatter(x, y, maybe_size_x->evaluateInt()); |
| 127 | + default: |
| 128 | + TORCH_INTERNAL_ASSERT(false, "Unsupported swizzle type"); |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +} // namespace cuda |
| 133 | +} // namespace fuser |
| 134 | +} // namespace jit |
| 135 | +} // namespace torch |
0 commit comments