Skip to content

Jax float64 precision issues do not play ball with hypothesis #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ev-br opened this issue May 5, 2025 · 3 comments
Open

Jax float64 precision issues do not play ball with hypothesis #368

ev-br opened this issue May 5, 2025 · 3 comments

Comments

@ev-br
Copy link
Member

ev-br commented May 5, 2025

A typical example is (test_diff):

self = <hypothesis.extra.array_api.ArrayStrategy object at 0x7e6a6cf6c990>, val = 2.112233982580733, val_0d = Array(2.1122339, dtype=float32)
strategy = FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308)

    def check_set_value(self, val, val_0d, strategy):
        if val == val and self.builtin(val_0d) != val:
            if self.builtin is float:
                assert self.finfo is not None  # for mypy
                try:
                    is_subnormal = 0 < abs(val) < self.finfo.smallest_normal
                except Exception:
                    # val may be a non-float that does not support the
                    # operations __lt__ and __abs__
                    is_subnormal = False
                if is_subnormal:
                    raise InvalidArgument(
                        f"Generated subnormal float {val} from strategy "
                        f"{strategy} resulted in {val_0d!r}, probably "
                        f"as a result of array module {self.xp.__name__} "
                        "being built with flush-to-zero compiler options. "
                        "Consider passing allow_subnormal=False."
                    )
>           raise InvalidArgument(
                f"Generated array element {val!r} from strategy {strategy} "
                f"cannot be represented with dtype {self.dtype}. "
                f"Array module {self.xp.__name__} instead "
                f"represents the element as {val_0d}. "
                "Consider using a more precise elements strategy, "
                "for example passing the width argument to floats()."
            )
E           hypothesis.errors.InvalidArgument: Generated array element 2.112233982580733 from strategy FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308) cannot be represented with dtype <class 'jax.numpy.float64'>. Array module jax.numpy instead represents the element as 2.112233877182007. Consider using a more precise elements strategy, for example passing the width argument to floats().
E           while generating 'x' from sampled_from((<class 'jax.numpy.uint8'>, <class 'jax.numpy.int8'>, <class 'jax.numpy.int16'>, <class 'jax.numpy.int32'>, <class 'jax.numpy.float32'>, <class 'jax.numpy.float64'>, <class 'jax.numpy.complex64'>, <class 'jax.numpy.complex128'>)).flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
E           Explanation:
E               These lines were always and only run by failing examples:
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:328
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:651
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/numpy/_core/getlimits.py:609

@ev-br ev-br mentioned this issue May 5, 2025
@jakevdp
Copy link
Contributor

jakevdp commented May 5, 2025

JAX is only compliant with the array api drype semantics when jax_enable_x64 is set to true. Any testing would have to take that into account.

ev-br added a commit to ev-br/array-api-tests that referenced this issue May 5, 2025
@ev-br
Copy link
Member Author

ev-br commented May 9, 2025

Thanks Jake!
So for completeness, the stanza to locally run a test from the test suite is

$ JAX_ENABLE_X64=true ARRAY_API_TESTS_VERSION="2024.12" ARRAY_API_TESTS_MODULE=jax.numpy pytest path/to/test

(EDITED to account for the correction below.)

@jakevdp
Copy link
Contributor

jakevdp commented May 9, 2025

Thanks Jake! So for completeness, the stanza to locally run a test from the test suite is

$ JAX_ENABLE_FLOAT64=True ARRAY_API_TESTS_VERSION="2024.12" ARRAY_API_TESTS_MODULE=jax.numpy pytest path/to/test

Almost – the env variable is JAX_ENABLE_X64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants