Skip to content

Commit 6ddaefb

Browse files
authored
[Keras Ops] view_as_complex() and view_as_real() (#21221)
* Add view_as_complex and view_as_real methods and tests * Make op jax compatible * Run linter * Remove outdated test * Remove is_complex() call * Skip tests for openvino * fix backend call * update tests * update tests * Use standardize_dtype for comparisons * Move impl to math.py * Remove math namespace
1 parent 4595239 commit 6ddaefb

File tree

4 files changed

+170
-0
lines changed

4 files changed

+170
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
from keras.src.ops.math import segment_sum as segment_sum
6161
from keras.src.ops.math import stft as stft
6262
from keras.src.ops.math import top_k as top_k
63+
from keras.src.ops.math import view_as_complex as view_as_complex
64+
from keras.src.ops.math import view_as_real as view_as_real
6365
from keras.src.ops.nn import average_pool as average_pool
6466
from keras.src.ops.nn import batch_normalization as batch_normalization
6567
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/api/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
from keras.src.ops.math import segment_sum as segment_sum
6161
from keras.src.ops.math import stft as stft
6262
from keras.src.ops.math import top_k as top_k
63+
from keras.src.ops.math import view_as_complex as view_as_complex
64+
from keras.src.ops.math import view_as_real as view_as_real
6365
from keras.src.ops.nn import average_pool as average_pool
6466
from keras.src.ops.nn import batch_normalization as batch_normalization
6567
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/src/ops/math.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,3 +1044,101 @@ def logdet(x):
10441044
if any_symbolic_tensors((x,)):
10451045
return Logdet().symbolic_call(x)
10461046
return backend.math.logdet(x)
1047+
1048+
1049+
class ViewAsComplex(Operation):
1050+
def call(self, x):
1051+
x = backend.convert_to_tensor(x)
1052+
if len(x.shape) < 1 or x.shape[-1] != 2:
1053+
raise ValueError(
1054+
"Input tensor's last dimension must be 2 (real and imaginary)."
1055+
)
1056+
return x[..., 0] + 1j * x[..., 1]
1057+
1058+
def compute_output_spec(self, x):
1059+
return KerasTensor(shape=x.shape[:-1], dtype="complex64")
1060+
1061+
1062+
class ViewAsReal(Operation):
1063+
def call(self, x):
1064+
x = backend.convert_to_tensor(x)
1065+
real_part = backend.numpy.real(x)
1066+
imag_part = backend.numpy.imag(x)
1067+
return backend.numpy.stack((real_part, imag_part), axis=-1)
1068+
1069+
def compute_output_spec(self, x):
1070+
return KerasTensor(shape=x.shape + (2,), dtype="float32")
1071+
1072+
1073+
@keras_export("keras.ops.view_as_complex")
1074+
def view_as_complex(x):
1075+
"""Converts a real tensor with shape `(..., 2)` to a complex tensor,
1076+
where the last dimension represents the real and imaginary components
1077+
of a complex tensor.
1078+
1079+
Args:
1080+
x: A real tensor with last dimension of size 2.
1081+
1082+
Returns:
1083+
A complex tensor with shape `x.shape[:-1]`.
1084+
1085+
Example:
1086+
1087+
```
1088+
>>> import numpy as np
1089+
>>> from keras import ops
1090+
1091+
>>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
1092+
>>> complex_tensor = ops.view_as_complex(real_imag)
1093+
>>> complex_tensor
1094+
array([1.+2.j, 3.+4.j])
1095+
```
1096+
"""
1097+
if any_symbolic_tensors((x,)):
1098+
return ViewAsComplex().symbolic_call(x)
1099+
1100+
x = backend.convert_to_tensor(x)
1101+
if len(x.shape) < 1 or x.shape[-1] != 2:
1102+
raise ValueError(
1103+
"Last dimension of input must be size 2 (real and imaginary). "
1104+
f"Received shape: {x.shape}"
1105+
)
1106+
real_part = x[..., 0]
1107+
imag_part = x[..., 1]
1108+
1109+
return backend.cast(real_part, dtype="complex64") + 1j * backend.cast(
1110+
imag_part, dtype="complex64"
1111+
)
1112+
1113+
1114+
@keras_export("keras.ops.view_as_real")
1115+
def view_as_real(x):
1116+
"""Converts a complex tensor to a real tensor with shape `(..., 2)`,
1117+
where the last dimension represents the real and imaginary components.
1118+
1119+
Args:
1120+
x: A complex tensor.
1121+
1122+
Returns:
1123+
A real tensor where the last dimension contains the
1124+
real and imaginary parts.
1125+
1126+
Example:
1127+
```
1128+
>>> import numpy as np
1129+
>>> from keras import ops
1130+
1131+
>>> complex_tensor = np.array([1 + 2j, 3 + 4j])
1132+
>>> real = ops.view_as_real(complex_tensor)
1133+
>>> real
1134+
array([[1., 2.],
1135+
[3., 4.]])
1136+
```
1137+
"""
1138+
if any_symbolic_tensors((x,)):
1139+
return ViewAsReal().symbolic_call(x)
1140+
1141+
x = backend.convert_to_tensor(x)
1142+
real_part = backend.numpy.real(x)
1143+
imag_part = backend.numpy.imag(x)
1144+
return backend.numpy.stack((real_part, imag_part), axis=-1)

keras/src/ops/math_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import backend
1010
from keras.src import testing
1111
from keras.src.backend.common import dtypes
12+
from keras.src.backend.common import standardize_dtype
1213
from keras.src.backend.common.keras_tensor import KerasTensor
1314
from keras.src.ops import math as kmath
1415

@@ -1494,3 +1495,70 @@ def test_istft_invalid_window_shape_2D_inputs(self):
14941495
fft_length,
14951496
window=incorrect_window,
14961497
)
1498+
1499+
1500+
@pytest.mark.skipif(
1501+
backend.backend() == "openvino",
1502+
reason="Complex dtype is not supported on OpenVINO backend.",
1503+
)
1504+
class ViewAsComplexRealTest(testing.TestCase):
1505+
def test_view_as_complex_basic(self):
1506+
real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
1507+
expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64)
1508+
1509+
result = kmath.view_as_complex(real_imag)
1510+
1511+
self.assertEqual(result.shape, expected.shape)
1512+
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
1513+
self.assertAllClose(result, expected)
1514+
1515+
def test_view_as_real_basic(self):
1516+
complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
1517+
expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
1518+
1519+
result = kmath.view_as_real(complex_tensor)
1520+
1521+
self.assertEqual(result.shape, expected.shape)
1522+
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
1523+
self.assertAllClose(result, expected)
1524+
1525+
def test_view_as_complex_invalid_shape(self):
1526+
bad_input = np.array([1.0, 2.0, 3.0]) # Last dimension not size 2
1527+
with self.assertRaisesRegex(
1528+
ValueError, "Last dimension of input must be size 2"
1529+
):
1530+
kmath.view_as_complex(bad_input)
1531+
1532+
def test_view_as_complex_symbolic_input(self):
1533+
x = KerasTensor(shape=(None, 2), dtype="float32")
1534+
result = kmath.view_as_complex(x)
1535+
1536+
self.assertEqual(result.shape, (None,))
1537+
self.assertEqual(standardize_dtype(result.dtype), "complex64")
1538+
1539+
def test_view_as_real_symbolic_input(self):
1540+
x = KerasTensor(shape=(None,), dtype="complex64")
1541+
result = kmath.view_as_real(x)
1542+
1543+
self.assertEqual(result.shape, (None, 2))
1544+
self.assertEqual(standardize_dtype(result.dtype), "float32")
1545+
1546+
def test_view_as_complex_multi_dimensional(self):
1547+
x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)
1548+
expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)
1549+
1550+
result = kmath.view_as_complex(x)
1551+
1552+
self.assertEqual(result.shape, expected.shape)
1553+
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
1554+
self.assertAllClose(result, expected)
1555+
1556+
def test_view_as_real_multi_dimensional(self):
1557+
x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)
1558+
expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)
1559+
1560+
result = kmath.view_as_real(x)
1561+
1562+
self.assertEqual(result.shape, expected.shape)
1563+
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
1564+
self.assertAllClose(result, expected)

0 commit comments

Comments
 (0)