diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 0bbcd31189cc..a0faf7003f8e 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -60,6 +60,8 @@ from keras.src.ops.math import segment_sum as segment_sum from keras.src.ops.math import stft as stft from keras.src.ops.math import top_k as top_k +from keras.src.ops.math import view_as_complex as view_as_complex +from keras.src.ops.math import view_as_real as view_as_real from keras.src.ops.nn import average_pool as average_pool from keras.src.ops.nn import batch_normalization as batch_normalization from keras.src.ops.nn import binary_crossentropy as binary_crossentropy diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 0bbcd31189cc..a0faf7003f8e 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -60,6 +60,8 @@ from keras.src.ops.math import segment_sum as segment_sum from keras.src.ops.math import stft as stft from keras.src.ops.math import top_k as top_k +from keras.src.ops.math import view_as_complex as view_as_complex +from keras.src.ops.math import view_as_real as view_as_real from keras.src.ops.nn import average_pool as average_pool from keras.src.ops.nn import batch_normalization as batch_normalization from keras.src.ops.nn import binary_crossentropy as binary_crossentropy diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index 6cedef62cee4..44539428b31b 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -1044,3 +1044,101 @@ def logdet(x): if any_symbolic_tensors((x,)): return Logdet().symbolic_call(x) return backend.math.logdet(x) + + +class ViewAsComplex(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + if len(x.shape) < 1 or x.shape[-1] != 2: + raise ValueError( + "Input tensor's last dimension must be 2 (real and imaginary)." + ) + return x[..., 0] + 1j * x[..., 1] + + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape[:-1], dtype="complex64") + + +class ViewAsReal(Operation): + def call(self, x): + x = backend.convert_to_tensor(x) + real_part = backend.numpy.real(x) + imag_part = backend.numpy.imag(x) + return backend.numpy.stack((real_part, imag_part), axis=-1) + + def compute_output_spec(self, x): + return KerasTensor(shape=x.shape + (2,), dtype="float32") + + +@keras_export("keras.ops.view_as_complex") +def view_as_complex(x): + """Converts a real tensor with shape `(..., 2)` to a complex tensor, + where the last dimension represents the real and imaginary components + of a complex tensor. + + Args: + x: A real tensor with last dimension of size 2. + + Returns: + A complex tensor with shape `x.shape[:-1]`. + + Example: + + ``` + >>> import numpy as np + >>> from keras import ops + + >>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]]) + >>> complex_tensor = ops.view_as_complex(real_imag) + >>> complex_tensor + array([1.+2.j, 3.+4.j]) + ``` + """ + if any_symbolic_tensors((x,)): + return ViewAsComplex().symbolic_call(x) + + x = backend.convert_to_tensor(x) + if len(x.shape) < 1 or x.shape[-1] != 2: + raise ValueError( + "Last dimension of input must be size 2 (real and imaginary). " + f"Received shape: {x.shape}" + ) + real_part = x[..., 0] + imag_part = x[..., 1] + + return backend.cast(real_part, dtype="complex64") + 1j * backend.cast( + imag_part, dtype="complex64" + ) + + +@keras_export("keras.ops.view_as_real") +def view_as_real(x): + """Converts a complex tensor to a real tensor with shape `(..., 2)`, + where the last dimension represents the real and imaginary components. + + Args: + x: A complex tensor. + + Returns: + A real tensor where the last dimension contains the + real and imaginary parts. + + Example: + ``` + >>> import numpy as np + >>> from keras import ops + + >>> complex_tensor = np.array([1 + 2j, 3 + 4j]) + >>> real = ops.view_as_real(complex_tensor) + >>> real + array([[1., 2.], + [3., 4.]]) + ``` + """ + if any_symbolic_tensors((x,)): + return ViewAsReal().symbolic_call(x) + + x = backend.convert_to_tensor(x) + real_part = backend.numpy.real(x) + imag_part = backend.numpy.imag(x) + return backend.numpy.stack((real_part, imag_part), axis=-1) diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 5786827717d1..d3ca13e5aac2 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -9,6 +9,7 @@ from keras.src import backend from keras.src import testing from keras.src.backend.common import dtypes +from keras.src.backend.common import standardize_dtype from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.ops import math as kmath @@ -1494,3 +1495,70 @@ def test_istft_invalid_window_shape_2D_inputs(self): fft_length, window=incorrect_window, ) + + +@pytest.mark.skipif( + backend.backend() == "openvino", + reason="Complex dtype is not supported on OpenVINO backend.", +) +class ViewAsComplexRealTest(testing.TestCase): + def test_view_as_complex_basic(self): + real_imag = np.array([[1.0, 2.0], [3.0, 4.0]]) + expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64) + + result = kmath.view_as_complex(real_imag) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_real_basic(self): + complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64) + expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + result = kmath.view_as_real(complex_tensor) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_complex_invalid_shape(self): + bad_input = np.array([1.0, 2.0, 3.0]) # Last dimension not size 2 + with self.assertRaisesRegex( + ValueError, "Last dimension of input must be size 2" + ): + kmath.view_as_complex(bad_input) + + def test_view_as_complex_symbolic_input(self): + x = KerasTensor(shape=(None, 2), dtype="float32") + result = kmath.view_as_complex(x) + + self.assertEqual(result.shape, (None,)) + self.assertEqual(standardize_dtype(result.dtype), "complex64") + + def test_view_as_real_symbolic_input(self): + x = KerasTensor(shape=(None,), dtype="complex64") + result = kmath.view_as_real(x) + + self.assertEqual(result.shape, (None, 2)) + self.assertEqual(standardize_dtype(result.dtype), "float32") + + def test_view_as_complex_multi_dimensional(self): + x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32) + expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64) + + result = kmath.view_as_complex(x) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected) + + def test_view_as_real_multi_dimensional(self): + x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64) + expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32) + + result = kmath.view_as_real(x) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(standardize_dtype(result.dtype), expected.dtype) + self.assertAllClose(result, expected)