Skip to content

Commit 134247c

Browse files
committed
Remove the deps/JetStream changes.
1 parent e8f1469 commit 134247c

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

benchmarks/run_offline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
5353
nums = 5
5454
start = time.perf_counter()
5555
for i in range(nums):
56-
if i == nums - 1 and FLAGS.profiling_prefill:
56+
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
5757
jax.profiler.start_trace(FLAGS.profiling_output)
5858
profiler_started = True
5959

@@ -66,7 +66,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
6666
jax.block_until_ready(decode_state)
6767

6868
end = time.perf_counter()
69-
return (end - start) / nums, decode_state
69+
return (end - start) / nums, decode_state, profiler_started
7070

7171

7272
MAXTEXT_PREFILL = {
@@ -93,7 +93,7 @@ def main(argv):
9393
decode_state = engine.init_decode_state()
9494
profiler_started = False
9595
for batch, _ in MAXTEXT_PREFILL.items():
96-
runtime, decode_state = run_prefill_time(
96+
runtime, decode_state, profiler_started = run_prefill_time(
9797
engine, params, decode_state, batch, profiler_started
9898
)
9999
prefill_times[batch] = runtime
@@ -109,6 +109,7 @@ def main(argv):
109109

110110
profiling_output = FLAGS.profiling_output
111111
print("======= decode starting ===")
112+
112113
dec_times = []
113114
for i in range(10):
114115
if profiling_output and i == 7 and not profiler_started:

0 commit comments

Comments
 (0)