Skip to content

Commit 731fc76

Browse files
zheng-xqlly-zero-one
authored andcommitted
A skeleton implmentation for the expression, IR and visitor dispatchers. (pytorch#33)
To run the test: cmake . && make cpptest && ./expr_test Refactor the RefHandle class. (pytorch#34) Add convenience operator for Expr. clang-format change (pytorch#35) Adding Var, Let and eval_context support. (pytorch#36) Add LLVM JIT class for online codegen Refactor llvm codegen fix caps of LlvmJit Generate code for integer arithmetic Test all arithmetic ops with LLVM Fix rtti Compat with llvm 7 and 8 Add support for tensor expressions. (pytorch#38) Add Casting support so mixed dtypes are supported. Add basic dtype and logging support. This should be merged with PyTorch during integration. clang-format fix (pytorch#39) Extend dtypes to support vector types (pytorch#40) Support LLVM 9 too Disambigate dependent type name with template keyword Remove empty scalar.h Add basic support for statements. (pytorch#41) Add support for For, Ramp, Block, Load, Store and Broadcast. Add support for Buffer. Adding Stmt evaluation support. (pytorch#42) Use third_party/googletest from pytorch Remove nnc/tests/googletest submodule Move nnc tld to torch/csrc/jit/compiler Add a README (probably temporary) for jit/compiler Move from namespace nnc to torch::jit::compiler Refactor JIT class to isolate no-rtti pieces Adding comparison operator to Var. (pytorch#43) Fix typo in README.md Use absolute imports and pragma once Use absolute includes in new llvm_jit.h Build non-LLVM compiler stuff with libtorch Minimal asmjit codegen from the tensor IR fix pessimizing moves IR printer fix printer bug Add printer to build system. Add data structure for schedule support and Split. clang-format using the new template Add IRMutator and basic support to substitude Var in Expr and Stmts. Change the default count of RefCounted as zero. Merge Expr(node) and Expr::make(node). Add basic lowering to the tensor expression trees. fix the schedule_test fixed lowering LLVM code generation for simple loops bugfixes refcount fixing self-assignment Make LOG(FATAL) nonreturn Enable Werror Adding statement conversion for SplitWithTail Add a reference tests for Split clang-format A functinoal reference chck for schedule tests. clang-format Add support for Float immediates. Get absolute path for ASMJIT_DIR (pytorch#24) Silence deprecation warnings from LLVM Include legacy PassManager for debug printing Set code model to medium to avoid indirect jumps in generated asm Fix argument type of input float buffers Add support for Casts in LLVM codegen. Add a complete tensor+lower+llvm test Enable the failing test Enable export of compile_commands.json. Floating point arithmetic Test fp32 mul using compute expr Broadcast add test using compute expr Update to LLVM 9 Implementation of Broadcast for LLVM. Add Buffer operator() overload, and some other minor features Cleanup use of ConstantInt API. fix accidental experimental changes Change the Compute interface to bring the dim sizes and names together clang-format refactor Buffer into its own files Add support for vector casts in LLVM CodeGen Implement masked loads and stores. Implement vector masked loads and stores. Add a PaddedBuffer test util Improve the user interface for SimpleIREvaluator Add a test for Block codegen. Fix gtest include path clang-format Add expressions and support for Max and Min. (pytorch#5) Rename compiler to tensorexpr and move files around to be more similar to other pytorch parts. (pytorch#6) Summary: 1. Move compiler to tensorexpr folder 2. Move files from src and include to the same folder (and remove src and include folders) 3. Rename .cc to .cpp Add missing include <math.h> (pytorch#7) Change isnan to std::isnan. It breaks my clang builds. (pytorch#8) Change the SimpleIREvaluator frontend (pytorch#9) Add RefHandle for subclass Make LLVM dependency optional. (pytorch#10) [wip] Basic fuser pass to select texpr subgraphs Revert "[wip] Basic fuser pass to select texpr subgraphs" This reverts commit a9d9919. Revert changes to the main pytorch CMakeLists.txt (for now). Add a test for aten::_cast_Float lowering. (pytorch#12) Hook tensorexp up to the main build, and switch to c10 logging More ATen op tests. (pytorch#16) Fix some missing returns Include tests back to the 'all' target. (pytorch#14) Even more ATen op tests. (pytorch#18) Test for relu ATen op. (pytorch#19) Add intrinsics function support. (pytorch#20) Remove fmax/fmin, as they are already covered by the Max/Min operators (pytorch#21) refactor CallNode and BaseCallNode, so we can have a common concrete base class for visitors. (pytorch#22) This is the first step to add other call types. Add FunctionCall to use existing tensors (pytorch#23) Add the ability to use an existing tensor expression in other compute functions. (pytorch#24) fixing broken compilation on mac/clang adding IRnode for Compare-Select Ops and their LLVM Codegen Fix Werror. (pytorch#26) Add tests for some transcendental ops. (pytorch#27) Add Allocate and Free support. (pytorch#29) Add Eval and test basic alloc support. Add Lowering support for buffer allocation for intermediate tensors. Tensor expr fuser pass for extremely simple expressions Make fusion work for arbitrary buffer/tensor combinations of inputs (pytorch#30) fix Let02 test Access inputs and intermediates uniformly through Tensors (pytorch#31) adding LLVM Codegen for Let Adding ComputeInline support. (pytorch#35) Fix broken tests (pytorch#36) Make tx fuser work with arbitrary ranks [fuser] Broadcast args Improve naming of arg broadcasting function modifying CMakeLists.txt to enable ninja test && minor update for LLVM Codegen for Let (handling XQ's comment) Test cases for tensorexpr fusion (pytorch#37) CompareSelct Op: Addressing XQ and Owen's comments Sketch sufficient support for constants to get constant alpha working. (pytorch#40) * Refactor to use a switch statement over Node kinds. * Sketch sufficient support for constants to get constant alpha working. Fix indices when inlining non-leaf calls (pytorch#39) Fixing the inline ordering issue (pytorch#43) Solve more problems with the inliner Avoid creating redundant and/or improperly ordered Constant's in fused subgraphs. (pytorch#42) Move fuser-styled tests to schedule_test (pytorch#44) Add aten::sub to the new fuser. (pytorch#46) Refactor CodeGen from SimpleIREval (pytorch#47) Inline all the things (pytorch#45) clang-format for atent_test.cpp Eliminate a ton of warnings for my own sanity. (pytorch#48) Add support for type promotion/demotion. (pytorch#50) Flesh out new fuser coverage to several more ops. (pytorch#51) Adding the first basic CudaCodeGen. (pytorch#52) aten tests for eq, ge, gt, le, lt support for aten ops: eq support for more aten ops: ge, gt, le, lt, ne Minimal CMake change to link LLVM to libtorch Fix issues causing assertion failures in llvm debug builds Fatal on unimplement llvm codegen ops (Allocate, etc.) Optionally compile tx fuser kernels with llvm Test for 2D broadcasted with large dims to show vectorization Updated isSupported for increased op coverage. (pytorch#54) Refactor LLVMCodeGen to compile kernel in constructor Cmake integration to PT codebase (pytorch#28) With this change our code blends with the usual PyTorch code and is built the usual way. I added a cmake option to specify where to look for LLVM, if it's not specified, LLVM is not used. An example of invocation (from the root of pytorch repo): ``` USE_LLVM=/path/to/llvm9/install python setup.py develop ``` This command will build libtorch.{a,so} and other libraries, and tensorexpr code will be a part of it. The tests will be built in build/bin/test_tensorexpr (I've ported only one test so far). So, invocation of the tests will be: ``` build/bin/test_tensorexpr ``` Remove old padded_buffer.{cpp,h}. (pytorch#56) Add support for code generation of Log10 intrinsics with LLVM. (pytorch#57) Remove tests/test_utils.h: inline what's still used and nuke what's unused. (pytorch#58) Move Fuser tests (tests/tests.py) to test/test_tensorexpr.py. (pytorch#59) Remove old CMakeLists and README.txt Add support for vectorized and unmasked loads and stores with LLVM. (pytorch#62) Enable CodeGen-level optimizations in LLVM. (pytorch#63) Add Bind/GPUBlock/GPUThread support. (pytorch#64) Bind/run interface to CodeGen (pytorch#60) * Bind/run interface to CodeGen * Make LLVMCodeGen implement CodeGen interface * Allow bind/run to be unimplemented for the moment (CUDA) * Cache compilation result * Two nasty bugs: forgot virtual dtor, forgot to clear bindings after run() Fix ambiguity in CreateExtractElementCall (0ull can be a Value*, I guess?) (pytorch#65) Allow constants as lhs/rhs args (not just alpha) (pytorch#66) Use correct tensor type for fuser output (pytorch#67) clang-format Rename 'compiler' namespace to 'tensorexpr'. Include all built llvm targets (pytorch#68) Switch back to linking only the native LLVM target. (pytorch#69) Virtual dtors for IRVisitor/IRMutator (pytorch#70) Add semicolon to make nvcc compile (pytorch#71) Enable NVRTC for the GPU backend. (pytorch#74) Fix non-CUDA testing. (pytorch#75) Getting fused (a)Sin(h), (a)Cos(h),(a) Tan(h), abs working with the interpreter (pytorch#73) * Getting fused (a)Sin(h), (a)Cos(h),(a) Tan(h), abs working with the interpreter * take the interpreter path only when ENABLE_LLVM is not set remove the leak tests, as we will get rid of refcounting (pytorch#76) Implement aten::min, max, and clamp (pytorch#72) * Implement aten::min, max, and clamp * Propagate NaNs like std::max/min * Change NaN propagation in interpreter too clang-format tensorexpr/tests.h (pytorch#77) Refactor UniqueNameManager into its own files. (pytorch#79) refactor cuda_codegen (pytorch#80) simplify nvrtc major, minor versions (pytorch#81) Allow CodeGen to take Var args (interpreter support only) (pytorch#78) * Test demonstrating dynamic shape * Allow binding of Vars to args in interpreter * Pass BufferArgs to LLVMCodeGen * clang-format-diff [LLVMCodeGen] Refactor kernel constructor to be less sprawling (pytorch#82) * Member TM to TM_ in LLVMCodeGen * [LLVMCodeGen] Add helper for getContext * [LLVMCodeGen] Refactor type support * [LLVMCodeGen] Refactor kernel emission (TE Interpreter)Support for floor, ceil, trunc, remainder, sqrt and improving tests (pytorch#83) * Getting fused (a)Sin(h), (a)Cos(h),(a) Tan(h), abs working with the interpreter * take the interpreter path only when ENABLE_LLVM is not set * cleaning up the tests for the new aten ops * (TE Interpret)adding support for floor, ceil, trunc, remainder and improving tests Add Cond and Mod to SimpleIREval (pytorch#84) [LLVMCodeGen] Support dynamic shapes by binding Var args (pytorch#86) * [LLVMCodeGen] Support dynamic shapes by binding Var args * Test llvm dynamic shape codegen using Tensor Add SplitWithMask core support. (pytorch#87) Add Cuda tests for SplitWithMask (pytorch#88) Disable DEBUG_PRINT (pytorch#89) Remove some debug prints (pytorch#90) Fix the no-CUDA build. (pytorch#92) Add support for multiple outputs from the fused subgraph. (pytorch#91) Remove RefCounting (pytorch#93) Add some comments for KernelScope. Address comments. (pytorch#94) Completely remove refcount.h (pytorch#95) fix the fuser pass (pytorch#97) Rename Kernel to KernelArena (pytorch#98) Add support for fusion through ConstantChunk ops. (pytorch#96) Fix implicit noexcept deduction warning. (pytorch#99) Make llvm tests conditional on USE_LLVM (pytorch#100) * Make llvm tests conditional on USE_LLVM * Use the right macro and add to gtest harness * clang-format Refactor ComputeNode into ComputeValue, to be able to handle arbitrary (pytorch#101) multi-output operators. Improve Stmt pretty printing from TensorExprFuser (pytorch#102) Add support for IfThenElse (pytorch#103) Add end-to-end support and a PyTorch fuser example on CudaCodeGen (pytorch#104) fix rebase errors (pytorch#105) fixes to build on system without LLVM and CUDA (pytorch#107) * fixes to build on system without LLVM and CUDA * minor edit: fixes to build on system without LLVM and CUDA Add support for aten::cat to the new fuser. (pytorch#106) Bail out of fusion if we don't have a complete tensor type (for now). (pytorch#108) Standardize codegen call() interface and remove bind/run (pytorch#109) * Standardize codegen call() interface and remove bind/run * revert undef USE_CUDA Clean up sketchy handling of scalar args in llvm codegen (pytorch#110) Test 2D dynamic shapes (pytorch#112) clang-format (pytorch#113) Add LLVM codegen for a lot of transcendental ops. (pytorch#115) Fix bug with binary math intrinsics. (pytorch#116) Use CUDA for 3-arg test (pytorch#117) Refactor CudaCodeGen into generic registration, so we can have both the Cuda and non-Cuda builds. (pytorch#118) Add instructions on how to rebase on master. Dynamic shape support in CUDA codegen (pytorch#120) * Dynamic shape support in CUDA codegen * free cuda memory Disable GPU fuser. Revive the Cuda tests (pytorch#121) Add ExecutionCounter to detect whether the underlying code is executed. (pytorch#122) Adding GPU index flatting to support arbitrary elementwise and broadcasting support. (pytorch#126) fix a bug kLog to Intrin::log (pytorch#124) Allow scalar variables as inputs (pytorch#125) clang-format (pytorch#127) Format python tests with `black` (pytorch#128) Add support for fusion in nested blocks. (pytorch#129) Teach the LLVM JIT to use dlsym to resolve symbols. (pytorch#130) Factor out kernel codegen from tx fusion pass (pytorch#131) Use standard JIT logging in TX fuser. Move memory management classes (KernelArena, KernelScope, KernelScopedObject) to a separate file. (pytorch#132) (IR Interpreter) Adding more Operators: Erfc, Exmp1, frac, lgamma, neg, sigmoid, reciprocal, neg, relu (pytorch#133) Add erfc to llvm codegen (pytorch#134) Squash some warnings (pytorch#135) (IR interpreter) addcmul (pytorch#137) * (IR interpreter) addcmul Remove IRNode. CodeGen accepts only Stmt. Add ExprEval utility wrapper. (pytorch#138) Add the benchmark from NNC (pytorch#141) Fix verifier errors in LLVM codegen when conditional loads feed directly into concats. (pytorch#143) Strength reduction peephole for pow(). (pytorch#144) Fix incorrect pow(x, 0) case. (pytorch#145) Use `const Value*` where possible (pytorch#146) Make Broadcast work (pytorch#147) $ python benchmarks/tensorexpr/benchmark.py broadcast_3args --device gpu --mode fwd --jit_mode trace Fixed CudaCodeGen output streams. Switch to __ldg by default (pytorch#148) Add ElementWise support (pytorch#150) Fix an assertion failure when merging constants into aten::cat fusions. (pytorch#151) adding LLVM support ops: sigmoid, relu, neg, addcmul, reciprocal, lgamma, expm1 (pytorch#149) * adding LLVM support for a few ops add findllvm
1 parent 0c73c32 commit 731fc76

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+4269
-6704
lines changed

benchmarks/tensorexpr/benchmark.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
import itertools
33
import framework
44
import os
5-
import types
65
import tensor_engine
7-
#import normalization
6+
import normalization
87
import broadcast
9-
#import reduction
8+
import reduction
109
import elementwise
11-
#import softmax
12-
#import pooling
13-
#import conv
14-
#import matmul
10+
import softmax
11+
import pooling
12+
import conv
13+
import matmul
1514

1615

1716
def main():
@@ -32,15 +31,7 @@ def main():
3231
help='the underlying tensor engine. only pt for now')
3332
parser.add_argument('--jit_mode', type=str, default='trace',
3433
help='the jit mode to use: one of {trace, none}')
35-
parser.add_argument('--cuda_pointwise_loop_levels', type=int, default=None,
36-
help='num of loop levesl for Cuda pointwise operations: 2 or 3')
37-
parser.add_argument('--cuda_pointwise_block_count', type=int, default=None,
38-
help='num of block for Cuda pointwise operations')
39-
parser.add_argument('--cuda_pointwise_block_size', type=int, default=None,
40-
help='num of blocks for Cuda pointwise operations')
41-
parser.add_argument('--cuda_fuser', type=str, default='te',
42-
help='The Cuda fuser backend to use: one of {te, old, none}')
43-
34+
4435
args = parser.parse_args()
4536

4637
def set_global_threads(num_threads):
@@ -82,7 +73,7 @@ def run_default_configs(bench_cls, allow_skip=True):
8273
continue
8374
else:
8475
raise ValueError('attempted to run an unsupported benchmark: %s' % (benchmark.desc()))
85-
framework.run_benchmark(benchmark, args)
76+
framework.run_benchmark(benchmark)
8677

8778
benchmark_classes = framework.benchmark_classes
8879
if not args.benchmark_names:
@@ -125,7 +116,7 @@ def run_default_configs(bench_cls, allow_skip=True):
125116
pass
126117
benchmark = bench_cls(*config)
127118
benchmark.jit_mode = args.jit_mode
128-
framework.run_benchmark(benchmark, args)
119+
framework.run_benchmark(benchmark)
129120

130121
if not match_class_name:
131122
available_classes = ', '.join([bench_cls.module() for bench_cls in benchmark_classes])

benchmarks/tensorexpr/broadcast.py

Lines changed: 4 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import framework
2-
import itertools
3-
import numpy as np
4-
import torch
52

63

74
class BroadcastMulBench(framework.Benchmark):
@@ -123,142 +120,7 @@ def module():
123120
return 'broadcast_3args'
124121

125122

126-
#framework.register_benchmark_class(BroadcastRowBench)
127-
#framework.register_benchmark_class(BroadcastMidBench)
128-
#framework.register_benchmark_class(BroadcastColBench)
129-
#framework.register_benchmark_class(BroadcastThreeArgs)
130-
131-
# TODO: merge this with elementwise bench
132-
# A template class for elementwise operations.
133-
# A derived class will override the class instance to customize its behavior.
134-
class BroadcastBench(framework.Benchmark):
135-
# List of customization class variables.
136-
op_str = None
137-
binary_op_pt_func = None
138-
binary_op_np_func = None
139-
unary_op_pt_func = None
140-
unary_op_np_func = None
141-
split_input = True
142-
def __init__(self, mode, device, M, N, K):
143-
super().__init__(mode, device)
144-
self.M = M
145-
self.N = N
146-
self.K = K
147-
self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
148-
self.d2 = self.rand([K, 1, N], device=device, requires_grad=self.requires_grad)
149-
self.d3 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
150-
self.d4 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad)
151-
self.inputs = [self.d1, self.d2, self.d3, self.d4]
152-
153-
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
154-
if not binary_op:
155-
binary_op = lambda x, y: x + y
156-
if not unary_op:
157-
unary_op = lambda x: x
158-
if self.split_input:
159-
d1 = unary_op(d1)
160-
d2 = unary_op(d2)
161-
d3 = unary_op(d3)
162-
d4 = unary_op(d4)
163-
else:
164-
d1, d2, d3, d4 = unary_op(d1), unary_op(d2), unary_op(d1 + 0.001), unary_op(d4)
165-
a = binary_op(d1, d2)
166-
b = binary_op(d3, d4)
167-
c = a + b
168-
return c
169-
170-
def forward(self, d1, d2, d3, d4):
171-
binary_op = self.__class__.binary_op_pt_func
172-
unary_op = self.__class__.unary_op_pt_func
173-
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
174-
175-
def reference(self):
176-
binary_op = self.__class__.binary_op_np_func
177-
unary_op = self.__class__.unary_op_np_func
178-
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
179-
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
180-
181-
def config(self):
182-
return [self.M, self.N, self.K]
183-
184-
@classmethod
185-
def module(cls):
186-
return 'broadcast_' + cls.op_str
187-
188-
def memory_workload(self):
189-
input_count = len(self.inputs)
190-
if self.mode == 'fwd':
191-
if self.split_input:
192-
sol_count = 1
193-
algorithmic_count = 1
194-
else:
195-
sol_count = 1
196-
algorithmic_count = 1
197-
else:
198-
if self.split_input:
199-
sol_count = 1
200-
algorithmic_count = input_count
201-
else:
202-
sol_count = 1
203-
algorithmic_count = input_count
204-
205-
buffer_size = self.M * self.N * self.K * 4
206-
return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count}
207-
208-
@staticmethod
209-
def default_configs():
210-
return [[1 << 8, 1 << 7, 1 << 9]]
211-
212-
213-
def register_broadcast_ops():
214-
binary_op_list = [
215-
["mul", lambda a, b: a * b],
216-
["add", lambda a, b: a + b],
217-
["sub", lambda a, b: a - b],
218-
["div", lambda a, b: a / (b + 1e-4)],
219-
["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered
220-
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
221-
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
222-
]
223-
224-
unary_op_list = [
225-
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
226-
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
227-
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
228-
]
229-
230-
for split_input, binary_op in itertools.product([True, False], binary_op_list):
231-
# Make a copy of BroadcastBench
232-
if len(binary_op) == 2:
233-
[op_str, op_pt_func] = binary_op
234-
op_np_func = op_pt_func
235-
elif len(binary_op) == 3:
236-
[op_str, op_pt_func, op_np_func] = binary_op
237-
split_str = 'split' if split_input else 'shared'
238-
op_str = split_str + '_' + op_str
239-
bm_cls = type('BroadcastBench_' + op_str, (BroadcastBench,), {})
240-
bm_cls.op_str = op_str
241-
bm_cls.binary_op_pt_func = op_pt_func
242-
bm_cls.binary_op_np_func = op_np_func
243-
bm_cls.split_input = split_input
244-
framework.register_benchmark_class(bm_cls)
245-
246-
for split_input, unary_op in itertools.product([True, False], unary_op_list):
247-
# Make a copy of BroadcastBench
248-
if len(unary_op) == 2:
249-
[op_str, op_pt_func] = unary_op
250-
op_np_func = op_pt_func
251-
elif len(unary_op) == 3:
252-
[op_str, op_pt_func, op_np_func] = unary_op
253-
split_str = 'split' if split_input else 'shared'
254-
op_str = split_str + '_' + op_str
255-
bm_cls = type('BroadcastBench_' + op_str, (BroadcastBench,), {})
256-
bm_cls.op_str = op_str
257-
bm_cls.unary_op_pt_func = op_pt_func
258-
bm_cls.unary_op_np_func = op_np_func
259-
bm_cls.split_input = split_input
260-
framework.register_benchmark_class(bm_cls)
261-
262-
263-
register_broadcast_ops()
264-
123+
framework.register_benchmark_class(BroadcastRowBench)
124+
framework.register_benchmark_class(BroadcastMidBench)
125+
framework.register_benchmark_class(BroadcastColBench)
126+
framework.register_benchmark_class(BroadcastThreeArgs)

benchmarks/tensorexpr/elementwise.py

Lines changed: 13 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
11
import framework
2-
import itertools
3-
import numpy as np
4-
import torch
52

6-
# A template class for elementwise operations.
7-
# A derived class will override the class instance to customize its behavior.
8-
class ElementBench(framework.Benchmark):
9-
# List of customization class variables.
10-
op_str = None
11-
binary_op_pt_func = None
12-
binary_op_np_func = None
13-
unary_op_pt_func = None
14-
unary_op_np_func = None
15-
split_input = True
3+
4+
class ElementMulBench(framework.Benchmark):
165
def __init__(self, mode, device, N):
176
super().__init__(mode, device)
187
self.N = N
@@ -21,68 +10,28 @@ def __init__(self, mode, device, N):
2110
self.d3 = self.rand([N], device=device, requires_grad=self.requires_grad)
2211
self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad)
2312
self.inputs = [self.d1, self.d2, self.d3, self.d4]
24-
self.deterministic = ('rand' not in self.op_str)
2513

