-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[complex] conv1d #75013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[complex] conv1d #75013
Changes from 8 commits
9a52316
633fa59
98d710b
fa4de3d
df874a9
a137e4f
ebda8ba
fc8d3a6
e2ac66c
7f525aa
0525045
e8cb463
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -652,6 +652,88 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) { | |
return tensor.narrow(dim, n * g, n).contiguous(); | ||
} | ||
|
||
namespace { | ||
|
||
std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) { | ||
auto inp_view_as_complex = at::view_as_real(inp); | ||
auto dim_i = inp_view_as_complex.dim() - 1; | ||
auto i_r = inp_view_as_complex.select(dim_i, 0); | ||
auto i_i = inp_view_as_complex.select(dim_i, 1); | ||
return std::make_pair(i_r, i_i); | ||
} | ||
|
||
at::Tensor complex_convolution( | ||
const Tensor& input, | ||
const Tensor& weight, | ||
const Tensor& bias, | ||
IntArrayRef stride, | ||
IntArrayRef padding, | ||
IntArrayRef dilation, | ||
IntArrayRef output_padding, | ||
int64_t groups) { | ||
check_input_same_type_as_parameters(input, weight, bias); | ||
Tensor i_r, i_i, w_r, w_i; | ||
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj()); | ||
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj()); | ||
|
||
// [NOTE] Complex Convolution | ||
// conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) | ||
// where W, x and b are all complex inputs. | ||
// With Gauss Trick: | ||
// a = conv(Wr, xr, br), | ||
// b = conv(Wi, xi, 0), | ||
// c = conv(Wr + Wi, xr + xi, bi + br) | ||
// conv(W, x, b) = a - b + i(c - a - b) | ||
Tensor a, b, c; | ||
if (!bias.defined()) { | ||
a = at::convolution(i_r, w_r, bias, stride, padding, dilation, false, output_padding, groups); | ||
b = at::convolution(i_i, w_i, bias, stride, padding, dilation, false, output_padding, groups); | ||
c = at::convolution(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, false, output_padding, groups); | ||
} else { | ||
Tensor b_r, b_i; | ||
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj()); | ||
a = at::convolution(i_r, w_r, b_r, stride, padding, dilation, false, output_padding, groups); | ||
b = at::convolution(i_i, w_i, Tensor(), stride, padding, dilation, false, output_padding, groups); | ||
c = at::convolution(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, false, output_padding, groups); | ||
} | ||
|
||
auto i = c10::Scalar(c10::complex<double>(0, 1)); | ||
return a - b + i * (c - a - b); | ||
} | ||
|
||
at::Tensor complex_convolution_mode( | ||
const at::Tensor& input, | ||
const at::Tensor& weight, | ||
const c10::optional<at::Tensor>& bias_opt, | ||
at::IntArrayRef stride, | ||
c10::string_view padding, | ||
at::IntArrayRef dilation, | ||
int64_t groups) { | ||
auto bias = bias_opt.value_or(Tensor()); | ||
check_input_same_type_as_parameters(input, weight, bias); | ||
Tensor i_r, i_i, w_r, w_i; | ||
std::tie(i_r, i_i) = complex_to_real(input.resolve_conj()); | ||
std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj()); | ||
|
||
// See [NOTE] Complex Convolution | ||
Tensor a, b, c; | ||
if (!bias.defined()) { | ||
a = at::_convolution_mode(i_r, w_r, bias, stride, padding, dilation, groups); | ||
b = at::_convolution_mode(i_i, w_i, bias, stride, padding, dilation, groups); | ||
c = at::_convolution_mode(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups); | ||
} else { | ||
Tensor b_r, b_i; | ||
std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj()); | ||
a = at::_convolution_mode(i_r, w_r, b_r, stride, padding, dilation, groups); | ||
b = at::_convolution_mode(i_i, w_i, Tensor(), stride, padding, dilation, groups); | ||
c = at::_convolution_mode(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups); | ||
} | ||
|
||
auto i = c10::Scalar(c10::complex<double>(0, 1)); | ||
return a - b + i * (c - a - b); | ||
} | ||
|
||
} // namespace | ||
|
||
at::Tensor conv1d( | ||
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, | ||
|
@@ -663,7 +745,12 @@ at::Tensor conv1d( | |
Tensor input; | ||
bool is_batched; | ||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d"); | ||
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups); | ||
Tensor output; | ||
if (at::isComplexType(input_.scalar_type())) { | ||
output = complex_convolution(input, weight, bias, stride, padding, dilation, {0}, groups); | ||
} else { | ||
output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups); | ||
} | ||
return is_batched ? output : output.squeeze(0); | ||
} | ||
|
||
|
@@ -787,8 +874,12 @@ at::Tensor conv1d( | |
Tensor input; | ||
bool is_batched; | ||
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d"); | ||
auto output = at::_convolution_mode( | ||
input, weight, bias, stride, std::move(padding), dilation, groups); | ||
Tensor output; | ||
if (at::isComplexType(input_.scalar_type())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice so adds support for complex half too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at this point as |
||
output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups); | ||
} else { | ||
output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); | ||
} | ||
return is_batched ? output : output.squeeze(0); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11181,8 +11181,9 @@ def ref_pairwise_distance(input1, input2): | |||||||||||||||||||||||||||||||||||||||||||||||||
OpInfo('nn.functional.conv1d', | ||||||||||||||||||||||||||||||||||||||||||||||||||
aliases=('conv1d',), | ||||||||||||||||||||||||||||||||||||||||||||||||||
aten_name='conv1d', | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtypes=floating_types_and(torch.int64), | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtypes=floating_and_complex_types_and(torch.int64), | ||||||||||||||||||||||||||||||||||||||||||||||||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, | ||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should also test complex half |
||||||||||||||||||||||||||||||||||||||||||||||||||
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), | ||||||||||||||||||||||||||||||||||||||||||||||||||
sample_inputs_func=sample_inputs_conv1d, | ||||||||||||||||||||||||||||||||||||||||||||||||||
supports_forward_ad=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
supports_fwgrad_bwgrad=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -11191,6 +11192,9 @@ def ref_pairwise_distance(input1, input2): | |||||||||||||||||||||||||||||||||||||||||||||||||
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at | ||||||||||||||||||||||||||||||||||||||||||||||||||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. | ||||||||||||||||||||||||||||||||||||||||||||||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), | ||||||||||||||||||||||||||||||||||||||||||||||||||
# AssertionError: None mismatch: torch.complex128 is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', | ||||||||||||||||||||||||||||||||||||||||||||||||||
'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), | ||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Complete Stack Trace:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure. Looking at the test it seems to check if dtype is propagated correctly in JIT traced graph. (I was planning to point someone from the JIT team to unblock this PR) pytorch/test/jit/test_dtype_analysis.py Lines 334 to 357 in ad028e5
|
||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||
supports_expanded_weight=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
supports_out=False,), | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.