Skip to content

Commit bfcde08

Browse files
Yinghai Lufacebook-github-bot
Yinghai Lu
authored andcommitted
[trt] Algorithm recorder/replayer (#4)
Summary: Pull Request resolved: pytorch/pytorch-canary#4 Pull Request resolved: pytorch#67211 Record the algorithm selection, dump it in json format and replay it. This has potential to 1. consistently repro the issue (algo selection could be sensitive to local benchmark timing) 2. manual edit the dumped json file to control algorithm selection. Reviewed By: wushirong, 842974287 Differential Revision: D31888836 fbshipit-source-id: 4611fda548f7391776f1ad61572b1f59fa4665b6
1 parent ecf7e96 commit bfcde08

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

torch/fx/experimental/fx2trt/fx2trt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def run(
414414
int8_mode=False,
415415
force_fp32_output=False,
416416
strict_type_constraints=False,
417+
algorithm_selector=None,
417418
) -> TRTInterpreterResult:
418419
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
419420
# force_fp32_output=False.
@@ -444,6 +445,10 @@ def run(
444445
for optimization_profile in self.optimization_profiles:
445446
builder_config.add_optimization_profile(optimization_profile)
446447

448+
if algorithm_selector:
449+
builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
450+
builder_config.algorithm_selector = algorithm_selector
451+
447452
engine = self.builder.build_engine(self.network, builder_config)
448453
assert engine
449454
return engine, self._input_names, self._output_names

0 commit comments

Comments
 (0)