Skip to content

Commit 3d4dd80

Browse files
committed
chore: Improve the AutoTuner log information.
* Change fallback alert from DEBUG level to WARNING level and only do once. * Add debug information for profiling cache right after warmup phase. * Change the level of exception message during tactic profiling from ERROR to WARNING. And all exception details are pushed to DEBUG. * Other trivial refinements. Signed-off-by: Yukun He <[email protected]>
1 parent 470544c commit 3d4dd80

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -350,16 +350,11 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
350350
runner = runners[runner_id]
351351
# TODO: check the stored runner and tactic can implement this shape here
352352
# Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.
353-
354-
# Record the cache miss config.
355-
# Expect no cache miss in inference. Thus, any cache miss should be recorded.
356-
if not is_cache_hit:
357-
logger.debug(
358-
f"[AutoTunner]: Using fallback tactic for {custom_op} with input shapes {input_shapes}"
359-
)
360-
logger.debug(
361-
f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}"
362-
)
353+
if not is_cache_hit and len(self.profiling_cache) > 0:
354+
# Only log once for each custom op and only when cache is not empty
355+
logger.warning_once(
356+
f"[AutoTunner]: Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
357+
key=(custom_op))
363358
return runner, tactic
364359

365360
assert len(runners) > 0, "At least one runner is required"
@@ -396,9 +391,10 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
396391
except Exception as e:
397392
shapes = self._get_input_sizes(tensors)
398393

399-
logger.error(
400-
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
394+
logger.warning(
395+
f"[Autotuner]: Failed when profiling runner={r}, tactic={tac}, shapes={shapes}. Set TLLM_LOG_LEVEL=DEBUG for details."
401396
)
397+
logger.debug(f"[Autotuner] Exception captured: {e}")
402398

403399
# Record the failed profiling combinations
404400
if custom_op not in self.stats.failed_profiling_count:
@@ -426,7 +422,11 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner],
426422
custom_op] = self.stats.tuned_op_successful_configs.get(
427423
custom_op, 0) + 1
428424
logger.debug(
429-
f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}"
425+
f"[Autotuner]: Profiling runner={runners[runner_id]}, tactic={tactic} for cache_key={cache_key}"
426+
)
427+
else:
428+
logger.error(
429+
f"[Autotuner]: No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. At least one valid (runner, tactic) pair is required. Please check the implementation of the runner for the custom_op."
430430
)
431431

432432
# Get the best runner and tactic from cache
@@ -487,7 +487,7 @@ def _profile_single_kernel(self, runner: TunableRunner,
487487

488488
shapes = self._get_input_sizes(inputs)
489489
logger.debug(
490-
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
490+
f"[Autotuner]: Profiled runner={runner}, tactic={tactic}, shapes={shapes}: {avg_time}ms."
491491
)
492492

493493
return avg_time
@@ -649,3 +649,12 @@ def clear_cache(self) -> None:
649649
def reset_statistics(self) -> None:
650650
"""Reset all statistics counters."""
651651
self.stats = AutoTunerStatistics()
652+
653+
def print_profiling_cache(self):
654+
logger.debug(f"[Autotuner]: The profiling_cache entries:")
655+
logger.debug(
656+
f"[Autotuner]: Cache contents: (custom_op, runner, hash(attributes), shape_profiles) -> (runner_id, tactic, shape_profile(ignored))"
657+
)
658+
for key, value in self.profiling_cache.items():
659+
runner_id, tactic, _ = value
660+
logger.debug(f"[Autotuner]: {key}: ({runner_id}, {tactic})")

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -693,15 +693,16 @@ def disable_optimization(backend: Backend):
693693
# No KV cache space!
694694
pass
695695
else:
696-
logger.info(
697-
f"Run autotuning warmup for batch size={1}")
698696
self.forward(batch,
699697
new_tensors_device=None,
700698
resource_manager=resource_manager)
701699
torch.cuda.synchronize()
702700

703-
logger.info(f"Autotuner Cache size after warmup " +
704-
str(len(AutoTuner.get().profiling_cache)))
701+
logger.info(
702+
f"[Autotuner]: Cache size after warmup is {len(AutoTuner.get().profiling_cache)}"
703+
)
704+
705+
AutoTuner.get().print_profiling_cache()
705706

706707
if not (self._run_cuda_graphs
707708
or self._torch_compile_piecewise_cuda_graph):

0 commit comments

Comments
 (0)