Skip to content

Commit cb802f5

Browse files
committed
chore: Improve the AutoTuner log information.
* Change fallback alert from DEBUG level to WARNING level and only do once. * Add debug information for profiling cache right after warmup phase. * Change the level of exception message during tactic profiling from ERROR to WARNING. And all exception details are pushed to DEBUG. * Other trivial refinements. Signed-off-by: Yukun He <[email protected]>
1 parent dc75779 commit cb802f5

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,13 @@ def autotune(tune_mode: bool = True):
197197
AutoTuner.get().is_tuning_mode = tune_mode
198198
autotune_enabled = tune_mode and not old_mode
199199
if autotune_enabled:
200-
logger.info("[Autotuner]: Autotuning process starts ...")
200+
logger.info("[Autotuner] Autotuning process starts ...")
201201
try:
202202
yield
203203
finally:
204204
AutoTuner.get().is_tuning_mode = old_mode
205205
if autotune_enabled:
206-
logger.info("[Autotuner]: Autotuning process ends")
206+
logger.info("[Autotuner] Autotuning process ends")
207207

208208

209209
@dataclass
@@ -350,16 +350,11 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
350350
runner = runners[runner_id]
351351
# TODO: check the stored runner and tactic can implement this shape here
352352
# Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.
353-
354-
# Record the cache miss config.
355-
# Expect no cache miss in inference. Thus, any cache miss should be recorded.
356-
if not is_cache_hit:
357-
logger.debug(
358-
f"[AutoTunner]: Using fallback tactic for {custom_op} with input shapes {input_shapes}"
359-
)
360-
logger.debug(
361-
f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}"
362-
)
353+
if not is_cache_hit and len(self.profiling_cache) > 0:
354+
# Only log once for each custom op and only when cache is not empty
355+
logger.warning_once(
356+
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
357+
key=(custom_op))
363358
return runner, tactic
364359

365360
assert len(runners) > 0, "At least one runner is required"
@@ -370,6 +365,8 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
370365
# Record the total configs to try
371366
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
372367

368+
new_tuning_failure_occured = False
369+
373370
for p in profiles:
374371
tensors = self._prepare_input_tensors(p, inputs)
375372
is_cache_hit, runner_id, tactic, _ = self.search_cache(
@@ -396,11 +393,13 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
396393
except Exception as e:
397394
shapes = self._get_input_sizes(tensors)
398395

399-
logger.error(
400-
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
396+
logger.warning(
397+
f"[Autotuner] Failed when profiling runner={r}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for more details."
401398
)
399+
logger.debug(f"[Autotuner] Exception captured: {e}")
402400

403401
# Record the failed profiling combinations
402+
new_tuning_failure_occured = True
404403
if custom_op not in self.stats.failed_profiling_count:
405404
self.stats.failed_profiling_count[
406405
custom_op] = set()
@@ -426,9 +425,26 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
426425
custom_op] = self.stats.tuned_op_successful_configs.get(
427426
custom_op, 0) + 1
428427
logger.debug(
429-
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
428+
f"[Autotuner] Profiling runner={runners[runner_id]}, tactic={tactic} for cache_key={cache_key}."
429+
)
430+
else:
431+
logger.warning(
432+
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
433+
f"At least one valid (runner, tactic) pair is required. "
434+
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
435+
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash."
430436
)
431437

438+
# If failed profiling tactics occurs, log the error.
439+
if new_tuning_failure_occured:
440+
logger.error(
441+
f"[Autotuner] New tuning error occurs:"
442+
f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. "
443+
f"This will not block the tuning process. "
444+
f"Please set TLLM_LOG_LEVEL=WARNING to find out when the tactic profiling fails. "
445+
f"Set TLLM_LOG_LEVEL=DEBUG to get more details of the failures."
446+
)
447+
432448
# Get the best runner and tactic from cache
433449
# If no valid tactic is found, the fallback runner and tactic will be used
434450
_, runner_id, tactic, _ = self.search_cache(custom_op, runners,
@@ -487,7 +503,7 @@ def _profile_single_kernel(self, runner: TunableRunner,
487503

488504
shapes = self._get_input_sizes(inputs)
489505
logger.debug(
490-
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
506+
f"[Autotuner] Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time}ms."
491507
)
492508

493509
return avg_time
@@ -557,7 +573,7 @@ def _optimization_profiles(
557573
p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim(
558574
min_value, opt_value, max_value)
559575
generated_profiles.append(p)
560-
logger.debug(f"[Autotuner]: generated profile: {p}")
576+
logger.debug(f"[Autotuner] Generated profile: {p}")
561577
return generated_profiles
562578

563579
@classmethod
@@ -649,3 +665,12 @@ def clear_cache(self) -> None:
649665
def reset_statistics(self) -> None:
650666
"""Reset all statistics counters."""
651667
self.stats = AutoTunerStatistics()
668+
669+
def print_profiling_cache(self):
670+
logger.debug(f"[Autotuner] The profiling_cache entries:")
671+
logger.debug(
672+
f"[Autotuner] Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
673+
)
674+
for key, value in self.profiling_cache.items():
675+
runner_id, tactic, _ = value
676+
logger.debug(f"[Autotuner] {key}: ({runner_id}, {tactic})")

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,15 +693,16 @@ def disable_optimization(backend: Backend):
693693
# No KV cache space!
694694
pass
695695
else:
696-
logger.info(
697-
f"Run autotuning warmup for batch size={1}")
698696
self.forward(batch,
699697
new_tensors_device=None,
700698
resource_manager=resource_manager)
701699
torch.cuda.synchronize()
702700

703-
logger.info(f"Autotuner Cache size after warmup " +
704-
str(len(AutoTuner.get().profiling_cache)))
701+
logger.info(
702+
f"[Autotuner] Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
703+
)
704+
705+
AutoTuner.get().print_profiling_cache()
705706

706707
if not (self._run_cuda_graphs
707708
or self._torch_compile_piecewise_cuda_graph):

0 commit comments

Comments
 (0)