Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions python/tvm/relax/backend/cuda/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
import hashlib
import json

import tvm
from tvm.target import Target
Expand All @@ -37,7 +39,57 @@ def _compile_flashinfer_kernels(
FLASHINFER_TVM_BINDING_DIR,
)

# Todo(tvm-team): enable compilation cache
# ------------------------------------------------------------------------
# Caching Flow: create build_directory and compute cache hash.
# ------------------------------------------------------------------------
build_directory = FLASHINFER_JIT_DIR / name
build_directory.mkdir(parents=True, exist_ok=True)

def get_object_file_path(src: Path) -> Path:
obj_name = src.stem + ".o"
obj_path = build_directory / obj_name
return obj_path

# Compute latest modification time among all source files
latest_src_mtime = max(src.stat().st_mtime for src in source_paths)

# Get modification time for the current file (the one that contains this function)
current_file_mtime = Path(__file__).stat().st_mtime

# Build the hash key from metadata
hash_key = {
"name": name,
"target": str(target),
"latest_src_mtime": latest_src_mtime,
"current_file_mtime": current_file_mtime,
}

hash_value = hashlib.md5(
json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8")
).hexdigest()

# Check if a valid hash exists in the build directory
hash_file = build_directory / "hash.md5"
if hash_file.exists():
with open(hash_file, "r") as f:
cached_hash = f.read().strip()
if cached_hash == hash_value:
# Check that all object files exist
object_files = []
all_exist = True
for src in source_paths:
obj_path = get_object_file_path(src)
if not obj_path.exists():
all_exist = False
break
object_files.append(obj_path)
if all_exist:
return object_files

# If we are here, cache is missing or outdated. Write the new hash and compile the paths
with open(hash_file, "w") as f:
f.write(hash_value)

# ------------------------------------------------------------------------
# 1) Common CUDA compile flags
# ------------------------------------------------------------------------
Expand Down Expand Up @@ -82,17 +134,12 @@ def _compile_flashinfer_kernels(
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
] + CUTLASS_INCLUDE_DIRS

# Where object files will be placed
build_directory = FLASHINFER_JIT_DIR / name
build_directory.mkdir(parents=True, exist_ok=True)

# ------------------------------------------------------------------------
# 3) Function to compile a single source file
# ------------------------------------------------------------------------
def compile_single_source(src: Path) -> Path:
# Derive the .o filename from the source filename
obj_name = src.stem + ".o"
obj_path = build_directory / obj_name
obj_path = get_object_file_path(src)

# Construct the command
cmd = (
Expand Down Expand Up @@ -202,7 +249,12 @@ def gen_flashinfer_prefill_module(
)
jit_args = {
"backend": backend,
"uri": "batch_prefill_tvm",
"uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_"
+ f"dtype_kv_{dtype_kv}_"
+ f"dtype_o_{dtype_o}_"
+ f"qk_head_dim_{qk_head_dim}_"
+ f"v_head_dim_{v_head_dim}_"
+ f"enable_inline_rope_{enable_inline_rope}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
Expand Down Expand Up @@ -273,7 +325,11 @@ def gen_flashinfer_decode_module(
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
jit_args = {
"uri": "batch_decode_tvm",
"uri": f"batch_decode_tvm_dtype_q_{dtype_q}_"
+ f"dtype_kv_{dtype_kv}_"
+ f"dtype_o_{dtype_o}_"
+ f"qk_head_dim_{qk_head_dim}_"
+ f"v_head_dim_{v_head_dim}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
Expand Down Expand Up @@ -343,7 +399,11 @@ def gen_flashinfer_mla_module(
torch_dtype_kv = getattr(torch, dtype_kv)
torch_dtype_o = getattr(torch, dtype_o)
jit_args = {
"uri": "batch_mla_tvm",
"uri": f"batch_mla_tvm_dtype_q_{dtype_q}_"
+ f"dtype_kv_{dtype_kv}_"
+ f"dtype_o_{dtype_o}_"
+ f"head_dim_ckv_{head_dim_ckv}_"
+ f"head_dim_kpe_{head_dim_kpe}",
"dtype_q": torch_dtype_q,
"dtype_kv": torch_dtype_kv,
"dtype_o": torch_dtype_o,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def set_global_func(head_dim, dtype):
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
f = tvm.compile(mod["main"], target=target)
f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)

(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]):
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
f = tvm.compile(mod["main"], target=target)
f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)

(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]):
mod = tvm.IRModule({"main": tir_func})
with target:
mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
f = tvm.compile(mod["main"], target=target)
f = tvm.tir.build(mod["main"], target=target)
builts.append(f.entry_func)

(
Expand Down