Skip to content

Commit 7eb8f1e

Browse files
improves float check + adds tests
1 parent 6b45f05 commit 7eb8f1e

File tree

4 files changed

+54
-5
lines changed

4 files changed

+54
-5
lines changed

keras_hub/src/models/task_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_quantized_preset_loading_and_saving(
163163
layer.kernel.dtype,
164164
expected_dtype,
165165
f"Layer '{layer.name}' kernel "
166-
"should have dtype '{expected_dtype}'",
166+
f"should have dtype '{expected_dtype}'",
167167
)
168168

169169
# Ensure inference runs without errors.

keras_hub/src/utils/preset_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,12 @@ def _resolve_dtype(self, config, kwargs):
723723
saved types verbatim.
724724
725725
Args:
726-
config: The model configuration.
727-
kwargs: Additional keyword arguments, potentially including `dtype`.
726+
config: dict. The model configuration.
727+
kwargs: dict. Additional keyword arguments, potentially including
728+
`dtype`.
728729
729730
Returns:
730-
The resolved dtype.
731+
str, dict, or DTypePolicy. The resolved dtype.
731732
"""
732733
# 1. If a user specified dtype is passed, use that.
733734
if "dtype" in kwargs and kwargs["dtype"] is not None:

keras_hub/src/utils/tensor_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,29 @@ def is_tensor_type(x):
310310

311311

312312
def is_float_dtype(dtype):
313-
return "float" in keras.backend.standardize_dtype(dtype)
313+
"""
314+
Checks if a dtype is a float type by using a regex.
315+
316+
This function standardizes the input dtype and then uses a regular
317+
expression to perform an exact match. It identifies standard floats,
318+
bfloats, and mixed-precision float types.
319+
320+
For example:
321+
- `is_float_dtype("float32")` returns `True`.
322+
- `is_float_dtype("bfloat16")` returns `True`.
323+
- `is_float_dtype("mixed_float16")` returns `True`.
324+
- `is_float_dtype("int8")` returns `False`.
325+
- `is_float_dtype("int8_from_float32")` returns `False`.
326+
327+
Args:
328+
dtype: str, DTypePolicy. The data type to check.
329+
330+
Returns:
331+
bool: `True` if the dtype is a floating-point type, `False` otherwise.
332+
"""
333+
pattern = re.compile(r"^(mixed_)?(b)?float[0-9]*$")
334+
standardized_dtype = keras.backend.standardize_dtype(dtype)
335+
return pattern.match(standardized_dtype) is not None
314336

315337

316338
def is_int_dtype(dtype):

keras_hub/src/utils/tensor_utils_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras_hub.src.utils.tensor_utils import convert_preprocessing_inputs
99
from keras_hub.src.utils.tensor_utils import convert_preprocessing_outputs
1010
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
11+
from keras_hub.src.utils.tensor_utils import is_float_dtype
1112
from keras_hub.src.utils.tensor_utils import is_tensor_type
1213
from keras_hub.src.utils.tensor_utils import preprocessing_function
1314
from keras_hub.src.utils.tensor_utils import target_gather
@@ -304,3 +305,28 @@ def test_target_gather_invalid_rank(self):
304305
indices = np.array([0, 1], dtype="int32")
305306
with self.assertRaisesRegex(ValueError, "larger than 3"):
306307
_ = target_gather(targets, indices)
308+
309+
310+
class IsFloatDtypeTest(TestCase):
311+
def test_float_dtypes_return_true(self):
312+
float_dtypes = [
313+
"float16",
314+
"float32",
315+
"float64",
316+
"bfloat16",
317+
]
318+
for dtype in float_dtypes:
319+
self.assertTrue(is_float_dtype(dtype))
320+
321+
def test_non_float_dtypes_return_false(self):
322+
non_float_dtypes = [
323+
"int8",
324+
"int32",
325+
"uint8",
326+
"bool",
327+
"string",
328+
"int8_from_float32",
329+
"int4_from_bfloat16",
330+
]
331+
for dtype in non_float_dtypes:
332+
self.assertFalse(is_float_dtype(dtype))

0 commit comments

Comments
 (0)