diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index b882b2347..5cec00410 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1538,7 +1538,13 @@ def view(context, node): shape = mb.concat(values=shape, axis=0) shape = mb.cast(x=shape, dtype="int32") - view = mb.reshape(x=x, shape=shape, name=node.name) + + if types.is_complex(x.dtype): + real, imag = (mb.reshape(x=x, shape=shape, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x))) + view = mb.complex(real_data=real, imag_data=imag, name=node.name) + else: + view = mb.reshape(x=x, shape=shape, name=node.name) + context.add(view) @@ -1565,7 +1571,11 @@ def pad(context, node): if inputs[val_index] and inputs[val_index].op.op_type == "const": scalar_val = float(scalar_val.val) - res = mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name) + if types.is_complex(x.dtype): + real, imag = (mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name) for x in (mb.complex_real(data=x), mb.complex_imag(data=x))) + res = mb.complex(real_data=real, imag_data=imag, name=node.name) + else: + res = mb.pad(x=x, pad=pad, mode=mode, constant_val=scalar_val, name=node.name) context.add(res) @@ -4427,8 +4437,11 @@ def index_select(context, node): @register_torch_op(torch_alias=["abs"]) def _abs(context, node): - inputs = _get_inputs(context, node, expected=1) - context.add(mb.abs(x=inputs[0], name=node.name)) + x = _get_inputs(context, node, expected=1)[0] + if types.is_complex(x.dtype): + context.add(mb.complex_abs(x=x, name=node.name)) + else: + context.add(mb.abs(x=x, name=node.name)) @register_torch_op @@ -5676,6 +5689,23 @@ def fft_irfftn(context, node): irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm) context.add(irfftn_res, node.name) +@register_torch_op +def stft(context, node): + """ + Lowers torch.stft with the dialect op `complex_stft` from complex_dialect_ops.py + """ + input_data, n_fft, hop_length, win_length, window, normalized, onesided, _ = _get_inputs(context, node, min_expected=2) + if types.is_complex(input_data.dtype): + onesided = False # pytorch defaults onesided to False for complex inputs + stft_res = mb.complex_stft( + input=input_data, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + normalized=normalized, + onesided=onesided) + context.add(stft_res, node.name) @register_torch_op(torch_alias=["torchvision::nms"]) def torchvision_nms(context, node): diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index d7416387f..d807ed313 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -11,6 +11,7 @@ import numpy as np import pytest import torch.nn as nn +import torchaudio import torchvision import coremltools as ct @@ -7877,6 +7878,26 @@ def forward(self, x): (2, 3, 4), ComplexModel(), backend=backend, compute_unit=compute_unit ) + @pytest.mark.parametrize( + "compute_unit, backend", + itertools.product( + compute_units, + backends, + ) + ) + def test_abs(self, compute_unit, backend): + class AbsModel(torch.nn.Module): + def forward(self, x): + x = torch.complex(x, x) + return torch.abs(x) + + TorchBaseTest.run_compare_torch( + (1, 16), + AbsModel(), + backend=backend, + compute_unit=compute_unit, + ) + class TestReal(TorchBaseTest): @pytest.mark.parametrize( @@ -8099,6 +8120,94 @@ def forward(self, x): (2, 3, 4), FftnModel(), backend=backend, compute_unit=compute_unit ) +class TestSTFT(TorchBaseTest): + @pytest.mark.slow + @pytest.mark.parametrize( + "compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided", + itertools.product( + compute_units, + backends, + [(1, 32), (32,), (3, 32)], # input shape + [False, True], # complex + [16], # n_fft + [None, 4, 5], # hop_length + [None, 16, 9], # win_length + [None, torch.hann_window], # window + [None, False, True], # center + ["constant", "reflect", "replicate"], # pad mode + [False, True], # normalized + [None, False, True], # onesided + ) + ) + def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): + if complex and onesided: + pytest.skip("Onesided stft not possible for complex inputs") + + class STFTModel(torch.nn.Module): + def forward(self, x): + applied_window = window(win_length) if window and win_length else None + x = torch.complex(x, x) if complex else x + x = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=applied_window, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + return_complex=True) + x = torch.stack([torch.real(x), torch.imag(x)], dim=0) + return x + + TorchBaseTest.run_compare_torch( + input_shape, + STFTModel(), + backend=backend, + compute_unit=compute_unit + ) + +class TestSpectrogram(TorchBaseTest): + @pytest.mark.parametrize( + "compute_unit, backend, input_shape, spec, power", + itertools.product( + compute_units, + backends, + [(1, 1000), (1000,), (3, 1000)], # input shape + [torchaudio.transforms.Spectrogram, torchaudio.transforms.MelSpectrogram], + [None, 1, 2] # magnitude or power + ) + ) + def test_spectrogram(self, compute_unit, backend, input_shape, spec, power): + if platform.machine() != "arm64": + pytest.xfail("rdar://108001659 ([PyTorch] Torchaudio Spectrogram Failed on Intel Machine)") + + if spec is torchaudio.transforms.MelSpectrogram and power is None: + pytest.skip("power or magnitude required for melspec") + + class SpectrogramModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # the other spectrogram options are passed through to stft + # and are tested in TestSTFT + self.spec = spec(power=power, n_fft=128) + + def forward(self, x): + x = self.spec(x) + if power is None: + # complex: stack them + x = torch.stack([torch.real(x), torch.imag(x)], dim=0) + return x + + TorchBaseTest.run_compare_torch( + input_shape, + SpectrogramModel(), + backend=backend, + compute_unit=compute_unit, + rtol=1e-4, + atol=1e-4, + ) class TestNms(TorchBaseTest): @pytest.mark.parametrize( diff --git a/coremltools/converters/mil/mil/input_type.py b/coremltools/converters/mil/mil/input_type.py index 8721d9275..267037f51 100644 --- a/coremltools/converters/mil/mil/input_type.py +++ b/coremltools/converters/mil/mil/input_type.py @@ -285,7 +285,7 @@ def type_domain(self): @type_domain.setter def type_domain(self, val): - msg = "type_domain must be a tuple of builtin types" + msg = f"type_domain {val} must be a tuple of builtin types" if not isinstance(val, tuple) or any(map(lambda t: t not in _SUPPORT_TYPES, val)): raise ValueError(msg) self._type_domain = val diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index 2a12f029b..c72673954 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -729,16 +729,135 @@ class complex_shape(Operation): "T": (types.complex64,), } + # If type_inference or value_inference is invoked when the graph is being constructed, + # x.real and x.imag may not be set since the complex lowering pass hasn't yet been invoked. + # self.x should already have the shape set, so use that instead. + def type_inference(self): if not isinstance(self.x, ComplexVar): raise ValueError("x must be a ComplexVar.") - input_rank = self.x.real.rank + input_rank = self.x.rank return types.tensor(types.int32, tuple([input_rank])) def value_inference(self): - if any_symbolic(self.x.real.shape): + if any_symbolic(self.x.shape): # convert elements in shape to int32 - res = [x if is_symbolic(x) else np.int32(x) for x in self.x.real.shape] + res = [x if is_symbolic(x) else np.int32(x) for x in self.x.shape] return np.array(res) else: - return np.array(self.x.real.shape).astype(np.int32) + return np.array(self.x.shape).astype(np.int32) + +@register_op(namespace="complex") +class complex_abs(Operation): + """ + Returns the absolute value of a complex tensor. + + Parameters + ---------- + x: tensor<[*d], T> (Required) + + Returns + ------- + tensor<[*d], fp32> + * A float tensor with the same shape as ``x`` + + Attributes + ---------- + T: complex64 + """ + + input_spec = InputSpec(x=TensorInputType(type_domain="T")) + + type_domains = { + "T": (types.complex64,), + } + + def type_inference(self): + if not isinstance(self.x, ComplexVar): + raise ValueError("x must be a ComplexVar.") + return types.tensor(infer_fp_dtype_from_complex(self.x.dtype), self.x.shape) + +@register_op(namespace="complex") +class complex_stft(Operation): + """ + Dialect op for 1-D STFT. + + Parameters + ---------- + input: tensor<\*D, T> (Required) + * The input tensor. + n_fft: const i32 (Required) + * Size of the fourier transform. + hop_length: const i32 (Optional) + * Stride between window frames of the input tensor. + win_length: const i32 (optional) + * The size of the window frame. + window: tensor<1, win_length> (optional) + * The window to apply to the input signal before performing the fourier transform. + normalized: const bool (optional, Default=``false``) + * Whether to normalize the results of the STFT + onesided: const bool (optional, Default=``true``) + * For real-valued inputs, whether to return the first half of the results. + + Returns + ------- + tensor<\*V, complex64> + * A complex tensor where real and imag parts have the same shape. + + Attributes + ---------- + T: fp32, complex64 + + References + ---------- + See `torch.stft `_. + """ + + input_spec = InputSpec( + input=TensorInputType(type_domain="T"), + n_fft=TensorInputType(const=True, type_domain=types.int32), + hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), + window=TensorInputType(const=True, optional=True, type_domain=types.fp32), + normalized=TensorInputType(const=True, optional=True, type_domain=types.bool), + onesided=TensorInputType(const=True, optional=True, type_domain=types.bool), + ) + + type_domains = { + "T": (types.fp32, types.complex64), + } + + def default_inputs(self): + return DefaultInputs( + hop_length = None, + win_length = None, + window = None, + normalized = False, + onesided = True, + ) + + def type_inference(self): + output_type = (types.complex64) + + # STFT shape is [B x N x T], where N is the number of frequency bins + # and T is the number of windows + # B is 1 for a time series or 2 for a batch of time series + + window_length = self.n_fft.val + hop = self.hop_length.val if self.hop_length else self.n_fft.val // 4 + + # if onesided is true, the input is real valued + # because of Hermitian symmetry, we only need to calculate the FFT + # for the first half of the frequences + if self.onesided and self.onesided.val: + window_length = window_length // 2 + 1 + + frames = (self.input.shape[-1] - self.n_fft.val) // hop + 1 + output_shape = [window_length, frames] + + # add back rank if needed + if self.input.rank == 2: + output_shape = [self.input.shape[0]] + output_shape + + return types.tensor(output_type, tuple(output_shape)) + diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 64943197d..ed36d87f3 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -139,19 +139,12 @@ def _restore_conj( return real_data, imag_data - -def _fft_1d( - input_real: Var, - input_imag: Var, - n: Optional[Var], - dim: Optional[Var], - norm: Optional[Var], - before_op: Operation, - inverse: bool = False, # For inverse FFT. +def _calculate_dft_matrix( + n_fft: Var, + onesided: bool = False, + before_op: Operation = None, ) -> Tuple[Var, Var]: """ - 1-D FFT by DFT Matrix Multiplication. - The core issue is how to derive the DFT matrix. As the DFT matrix is consist of different powers of `w`, where w=e^(2pi/N i), we need to separate the real and imaginary part of w. To achieve that, we need to find a way to construct the following matrix (from the power of `w` in DFT): @@ -163,7 +156,45 @@ def _fft_1d( This matrix could be derived by outer product of two range tensors. After getting that base matrix, we can take sin and cos to get the corresponding `sin_base` and - `cos_base` matrix. Now based on some math formulas including: + `cos_base` matrix. + + If the onesided flag is passed, we can take advantage of Hermitian symmetry and return a + weight matrix consisting of only the first (n_fft // 2 + 1) values. + """ + n_fft = mb.cast(x=n_fft, dtype="fp32", before_op=before_op) + half = mb.floor_div(x=n_fft, y=2., before_op=before_op) + half = mb.add(x=half, y=1., before_op=before_op) + + tmp_x = mb.range_1d(start=0.0, end=(half if onesided else n_fft), step=1.0, before_op=before_op) + tmp_y = mb.range_1d(start=0.0, end=n_fft, step=1.0, before_op=before_op) + + # Use MIL ops to calculate base = torch.outer(tmp, tmp) * (2 * torch.pi / N). + tmp_x = mb.reshape(x=tmp_x, shape=[-1, 1], before_op=before_op) + tmp_y = mb.reshape(x=tmp_y, shape=[1, -1], before_op=before_op) + + base = mb.matmul(x=tmp_x, y=tmp_y, before_op=before_op) + base = mb.mul(x=base, y=2 * np.pi, before_op=before_op) + base = mb.real_div(x=base, y=n_fft, before_op=before_op) + + # Get real part and imaginary part separately. + cos_base = mb.cos(x=base, before_op=before_op) + sin_base = mb.sin(x=base, before_op=before_op) + + return cos_base, sin_base + +def _fft_1d( + input_real: Var, + input_imag: Var, + n: Optional[Var], + dim: Optional[Var], + norm: Optional[Var], + before_op: Operation, + inverse: bool = False, # For inverse FFT. +) -> Tuple[Var, Var]: + """ + 1-D FFT by DFT Matrix Multiplication. + + Now based on some math formulas including: * The addition of complex numbers is: (a+bi)+(c+di)=(a+c)+(b+d)i. * The multiplication of complex numbers is: (a+bi)(c+di)=ac+adi+bci−bd=(ac−bd)+(ad+bc)i. * Euler’s formula: e^xi=cosx+isinx. @@ -202,18 +233,9 @@ def _fft_1d( N = transposed_input_real.shape[0] reshaped_input_real = mb.reshape(x=transposed_input_real, shape=[N, -1], before_op=before_op) reshaped_input_imag = mb.reshape(x=transposed_input_imag, shape=[N, -1], before_op=before_op) - tmp = mb.range_1d(start=0, end=N, step=1, before_op=before_op) - # Use MIL ops to calculate base = torch.outer(tmp, tmp) * (2 * torch.pi / N). - tmp_x = mb.reshape(x=tmp, shape=[-1, 1], before_op=before_op) - tmp_y = mb.reshape(x=tmp, shape=[1, -1], before_op=before_op) - base = mb.matmul(x=tmp_x, y=tmp_y, before_op=before_op) - base = mb.cast(x=base, dtype="fp32", before_op=before_op) - base = mb.mul(x=base, y=2 * np.pi, before_op=before_op) + N = mb.cast(x=N, dtype="fp32", before_op=before_op) - base = mb.real_div(x=base, y=N, before_op=before_op) - # Get real part and imaginary part separately. - cos_base = mb.cos(x=base, before_op=before_op) - sin_base = mb.sin(x=base, before_op=before_op) + cos_base, sin_base = _calculate_dft_matrix(N, onesided=False, before_op=before_op) if not inverse: real_part = mb.add( @@ -288,6 +310,92 @@ def _rfft_1d( return real_data, imag_data +def _stft( + input_real: Var, + input_imaginary: Optional[Var], + n_fft: Var, + hop_length: Optional[Var], + win_length: Optional[Var], + window: Optional[Var], + normalized: Optional[Var], + onesided: Optional[Var], + before_op: Operation, +) -> Tuple[Var, Var]: + """ + We can write STFT in terms of convolutions with a DFT kernel. + At the end: + * The real part output is: cos_base * input_real + sin_base * input_imag + * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) + Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py + """ + hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) + + # input should always be 2D + should_increase_rank = input_real.rank == 1 + if should_increase_rank: + input_real = mb.expand_dims(x=input_real, axes=(0,), before_op=before_op) + if input_imaginary: + input_imaginary = mb.expand_dims(x=input_imaginary, axes=(0,), before_op=before_op) + + is_onesided = onesided and onesided.val + cos_base, sin_base = _calculate_dft_matrix( + n_fft, + onesided=is_onesided, + before_op=before_op) + + # create a window of centered 1s of the requested size + if win_length: + n_left = (n_fft.val - win_length.val) // 2 + n_right = n_fft.val - win_length.val - n_left + + left = mb.fill(shape=(n_left,), value=0., before_op=before_op) + if not window: + window = mb.fill(shape=(win_length.val,), value=1., before_op=before_op) + right = mb.fill(shape=(n_right,), value=0., before_op=before_op) + + # concatenate + window = mb.concat(values=(left, window, right), axis=0, before_op=before_op) + + # apply time window + if window: + cos_base = mb.mul(x=window, y=cos_base, before_op=before_op) + sin_base = mb.mul(x=window, y=sin_base, before_op=before_op) + + # conv with DFT kernel across the input signal + sin_base = mb.sub(x=0., y=sin_base, before_op=before_op) + cos_base = mb.expand_dims(x=cos_base, axes=(1,), before_op=before_op) + sin_base = mb.expand_dims(x=sin_base, axes=(1,), before_op=before_op) + hop_size = mb.expand_dims(x=hop_length, axes=(0,), before_op=before_op) + + signal_real = mb.expand_dims(x=input_real, axes=(1,), before_op=before_op) + cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) + sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) + + if input_imaginary: + signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) + cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) + sin_windows_imag = mb.conv(x=signal_imaginary, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) + + # add everything together + if input_imaginary: + # sin base is already negative so subtract + real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.add(x=sin_windows_real, y=cos_windows_imag, before_op=before_op) + else: + real_result = cos_windows_real + imag_result = sin_windows_real + + # reduce the rank of the output + if should_increase_rank: + real_result = mb.squeeze(x=real_result, axes=(0,), before_op=before_op) + imag_result = mb.squeeze(x=imag_result, axes=(0,), before_op=before_op) + + if normalized and normalized.val: + divisor = mb.sqrt(x=mb.cast(x=n_fft, dtype="fp32", before_op=before_op), before_op=before_op) + real_result = mb.real_div(x=real_result, y=divisor, before_op=before_op) + imag_result = mb.real_div(x=imag_result, y=divisor, before_op=before_op) + + return real_result, imag_result def _wrap_complex_output(original_output: Var, real_data: Var, imag_data: Var) -> ComplexVar: return ComplexVar( @@ -483,11 +591,33 @@ def _lower_complex_irfftn(op: Operation): return real_data +@LowerComplex.register_lower_func(op_type="complex_stft") +def _lower_complex_stft(op: Operation): + is_complex = types.is_complex(op.input.dtype) + + # check parameters for validity + if op.win_length and op.win_length.val > op.n_fft.val: + raise ValueError("Window length must be less than or equal to n_fft") + if is_complex and op.onesided and op.onesided.val: + raise ValueError("Onesided is only valid for real inputs") + + real, imag = _stft( + op.input.real if is_complex else op.input, + op.input.imag if is_complex else None, + op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op) + + return _wrap_complex_output(op.outputs[0], real, imag) + @LowerComplex.register_lower_func(op_type="complex_shape") def _lower_complex_shape(op: Operation): return mb.shape(x=op.data.real, before_op=op) +@LowerComplex.register_lower_func(op_type="complex_abs") +def _lower_complex_abs(op: Operation): + mag_r, mag_i = (mb.square(x=x, before_op=op) for x in (op.x.real, op.x.imag)) + mag = mb.add(x=mag_r, y=mag_i, before_op=op) + return mb.sqrt(x=mag, before_op=op) def _match_and_replace_dialect_op(block, op): if not LowerComplex.has_lower_func(op.op_type): diff --git a/coremltools/converters/mil/mil/passes/tests/test_lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/tests/test_lower_complex_dialect_ops.py index a7a03fd15..cf6b5f409 100644 --- a/coremltools/converters/mil/mil/passes/tests/test_lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/tests/test_lower_complex_dialect_ops.py @@ -4,11 +4,15 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import numpy as np +import pytest +from coremltools import ComputeUnit from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil.passes.defs.lower_complex_dialect_ops import _calculate_dft_matrix from coremltools.converters.mil.testing_utils import ( apply_pass_and_basic_check, assert_model_is_valid, + ct_convert, get_op_types_in_program, ) @@ -54,3 +58,31 @@ def prog(x): inputs, expected_output_shapes={block.outputs[0].name: (1, 2, 3)}, ) + + @pytest.mark.parametrize( + "onesided", + [True, False] + ) + def test_calculate_dft_matrix(self, onesided): + expected_C = np.zeros((16, 16)) + expected_S = np.zeros((16, 16)) + + _range = np.arange(16) + for k in range(16): + expected_C[k, :] = np.cos(2 * np.pi * k * _range / 16) + expected_S[k, :] = np.sin(2 * np.pi * k * _range / 16) + + if onesided: + expected_C = expected_C[:9] + expected_S = expected_S[:9] + + @mb.program(input_specs=[mb.TensorSpec(shape=(1,))]) + def prog(x): + return _calculate_dft_matrix(x, onesided=onesided) + + model = ct_convert(program=prog, convert_to=("neuralnetwork", "fp32"), compute_units=ComputeUnit.CPU_ONLY) + p = model.predict({"x": np.array([16.0])}) + cos_matrix, sin_matrix = p["cos_0"], p["sin_0"] + + np.testing.assert_allclose(expected_C, cos_matrix, atol=1e-04, rtol=1e-05) + np.testing.assert_allclose(expected_S, sin_matrix, atol=1e-04, rtol=1e-05) diff --git a/reqs/test.pip b/reqs/test.pip index 374266bb2..98a340967 100644 --- a/reqs/test.pip +++ b/reqs/test.pip @@ -21,6 +21,7 @@ six sympy > 1.6 gast==0.4.0 torch==2.0.0 +torchaudio==2.0.1 torchvision==0.15.1 xgboost==1.4.2; platform_machine != "arm64" mock