Skip to content

[Keras Ops] view_as_complex() and view_as_real() #21221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 5, 2025
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from keras.src.ops.math import segment_sum as segment_sum
from keras.src.ops.math import stft as stft
from keras.src.ops.math import top_k as top_k
from keras.src.ops.math import view_as_complex as view_as_complex
from keras.src.ops.math import view_as_real as view_as_real
from keras.src.ops.nn import average_pool as average_pool
from keras.src.ops.nn import batch_normalization as batch_normalization
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
Expand Down
2 changes: 2 additions & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from keras.src.ops.math import segment_sum as segment_sum
from keras.src.ops.math import stft as stft
from keras.src.ops.math import top_k as top_k
from keras.src.ops.math import view_as_complex as view_as_complex
from keras.src.ops.math import view_as_real as view_as_real
from keras.src.ops.nn import average_pool as average_pool
from keras.src.ops.nn import batch_normalization as batch_normalization
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
Expand Down
98 changes: 98 additions & 0 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,101 @@ def logdet(x):
if any_symbolic_tensors((x,)):
return Logdet().symbolic_call(x)
return backend.math.logdet(x)


class ViewAsComplex(Operation):
def call(self, x):
x = backend.convert_to_tensor(x)
if len(x.shape) < 1 or x.shape[-1] != 2:
raise ValueError(
"Input tensor's last dimension must be 2 (real and imaginary)."
)
return x[..., 0] + 1j * x[..., 1]

def compute_output_spec(self, x):
return KerasTensor(shape=x.shape[:-1], dtype="complex64")


class ViewAsReal(Operation):
def call(self, x):
x = backend.convert_to_tensor(x)
real_part = backend.numpy.real(x)
imag_part = backend.numpy.imag(x)
return backend.numpy.stack((real_part, imag_part), axis=-1)

def compute_output_spec(self, x):
return KerasTensor(shape=x.shape + (2,), dtype="float32")


@keras_export("keras.ops.view_as_complex")
def view_as_complex(x):
"""Converts a real tensor with shape `(..., 2)` to a complex tensor,
where the last dimension represents the real and imaginary components
of a complex tensor.

Args:
x: A real tensor with last dimension of size 2.

Returns:
A complex tensor with shape `x.shape[:-1]`.

Example:

```
>>> import numpy as np
>>> from keras import ops

>>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
>>> complex_tensor = ops.view_as_complex(real_imag)
>>> complex_tensor
array([1.+2.j, 3.+4.j])
```
"""
if any_symbolic_tensors((x,)):
return ViewAsComplex().symbolic_call(x)

x = backend.convert_to_tensor(x)
if len(x.shape) < 1 or x.shape[-1] != 2:
raise ValueError(
"Last dimension of input must be size 2 (real and imaginary). "
f"Received shape: {x.shape}"
)
real_part = x[..., 0]
imag_part = x[..., 1]

return backend.cast(real_part, dtype="complex64") + 1j * backend.cast(
imag_part, dtype="complex64"
)


@keras_export("keras.ops.view_as_real")
def view_as_real(x):
"""Converts a complex tensor to a real tensor with shape `(..., 2)`,
where the last dimension represents the real and imaginary components.

Args:
x: A complex tensor.

Returns:
A real tensor where the last dimension contains the
real and imaginary parts.

Example:
```
>>> import numpy as np
>>> from keras import ops

>>> complex_tensor = np.array([1 + 2j, 3 + 4j])
>>> real = ops.view_as_real(complex_tensor)
>>> real
array([[1., 2.],
[3., 4.]])
```
"""
if any_symbolic_tensors((x,)):
return ViewAsReal().symbolic_call(x)

x = backend.convert_to_tensor(x)
real_part = backend.numpy.real(x)
imag_part = backend.numpy.imag(x)
return backend.numpy.stack((real_part, imag_part), axis=-1)
68 changes: 68 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.common import dtypes
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.ops import math as kmath

Expand Down Expand Up @@ -1494,3 +1495,70 @@ def test_istft_invalid_window_shape_2D_inputs(self):
fft_length,
window=incorrect_window,
)


@pytest.mark.skipif(
backend.backend() == "openvino",
reason="Complex dtype is not supported on OpenVINO backend.",
)
class ViewAsComplexRealTest(testing.TestCase):
def test_view_as_complex_basic(self):
real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64)

result = kmath.view_as_complex(real_imag)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_real_basic(self):
complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

result = kmath.view_as_real(complex_tensor)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_complex_invalid_shape(self):
bad_input = np.array([1.0, 2.0, 3.0]) # Last dimension not size 2
with self.assertRaisesRegex(
ValueError, "Last dimension of input must be size 2"
):
kmath.view_as_complex(bad_input)

def test_view_as_complex_symbolic_input(self):
x = KerasTensor(shape=(None, 2), dtype="float32")
result = kmath.view_as_complex(x)

self.assertEqual(result.shape, (None,))
self.assertEqual(standardize_dtype(result.dtype), "complex64")

def test_view_as_real_symbolic_input(self):
x = KerasTensor(shape=(None,), dtype="complex64")
result = kmath.view_as_real(x)

self.assertEqual(result.shape, (None, 2))
self.assertEqual(standardize_dtype(result.dtype), "float32")

def test_view_as_complex_multi_dimensional(self):
x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)
expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)

result = kmath.view_as_complex(x)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_real_multi_dimensional(self):
x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)
expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)

result = kmath.view_as_real(x)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)