Skip to content

Commit df85a92

Browse files
mengluy0125pytorchmergebot
authored andcommitted
[Inductor][Observability] Add logging for split cat pass (pytorch#116442)
Summary: Add logs for both in the pre and post grad passes Test Plan: ``` buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode split_batch ``` [2023-12-26 17:14:24,203] [0/0] torch._inductor.fx_passes.post_grad: [INFO] counters of inductor dict after apply the split cat in the post grad pass: Counter({'pattern_matcher_nodes': 4076, 'pattern_matcher_count': 2917, 'remove_split_with_size_one': 1322, 'split_cat_norm': 461, 'consecutive_split_merged': 371, 'scmerge_cat_removed': 41, 'scmerge_cat_added': 32, 'scmerge_split_removed': 28, 'getitem_cat_merged': 11, 'batch_fusion': 7, 'scmerge_split_sections_removed': 3, 'scmerge_split_added': 2, 'split_squeeze_replaced': 2}) [2023-12-26 17:16:28,437] torch._inductor.fx_passes.post_grad: [INFO] counters of inductor dict after apply the split cat in the post grad pass: Counter({'pattern_matcher_nodes': 4122, 'pattern_matcher_count': 2935, 'remove_split_with_size_one': 1322, 'split_cat_norm': 461, 'consecutive_split_merged': 371, 'scmerge_cat_removed': 41, 'batch_fusion': 39, 'scmerge_cat_added': 32, 'scmerge_split_removed': 28, 'getitem_cat_merged': 11, 'scmerge_split_sections_removed': 3, 'scmerge_split_added': 2, 'split_squeeze_replaced': 2}) Differential Revision: D52425400 Pull Request resolved: pytorch#116442 Approved by: https://github.com/yanboliang
1 parent 8deaa13 commit df85a92

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

test/dynamo/test_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_dynamo_error(self, records):
143143
)
144144

145145
test_aot = within_range_record_test(2, 6, aot=logging.INFO)
146-
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
146+
test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG)
147147
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
148148

149149
@make_logging_test()

torch/_inductor/compile_fx.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
logging as dynamo_logging,
3333
utils as dynamo_utils,
3434
)
35-
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
35+
from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code
3636
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
3737
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
3838

@@ -511,6 +511,10 @@ def fx_codegen_and_compile(
511511
post_grad_passes(gm, is_inference=is_inference)
512512
V.debug.fx_graph_transformed(gm, example_inputs)
513513
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
514+
log.debug(
515+
"counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s",
516+
counters["inductor"],
517+
)
514518

515519
with V.set_fake_mode(fake_mode):
516520
graph = GraphLowering(
@@ -1010,6 +1014,10 @@ def compile_fx(
10101014
)
10111015

10121016
model_ = pre_grad_passes(model_, example_inputs_)
1017+
log.debug(
1018+
"counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s",
1019+
counters["inductor"],
1020+
)
10131021

10141022
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
10151023
return flatten_graph_inputs(

0 commit comments

Comments
 (0)