Skip to content

crash when enabled triton matmul #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
stephen-youn opened this issue Dec 30, 2022 · 1 comment
Closed

crash when enabled triton matmul #2015

stephen-youn opened this issue Dec 30, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@stephen-youn
Copy link

stephen-youn commented Dec 30, 2022

🐛 Describe the bug

when triton matmul is enabled by setting the config.triton.mm to either "autotune" or "triton", it crashes, complaining "BLOCK_K" which seems a triton kernel's parameter

Error logs

~/project/sandbox$ python test_bert.py
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py:372: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled.Consider setting `torch.set_float32_matmul_precision('high')`
  warnings.warn(
/usr/local/lib/python3.9/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /usr/local/lib/python3.9/dist-packages/torchvision/image.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationERKSsb
  warn(f"Failed to load image Python extension: {e}")
sequnece length = 12
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 296, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 222, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/lowering.py", line 842, in mm
    return TensorBox.create(ir.MatrixMultiply.create(a, b))
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/ir.py", line 2803, in create
    kernel = tuned_mm(
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 202, in tuned_mm
    timing, _, _ = autotune._bench(runnable_kernel, *run_args, **run_kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 40, in _bench
    return do_bench(kernel_call)
  File "/usr/local/lib/python3.9/dist-packages/triton/testing.py", line 140, in do_bench
    fn()
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/autotuner.py", line 36, in kernel_call
    kernel(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 134, in forward
    return _matmul_out._call(a, b, out, allow_tf32)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/matmul.py", line 114, in _call
    _kernel[grid](
  File "/usr/local/lib/python3.9/dist-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/triton/runtime/autotuner.py", line 199, in run
    kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/triton_ops/autotune.py", line 560, in <lambda>
    "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
KeyError: 'BLOCK_K'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 1032, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/__init__.py", line 1190, in __call__
    return self.compile_fn(model_, inputs_)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 398, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/optimizations/training.py", line 78, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2355, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 2052, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1307, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/aot_autograd.py", line 1566, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 373, in fw_compiler
    return inner_compile(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/debug_utils.py", line 588, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/debug.py", line 223, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/compile_fx.py", line 139, in compile_fx_inner
    graph.run(*example_inputs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 170, in run
    return super().run(*args)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 369, in run_node
    result = super().run_node(n)
  File "/usr/local/lib/python3.9/dist-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_inductor/graph.py", line 299, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf5, i1 + 768 * i0)
      return tmp0
      ,
      ranges=(12, 768),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
    mixed_query_layer = self.query(hidden_states)
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
    self_outputs = self.self(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
    self_attention_outputs = self.attention(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
    layer_outputs = layer_module(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
    encoder_outputs = self.encoder(


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/A/project/sandbox/test_bert.py", line 60, in <module>
    print(f"gains={seq_len}:{measure_perf(t)}")
  File "/home/A/project/sandbox/test_bert.py", line 40, in measure_perf
    output = opt_model(**token)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 83, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 212, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/eval_frame.py", line 333, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 480, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/utils.py", line 88, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _compile
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/convert_frame.py", line 387, in transform
    tracer.run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1684, in run
    super().run()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/symbolic_convert.py", line 1750, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 553, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 600, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.9/dist-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: KeyError: 'BLOCK_K'
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf6', layout=FixedLayout('cuda', torch.float32, size=(12, 768), stride=[768, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      tmp0 = load(buf5, i1 + 768 * i0)
      return tmp0
      ,
      ranges=(12, 768),
      origins={view}
    ))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_6', layout=FixedLayout('cuda', torch.float32, size=[768, 768], stride=[768, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[768, 768], stride=[1, 768]),
      no origins?
    )
  )

While executing %mm : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %permute), kwargs = {})
Original traceback:
Module stack: {'self_encoder': "<class 'transformers.models.bert.modeling_bert.BertEncoder'>", 'self_encoder_layer_0': "<class 'transformers.models.bert.modeling_bert.BertLayer'>", 'self_encoder_layer_0_attention': "<class 'transformers.models.bert.modeling_bert.BertAttention'>", 'self_encoder_layer_0_attention_self': "<class 'transformers.models.bert.modeling_bert.BertSelfAttention'>", 'self_encoder_layer_0_attention_self_query': "<class 'torch.nn.modules.linear.Linear'>"}
  File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 285, in forward
    mixed_query_layer = self.query(hidden_states)
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 426, in forward
    self_outputs = self.self(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 496, in forward
    self_attention_outputs = self.attention(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 610, in forward
    layer_outputs = layer_module(
 |   File "/usr/local/lib/python3.9/dist-packages/transformers/models/bert/modeling_bert.py", line 1021, in forward
    encoder_outputs = self.encoder(


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

"""
bert
"""

import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from torch._inductor import config
config.triton.mm = "autotune"
#config.triton.mm = "triton"

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")#opt_model = torch.compile(model, backend="inductor") # This is the only line of code that we changed
#opt_model = torch.compile(model, passes={"triton-autotune":True})
#opt_model = torch.compile(model, passes={"triton-mm":"triton"})
#opt_model = torch.compile(model, passes={'triton-mm': "triton", 'triton-bmm': True}) # this also fails with a different error message

def measure_perf(token, verbose=False, N=8):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    ref_runtime = list()
    print(f"sequnece length = {token['input_ids'].shape[1]}")
    for i in range(N):
        start_event.record()
        output = model(**token)
        end_event.record()
        torch.cuda.synchronize()
        estimate_ms = start_event.elapsed_time(end_event)
        if verbose:
            print(f"model: estimated_ms={estimate_ms}")
        if i>0 :
            ref_runtime.append(estimate_ms)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    triton_runtime = list()
    for i in range(N):
        start_event.record()
        output = opt_model(**token)
        end_event.record()
        torch.cuda.synchronize()
        estimate_ms = start_event.elapsed_time(end_event)
        if verbose:
            print(f"opt_model: estimated_ms={estimate_ms}")
        if i>0 :
            triton_runtime.append(estimate_ms)

    a = np.mean(ref_runtime)
    b = np.mean(triton_runtime)
    gain = a/b
    print(f"model: estimated_ms={a}")
    print(f"triton model: estimated_ms={b}, gain={gain}")
    return gain

text = "Replace me by any text you'd like."
t = tokenizer(text, return_tensors='pt').to(device="cuda:0")

seq_len = t['input_ids'].shape[1]
print(f"gains={seq_len}:{measure_perf(t)}")
@stephen-youn stephen-youn added the bug Something isn't working label Dec 30, 2022
@soumith
Copy link
Member

soumith commented Dec 30, 2022

this will be fixed after pytorch/pytorch#90738 re-lands

jansel added a commit to jansel/pytorch that referenced this issue Jan 10, 2023
…h#91575)

Summary:
This reverts commit 94262ef to reland pytorch#91105 / pytorch#90738.

Fixes pytorch/torchdynamo#2015

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

Pull Request resolved: pytorch#91575

Reviewed By: ngimel

Differential Revision: D42304332

Pulled By: jansel

fbshipit-source-id: 1eefc7320da5de7544d048c5b7ea8716930f31cf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants