Skip to content

Commit 4e446b8

Browse files
zou3519facebook-github-bot
authored andcommitted
Make profiler.build_table() O(n) rather than O(n^2) (#10969)
Summary: Fixes #10851 Speeds up profiling results dramatically. For the following script: ``` import torch import time ITER = 2000 x = torch.randn(1, 1, requires_grad=True) with torch.autograd.profiler.profile() as prof: y = x for i in range(ITER): y = 3 * y - 2 * y y.backward() start = time.time() print("Done running. Preparing prof") x = str(prof) print("Done preparing prof results") end = time.time() print("Elapsed: {}".format(end - start)) ``` I get 7s before / 0.13s after these changes. cc apaszke Pull Request resolved: #10969 Differential Revision: D9556129 Pulled By: zou3519 fbshipit-source-id: 26b421686f8a42cdaace6382567d403e6385dc12
1 parent 396dec0 commit 4e446b8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch/autograd/profiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,11 @@ def build_table(events, sort_by=None, header=None):
554554
header_sep = '-' * max_name_length + (' ' + '-' * col_width) * 5
555555

556556
# Have to use a list because nonlocal is Py3 only...
557-
result = ['']
557+
result = []
558558

559559
def append(s):
560-
result[0] += s
561-
result[0] += '\n'
560+
result.append(s)
561+
result.append('\n') # Yes, newline after the end as well
562562

563563
# Actual printing
564564
if header is not None:
@@ -572,4 +572,4 @@ def append(s):
572572
append(row_format.format(evt.key, evt.cpu_time_str, evt.cuda_time_str,
573573
evt.count, evt.cpu_time_total_str, evt.cuda_time_total_str))
574574

575-
return result[0]
575+
return ''.join(result)

0 commit comments

Comments
 (0)