26-
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
27-
if not binary_op:
28-
binary_op = lambda x, y: x + y
29-
if not unary_op:
30-
unary_op = lambda x: x
31-
if self.split_input:
32-
d1 = unary_op(d1)
33-
d2 = unary_op(d2)
34-
d3 = unary_op(d3)
35-
d4 = unary_op(d4)
36-
else:
37-
d2 = unary_op(d1 + 0.001)
38-
d3 = unary_op(d1 + 0.002)
39-
d4 = unary_op(d1 + 0.003)
40-
d1 = unary_op(d1)
41-
a = binary_op(d1, d2)
42-
b = binary_op(d3, d4)
43-
c = a + b
44-
return c
45-
4614
def forward(self, d1, d2, d3, d4):
47-
binary_op = self.__class__.binary_op_pt_func
48-
unary_op = self.__class__.unary_op_pt_func
49-
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
15+
y = d1 * d2 + d3 * d4
16+
return y
5017

5118
def reference(self):
52-
binary_op = self.__class__.binary_op_np_func
53-
unary_op = self.__class__.unary_op_np_func
54-
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
55-
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
19+
return self.numpy(self.d1) * self.numpy(self.d2) + self.numpy(self.d3) * self.numpy(self.d4)
5620

5721
def config(self):
5822
return [self.N]
5923

