You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter
Error logs
~/project/sandbox$ python test_bert.py
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /usr/local/lib/python3.9/dist-packages/torchvision/image.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationERKSsb
warn(f"Failed to load image Python extension: {e}")
sequnece length = 12
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 296, in call_function
out = lowerings[target](*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 222, in wrapped
return decomp_fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 842, in mm
return TensorBox.create(ir.MatrixMultiply.create(a, b))
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/ir.py", line 2803, in create
kernel = tuned_mm(
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 202, in tuned_mm
timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 40, in _bench
return do_bench(kernel_call)
File "/usr/local/lib/python3.9/dist-packages/triton/testing.py", line 140, in do_bench
fn()
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 36, in kernel_call
kernel(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 134, in forward
return _matmul_out._call(a, b, out, allow_tf32)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 114, in _call
_kernel[grid](
File "/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/triton/runtime/autotuner.py", line 199, in run
kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/autotune.py", line 560, in <lambda>
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
KeyError: 'BLOCK_K'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 1032, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/__init__.py", line 1190, in __call__
return self.compile_fn(model_, inputs_)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 398, in compile_fx
return aot_autograd(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2355, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2052, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1307, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1566, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 373, in fw_compiler
return inner_compile(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 588, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/debug.py", line 223, in inner
return fn(*args, **kwargs)
File "/usr/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 139, in compile_fx_inner
graph.run(*example_inputs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 170, in run
return super().run(*args)
File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 369, in run_node
result = super().run_node(n)
File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 299, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf5, i1 + 768 * i0)
return tmp0
,
ranges=(12, 768),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
),
FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
mixed_query_layer = self.query(hidden_states)
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
self_outputs = self.self(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
self_attention_outputs = self.attention(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
layer_outputs = layer_module(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
encoder_outputs = self.encoder(
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/A/project/sandbox/test_bert.py", line 60, in <module>
print(f"gains={seq_len}:{measure_perf(t)}")
File "/home/A/project/sandbox/test_bert.py", line 40, in measure_perf
output = opt_model(**token)
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 83, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 212, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 333, in catch_errors
return callback(frame, cache_size, hooks)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 387, in transform
tracer.run()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
super().run()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
and self.step()
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
self.output.compile_subgraph(self)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 553, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 600, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
target: aten.mm.default
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
'cuda',
torch.float32,
tmp0 = load(buf5, i1 + 768 * i0)
return tmp0
,
ranges=(12, 768),
origins={view}
))
))
args[1]: TensorBox(
ReinterpretView(
StorageBox(
InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
),
FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
no origins?
)
)
While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
mixed_query_layer = self.query(hidden_states)
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
self_outputs = self.self(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
self_attention_outputs = self.attention(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
layer_outputs = layer_module(
| File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
encoder_outputs = self.encoder(
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Minified repro
"""
bert
"""
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from torch._inductor import config
config.triton.mm = "autotune"
#config.triton.mm = "triton"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")#opt_model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
#opt_model = torch.compile(model, passes={"triton-autotune":True})
#opt_model = torch.compile(model, passes={"triton-mm":"triton"})
#opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True}) # this also fails with a different error message
def measure_perf(token, verbose=False, N=8):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
ref_runtime = list()
print(f"sequnece length = {token['input_ids'].shape[1]}")
for i in range(N):
start_event.record()
output = model(**token)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
if verbose:
print(f"model: estimated_ms={estimate_ms}")
if i>0 :
ref_runtime.append(estimate_ms)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
triton_runtime = list()
for i in range(N):
start_event.record()
output = opt_model(**token)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
if verbose:
print(f"opt_model: estimated_ms={estimate_ms}")
if i>0 :
triton_runtime.append(estimate_ms)
a = np.mean(ref_runtime)
b = np.mean(triton_runtime)
gain = a/b
print(f"model: estimated_ms={a}")
print(f"triton model: estimated_ms={b}, gain={gain}")
return gain
text = "Replace me by any text you'd like."
t = tokenizer(text, return_tensors='pt').to(device="cuda:0")
seq_len = t['input_ids'].shape[1]
print(f"gains={seq_len}:{measure_perf(t)}")
The text was updated successfully, but these errors were encountered:
🐛 Describe the bug
when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter
Error logs
Minified repro
The text was updated successfully, but these errors were encountered: