@@ -53,7 +53,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
53
53
nums = 5
54
54
start = time .perf_counter ()
55
55
for i in range (nums ):
56
- if i == nums - 1 and FLAGS .profiling_prefill :
56
+ if i == nums - 1 and FLAGS .profiling_prefill and not profiler_started :
57
57
jax .profiler .start_trace (FLAGS .profiling_output )
58
58
profiler_started = True
59
59
@@ -66,7 +66,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
66
66
jax .block_until_ready (decode_state )
67
67
68
68
end = time .perf_counter ()
69
- return (end - start ) / nums , decode_state
69
+ return (end - start ) / nums , decode_state , profiler_started
70
70
71
71
72
72
MAXTEXT_PREFILL = {
@@ -93,7 +93,7 @@ def main(argv):
93
93
decode_state = engine .init_decode_state ()
94
94
profiler_started = False
95
95
for batch , _ in MAXTEXT_PREFILL .items ():
96
- runtime , decode_state = run_prefill_time (
96
+ runtime , decode_state , profiler_started = run_prefill_time (
97
97
engine , params , decode_state , batch , profiler_started
98
98
)
99
99
prefill_times [batch ] = runtime
@@ -109,6 +109,7 @@ def main(argv):
109
109
110
110
profiling_output = FLAGS .profiling_output
111
111
print ("======= decode starting ===" )
112
+
112
113
dec_times = []
113
114
for i in range (10 ):
114
115
if profiling_output and i == 7 and not profiler_started :
0 commit comments