Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);

#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
Expand Down
3 changes: 3 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ static constexpr const char *kDisableWarpSpecialized =
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
"tl.enable_aggressive_shared_memory_merge";
static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";

/*!
* \brief Whether to disable dynamic tail split
Expand Down
2 changes: 2 additions & 0 deletions tilelang/jit/adapter/ctypes/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self,
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)

self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
Expand Down Expand Up @@ -145,6 +146,7 @@ def from_database(cls,
adapter.target = Target.canon_target(determine_target(target))
adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
adapter.lib.init()

Expand Down
2 changes: 2 additions & 0 deletions tilelang/jit/adapter/cython/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(self,
self.verbose = verbose
self.wrapper = TLWrapper(self.target)
self.lib_generator = LibraryGenerator(self.target)
self.lib_generator.assign_pass_configs(pass_configs)

self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
Expand Down Expand Up @@ -305,6 +306,7 @@ def from_database(cls,

adapter.verbose = verbose
adapter.lib_generator = LibraryGenerator(adapter.target)
adapter.lib_generator.assign_pass_configs(pass_configs)
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)

adapter.lib.get_last_error.restype = ctypes.c_char_p
Expand Down
15 changes: 14 additions & 1 deletion tilelang/jit/adapter/libgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import os.path as osp
import subprocess
import tempfile
from typing import Optional
from typing import Any, Dict, Optional

from tvm.target import Target

from tilelang import tvm as tvm
from tilelang.transform import PassConfigKey
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
from tilelang.env import TILELANG_TEMPLATE_PATH
Expand All @@ -36,10 +37,14 @@ class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
pass_configs: Optional[Dict[str, Any]] = None

def __init__(self, target: Target):
self.target = target

def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
self.pass_configs = pass_configs

def update_lib_code(self, lib_code: str):
self.lib_code = lib_code

Expand All @@ -61,6 +66,10 @@ def compile_lib(self, timeout: float = None):
compute_version = "90a"
libpath = src.name.replace(".cu", ".so")

disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

self.pass_configs is not guaranteed to be a dictionary. If assign_pass_configs is not called or is called with None, self.pass_configs will be None, leading to an AttributeError when .get() is called.

You should handle this case to prevent a runtime crash. A safe way to do this is to provide a default empty dictionary if self.pass_configs is None.

Suggested change
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
disable_fast_math = (self.pass_configs or {}).get(PassConfigKey.TL_DISABLE_FAST_MATH, False)

verbose_ptxas_output = self.pass_configs.get(
PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)

command = [
get_nvcc_compiler(),
"-std=c++17",
Expand All @@ -76,6 +85,10 @@ def compile_lib(self, timeout: float = None):
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
if not disable_fast_math:
command += ["--use_fast_math"]
if verbose_ptxas_output:
command += ["--ptxas_options", "-v"]
command += [
"-I" + CUTLASS_INCLUDE_DIR,
]
Expand Down
6 changes: 6 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class PassConfigKey(str, Enum):
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False"""

TL_DISABLE_FAST_MATH = "tl.disable_fast_math"
"""Disable fast math optimization. Default: False"""

TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output"
"""Enable ptxas verbose output. Default: False"""

TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32"""

Expand Down
Loading