Skip to content

Commit d880221

Browse files
authored
[Enhancement] Introduce option TL_DISABLE_FAST_MATH and TL_ENABLE_PTXAS_VERBOSE_OUTPUT (#609)
* [Enhancement] Introduce new PassConfig options for fast math and PTXAS verbosity - Added `kDisableFastMath` and `kEnablePTXASVerboseOutput` configuration options to enhance control over compilation settings. - Updated `LibraryGenerator` to utilize these new pass configurations, allowing for more flexible compilation behavior based on user preferences. - Enhanced `PassConfigKey` enumeration to include the new options, ensuring they can be configured appropriately in the pass context. * [Refactor] Update PTXAS verbosity configuration key in LibraryGenerator - Changed the configuration key for PTXAS verbosity from `TL_VERBOSE_PTXAS_OUTPUT` to `TL_ENABLE_PTXAS_VERBOSE_OUTPUT` to align with the new naming convention introduced in recent enhancements. - This update ensures consistency in the configuration options used within the `LibraryGenerator` class, improving clarity and maintainability of the code. * lint fix
1 parent 432dc78 commit d880221

File tree

6 files changed

+29
-1
lines changed

6 files changed

+29
-1
lines changed

src/op/builtin.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
2727
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
2828
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
2929
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
30+
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
31+
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
3032

3133
#define TIR_DEFINE_TL_BUILTIN(OpName) \
3234
const Op &OpName() { \

src/op/builtin.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ static constexpr const char *kDisableWarpSpecialized =
3030
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
3131
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
3232
"tl.enable_aggressive_shared_memory_merge";
33+
static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
34+
static constexpr const char *kEnablePTXASVerboseOutput =
35+
"tl.enable_ptxas_verbose_output";
3336

3437
/*!
3538
* \brief Whether to disable dynamic tail split

tilelang/jit/adapter/ctypes/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(self,
9090
self.verbose = verbose
9191
self.wrapper = TLWrapper(self.target)
9292
self.lib_generator = LibraryGenerator(self.target)
93+
self.lib_generator.assign_pass_configs(pass_configs)
9394

9495
self.wrapper.assign_optimized_module(self.ir_module)
9596
self.wrapper.assign_pass_configs(pass_configs)
@@ -145,6 +146,7 @@ def from_database(cls,
145146
adapter.target = Target.canon_target(determine_target(target))
146147
adapter.verbose = verbose
147148
adapter.lib_generator = LibraryGenerator(adapter.target)
149+
adapter.lib_generator.assign_pass_configs(pass_configs)
148150
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
149151
adapter.lib.init()
150152

tilelang/jit/adapter/cython/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def __init__(self,
246246
self.verbose = verbose
247247
self.wrapper = TLWrapper(self.target)
248248
self.lib_generator = LibraryGenerator(self.target)
249+
self.lib_generator.assign_pass_configs(pass_configs)
249250

250251
self.wrapper.assign_optimized_module(self.ir_module)
251252
self.wrapper.assign_pass_configs(pass_configs)
@@ -305,6 +306,7 @@ def from_database(cls,
305306

306307
adapter.verbose = verbose
307308
adapter.lib_generator = LibraryGenerator(adapter.target)
309+
adapter.lib_generator.assign_pass_configs(pass_configs)
308310
adapter.lib = adapter.lib_generator.load_lib(lib_path=kernel_lib_path)
309311

310312
adapter.lib.get_last_error.restype = ctypes.c_char_p

tilelang/jit/adapter/libgen.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import os.path as osp
88
import subprocess
99
import tempfile
10-
from typing import Optional
10+
from typing import Any, Dict, Optional
1111

1212
from tvm.target import Target
1313

1414
from tilelang import tvm as tvm
15+
from tilelang.transform import PassConfigKey
1516
from tilelang.contrib.nvcc import get_nvcc_compiler, get_target_compute_version
1617
from tilelang.contrib.rocm import find_rocm_path, get_rocm_arch
1718
from tilelang.env import TILELANG_TEMPLATE_PATH
@@ -36,10 +37,14 @@ class LibraryGenerator(object):
3637
srcpath: Optional[str] = None
3738
libpath: Optional[str] = None
3839
lib_code: Optional[str] = None
40+
pass_configs: Optional[Dict[str, Any]] = None
3941

4042
def __init__(self, target: Target):
4143
self.target = target
4244

45+
def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
46+
self.pass_configs = pass_configs
47+
4348
def update_lib_code(self, lib_code: str):
4449
self.lib_code = lib_code
4550

@@ -61,6 +66,10 @@ def compile_lib(self, timeout: float = None):
6166
compute_version = "90a"
6267
libpath = src.name.replace(".cu", ".so")
6368

69+
disable_fast_math = self.pass_configs.get(PassConfigKey.TL_DISABLE_FAST_MATH, False)
70+
verbose_ptxas_output = self.pass_configs.get(
71+
PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT, False)
72+
6473
command = [
6574
get_nvcc_compiler(),
6675
"-std=c++17",
@@ -76,6 +85,10 @@ def compile_lib(self, timeout: float = None):
7685
"-gencode",
7786
f"arch=compute_{compute_version},code=sm_{compute_version}",
7887
]
88+
if not disable_fast_math:
89+
command += ["--use_fast_math"]
90+
if verbose_ptxas_output:
91+
command += ["--ptxas_options", "-v"]
7992
command += [
8093
"-I" + CUTLASS_INCLUDE_DIR,
8194
]

tilelang/transform/pass_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class PassConfigKey(str, Enum):
2020
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
2121
"""Disable warp specialization optimization. Default: False"""
2222

23+
TL_DISABLE_FAST_MATH = "tl.disable_fast_math"
24+
"""Disable fast math optimization. Default: False"""
25+
26+
TL_ENABLE_PTXAS_VERBOSE_OUTPUT = "tl.enable_ptxas_verbose_output"
27+
"""Enable ptxas verbose output. Default: False"""
28+
2329
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
2430
"""Bitwidth for configuration indices. Default: 32"""
2531

0 commit comments

Comments
 (0)