60-
@classmethod
61-
def module(cls):
62-
return 'element_' + cls.op_str
24+
@staticmethod
25+
def module():
26+
return 'element_mul'
6327

6428
def memory_workload(self):
65-
input_count = len(self.inputs)
6629
if self.mode == 'fwd':
67-
if self.split_input:
68-
sol_count = input_count + 1
69-
algorithmic_count = input_count + 1
70-
else:
71-
sol_count = 1 + 1
72-
algorithmic_count = 1 + 1
73-
if 'rand' in self.op_str:
74-
sol_count = 1
75-
algorithmic_count = 1
30+
sol_count = 4 + 1
31+
algorithmic_count = 3 + 1
7632
else:
77-
if self.split_input:
78-
sol_count = (input_count + 1) + (1 + input_count)
79-
algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
80-
else:
81-
sol_count = 1 + 1
82-
algorithmic_count = 1 + 1
83-
if 'rand' in self.op_str:
84-
sol_count = 1
85-
algorithmic_count = 1
33+
sol_count = (4 + 1) + (1 + 4)
34+
algorithmic_count = (4 + 1) + ((2 + 1) * 4)
8635

8736
buffer_size = self.N * 4
8837
return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count}
@@ -92,57 +41,4 @@ def default_configs():
9241
return [[1 << 27]]
9342

