From c6041b64be0d1ad67330386286fab28a1e07def1 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 19 May 2025 22:32:54 -0700 Subject: [PATCH] [MaxText/inference_microbenchmark.py] Ensure correct exit code (closes #1742) --- MaxText/inference_microbenchmark.py | 7 ++++++- MaxText/inference_microbenchmark_sweep.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 667b6530a..be895c827 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -19,6 +19,7 @@ import jax import json import os +import sys from absl import app from collections.abc import MutableMapping @@ -426,10 +427,14 @@ def run_benchmarks(config): return results -def main(config, **kwargs): +def run_benchmarks_with_unsafe_rbg(config, **kwargs): jax.config.update("jax_default_prng_impl", "unsafe_rbg") return run_benchmarks(pyconfig.initialize(config, **kwargs)) +def main(config, **kwargs): + json.dump(run_benchmarks_with_unsafe_rbg(config, **kwargs), sys.stdout) + + if __name__ == "__main__": app.run(main) diff --git a/MaxText/inference_microbenchmark_sweep.py b/MaxText/inference_microbenchmark_sweep.py index d9d48fcaa..b36a6c132 100644 --- a/MaxText/inference_microbenchmark_sweep.py +++ b/MaxText/inference_microbenchmark_sweep.py @@ -123,7 +123,9 @@ def main(): **inference_metadata, } try: - microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata) + microbenchmark_results = inference_microbenchmark.run_benchmarks_with_unsafe_rbg( + config, inference_metadata=inference_metadata + ) if microbenchmark_results: metrics = microbenchmark_results["flattened_results"] metrics = {k.lower(): v for k, v in metrics.items()}