Skip to content

Commit 41228dd

Browse files
author
maxtext authors
committed
Merge pull request #1753 from SamuelMarks:inference_microbenchmark-EXIT_CODE
PiperOrigin-RevId: 761574119
2 parents a93d7d8 + c6041b6 commit 41228dd

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

MaxText/inference_microbenchmark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
import json
2121
import os
22+
import sys
2223

2324
from absl import app
2425
from collections.abc import MutableMapping
@@ -426,10 +427,14 @@ def run_benchmarks(config):
426427
return results
427428

428429

429-
def main(config, **kwargs):
430+
def run_benchmarks_with_unsafe_rbg(config, **kwargs):
430431
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
431432
return run_benchmarks(pyconfig.initialize(config, **kwargs))
432433

433434

435+
def main(config, **kwargs):
436+
json.dump(run_benchmarks_with_unsafe_rbg(config, **kwargs), sys.stdout)
437+
438+
434439
if __name__ == "__main__":
435440
app.run(main)

MaxText/inference_microbenchmark_sweep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def main():
123123
**inference_metadata,
124124
}
125125
try:
126-
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
126+
microbenchmark_results = inference_microbenchmark.run_benchmarks_with_unsafe_rbg(
127+
config, inference_metadata=inference_metadata
128+
)
127129
if microbenchmark_results:
128130
metrics = microbenchmark_results["flattened_results"]
129131
metrics = {k.lower(): v for k, v in metrics.items()}

0 commit comments

Comments
 (0)