Skip to content

[complex] conv1d #75013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
97 changes: 94 additions & 3 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,88 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
return tensor.narrow(dim, n * g, n).contiguous();
}

namespace {

std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
auto inp_view_as_complex = at::view_as_real(inp);
auto dim_i = inp_view_as_complex.dim() - 1;
auto i_r = inp_view_as_complex.select(dim_i, 0);
auto i_i = inp_view_as_complex.select(dim_i, 1);
return std::make_pair(i_r, i_i);
}

at::Tensor complex_convolution(
const Tensor& input,
const Tensor& weight,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
IntArrayRef output_padding,
int64_t groups) {
check_input_same_type_as_parameters(input, weight, bias);
Tensor i_r, i_i, w_r, w_i;
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj());
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj());

// [NOTE] Complex Convolution
// conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
// where W, x and b are all complex inputs.
// With Gauss Trick:
// a = conv(Wr, xr, br),
// b = conv(Wi, xi, 0),
// c = conv(Wr + Wi, xr + xi, bi + br)
// conv(W, x, b) = a - b + i(c - a - b)
Tensor a, b, c;
if (!bias.defined()) {
a = at::convolution(i_r, w_r, bias, stride, padding, dilation, false, output_padding, groups);
b = at::convolution(i_i, w_i, bias, stride, padding, dilation, false, output_padding, groups);
c = at::convolution(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, false, output_padding, groups);
} else {
Tensor b_r, b_i;
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj());
a = at::convolution(i_r, w_r, b_r, stride, padding, dilation, false, output_padding, groups);
b = at::convolution(i_i, w_i, Tensor(), stride, padding, dilation, false, output_padding, groups);
c = at::convolution(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, false, output_padding, groups);
}

auto i = c10::Scalar(c10::complex<double>(0, 1));
return a - b + i * (c - a - b);
}

at::Tensor complex_convolution_mode(
const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
at::IntArrayRef stride,
c10::string_view padding,
at::IntArrayRef dilation,
int64_t groups) {
auto bias = bias_opt.value_or(Tensor());
check_input_same_type_as_parameters(input, weight, bias);
Tensor i_r, i_i, w_r, w_i;
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj());
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj());

// See [NOTE] Complex Convolution
Tensor a, b, c;
if (!bias.defined()) {
a = at::_convolution_mode(i_r, w_r, bias, stride, padding, dilation, groups);
b = at::_convolution_mode(i_i, w_i, bias, stride, padding, dilation, groups);
c = at::_convolution_mode(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
} else {
Tensor b_r, b_i;
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj());
a = at::_convolution_mode(i_r, w_r, b_r, stride, padding, dilation, groups);
b = at::_convolution_mode(i_i, w_i, Tensor(), stride, padding, dilation, groups);
c = at::_convolution_mode(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
}

auto i = c10::Scalar(c10::complex<double>(0, 1));
return a - b + i * (c - a - b);
}

} // namespace

at::Tensor conv1d(
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
Expand All @@ -663,7 +745,12 @@ at::Tensor conv1d(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
output = complex_convolution(input, weight, bias, stride, padding, dilation, {0}, groups);
} else {
output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
}
return is_batched ? output : output.squeeze(0);
}

Expand Down Expand Up @@ -787,8 +874,12 @@ at::Tensor conv1d(
Tensor input;
bool is_batched;
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
auto output = at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice so adds support for complex half too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at this point as +, -, *, view_as_real will need to be supported first.

output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
} else {
output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups);
}
return is_batched ? output : output.squeeze(0);
}

Expand Down
8 changes: 6 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11181,8 +11181,9 @@ def ref_pairwise_distance(input1, input2):
OpInfo('nn.functional.conv1d',
aliases=('conv1d',),
aten_name='conv1d',
dtypes=floating_types_and(torch.int64),
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypes=floating_and_complex_types_and(torch.int64),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also test complex half

*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_conv1d,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand All @@ -11191,6 +11192,9 @@ def ref_pairwise_distance(input1, input2):
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# AssertionError: None mismatch: torch.complex128 is not None
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete Stack Trace:

raceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 361, in test_custom_rules
    self.custom_rules_test_base(device, dtype, op)
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 357, in custom_rules_test_base
    self.assert_output_dtype_equal(expected_res, graph)
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 323, in assert_output_dtype_equal
    self.assert_tensor_dtype_equal(expected_res, actual_dtype[0])
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 332, in assert_tensor_dtype_equal
    self.assertEqual(tensor_output.dtype, graph_dtype)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2160, in assertEqual
    assert_equal(
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1084, in assert_equal
    raise error_metas[0].to_error()
AssertionError: None mismatch: torch.complex128 is not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does TestDtypeCustomRules. test_custom_rules test? why is this an expected failure?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. Looking at the test it seems to check if dtype is propagated correctly in JIT traced graph. (I was planning to point someone from the JIT team to unblock this PR)

def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False):
try:
samples = op.sample_inputs(device, dtype, requires_grad=False)
sample_input = first_sample(self, samples)
input_args = [sample_input.input, *sample_input.args]
expected_res = op(*input_args, **sample_input.kwargs)
except Exception as e:
if allow_eager_fail:
return
else:
raise e
func = op.get_op()
traced_fn = create_traced_fn(self, func)
# Have to run the traced function to actually generate the trace
traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs)
# Run the Dtype Analysis
graph = traced_fn.graph # Note this is a cached graph
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
self.prop_dtype_on_graph(graph, input_tensors)
self.assert_output_dtype_equal(expected_res, graph)

),
supports_expanded_weight=True,
supports_out=False,),
Expand Down