9443

95-
def register_element_ops():
96-
binary_op_list = [
97-
["mul", lambda a, b: a * b],
98-
["add", lambda a, b: a + b],
99-
["sub", lambda a, b: a - b],
100-
["div", lambda a, b: a / (b + 1e-4)],
101-
["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered
102-
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
103-
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
104-
]
105-
106-
unary_op_list = [
107-
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
108-
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
109-
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
110-
["rand_like", lambda x: torch.rand_like(x), lambda x: np.random.rand(*x.shape)],
111-
]
112-
113-
for split_input, binary_op in itertools.product([True, False], binary_op_list):
114-
# Make a copy of ElementBench
115-
if len(binary_op) == 2:
116-
[op_str, op_pt_func] = binary_op
117-
op_np_func = op_pt_func
118-
elif len(binary_op) == 3:
119-
[op_str, op_pt_func, op_np_func] = binary_op
120-
split_str = 'split' if split_input else 'shared'
121-
op_str = split_str + '_' + op_str
122-
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
123-
bm_cls.op_str = op_str
124-
bm_cls.binary_op_pt_func = op_pt_func
125-
bm_cls.binary_op_np_func = op_np_func
126-
bm_cls.split_input = split_input
127-
framework.register_benchmark_class(bm_cls)
128-
129-
for split_input, unary_op in itertools.product([True, False], unary_op_list):
130-
# Make a copy of ElementBench
131-
if len(unary_op) == 2:
132-
[op_str, op_pt_func] = unary_op
133-
op_np_func = op_pt_func
134-
elif len(unary_op) == 3:
135-
[op_str, op_pt_func, op_np_func] = unary_op
136-
split_str = 'split' if split_input else 'shared'
137-
op_str = split_str + '_' + op_str
138-
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
139-
bm_cls.op_str = op_str
140-
bm_cls.unary_op_pt_func = op_pt_func
141-
bm_cls.unary_op_np_func = op_np_func
142-
bm_cls.split_input = split_input
143-
framework.register_benchmark_class(bm_cls)
144-
145-
146-
#framework.register_benchmark_class(ElementMulBench)
147-
register_element_ops()
148-
44+
framework.register_benchmark_class(ElementMulBench)

0 commit comments

Comments
 (0)