Skip to content

[MaxText/inference_microbenchmark.py] Ensure correct exit code #1753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax
import json
import os
import sys

from absl import app
from collections.abc import MutableMapping
Expand Down Expand Up @@ -426,10 +427,14 @@ def run_benchmarks(config):
return results


def main(config, **kwargs):
def run_benchmarks_with_unsafe_rbg(config, **kwargs):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
return run_benchmarks(pyconfig.initialize(config, **kwargs))


def main(config, **kwargs):
json.dump(run_benchmarks_with_unsafe_rbg(config, **kwargs), sys.stdout)


if __name__ == "__main__":
app.run(main)
4 changes: 3 additions & 1 deletion MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def main():
**inference_metadata,
}
try:
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
microbenchmark_results = inference_microbenchmark.run_benchmarks_with_unsafe_rbg(
config, inference_metadata=inference_metadata
)
if microbenchmark_results:
metrics = microbenchmark_results["flattened_results"]
metrics = {k.lower(): v for k, v in metrics.items()}
Expand Down
Loading