-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[complex] conv1d #75013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[complex] conv1d #75013
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit e8cb463 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
@@ -11191,6 +11192,9 @@ def ref_pairwise_distance(input1, input2): | |||
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at | |||
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. | |||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), | |||
# AssertionError: None mismatch: torch.complex128 is not None | |||
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', | |||
'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complete Stack Trace:
raceback (most recent call last):
File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 361, in test_custom_rules
self.custom_rules_test_base(device, dtype, op)
File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 357, in custom_rules_test_base
self.assert_output_dtype_equal(expected_res, graph)
File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 323, in assert_output_dtype_equal
self.assert_tensor_dtype_equal(expected_res, actual_dtype[0])
File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 332, in assert_tensor_dtype_equal
self.assertEqual(tensor_output.dtype, graph_dtype)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2160, in assertEqual
assert_equal(
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1084, in assert_equal
raise error_metas[0].to_error()
AssertionError: None mismatch: torch.complex128 is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does TestDtypeCustomRules. test_custom_rules
test? why is this an expected failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. Looking at the test it seems to check if dtype is propagated correctly in JIT traced graph. (I was planning to point someone from the JIT team to unblock this PR)
pytorch/test/jit/test_dtype_analysis.py
Lines 334 to 357 in ad028e5
def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False): | |
try: | |
samples = op.sample_inputs(device, dtype, requires_grad=False) | |
sample_input = first_sample(self, samples) | |
input_args = [sample_input.input, *sample_input.args] | |
expected_res = op(*input_args, **sample_input.kwargs) | |
except Exception as e: | |
if allow_eager_fail: | |
return | |
else: | |
raise e | |
func = op.get_op() | |
traced_fn = create_traced_fn(self, func) | |
# Have to run the traced function to actually generate the trace | |
traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs) | |
# Run the Dtype Analysis | |
graph = traced_fn.graph # Note this is a cached graph | |
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)] | |
self.prop_dtype_on_graph(graph, input_tensors) | |
self.assert_output_dtype_equal(expected_res, graph) |
@anjali411 can you please take a look to see if it looks in the right direction. Thanks! |
auto output = at::_convolution_mode( | ||
input, weight, bias, stride, std::move(padding), dilation, groups); | ||
Tensor output; | ||
if (at::isComplexType(input_.scalar_type())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice so adds support for complex half too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not at this point as +
, -
, *
, view_as_real
will need to be supported first.
dtypes=floating_types_and(torch.int64), | ||
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), | ||
dtypes=floating_and_complex_types_and(torch.int64), | ||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also test complex half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Granted that this would be slow (since we call into three convolution kernels), however it's a cool functionality and has been asked by several users so pretty exciting!
This PR looks good to me but we also need more tests to validate all cases (extend tests in test_nn.py
) as well as a test to verify if we are computing the correct value.
@anjali411 have updated the relevant tests in |
Very excited! Nice job @kshitij12345 ! |
|
||
# Global dtype for this test suite is torch.double | ||
# This leads to change in type-promotion | ||
# and conv1d outputs `complex128` for `complex64` input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work @kshitij12345
@pytorchbot merge this please |
Hey @kshitij12345. |
@pytorchbot revert this please |
This reverts commit b64e7de. Reverted #75013 on behalf of https://github.com/mruberry
Reland : #75013 Reference: #71108 Pull Request resolved: #75310 Approved by: https://github.com/anjali411
Summary: Reland : #75013 Reference: #71108 Pull Request resolved: #75310 Approved by: https://github.com/anjali411 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/706b9e8b8d8b17cc33f74636bf520a5a53e4d386 Reviewed By: b0noI Differential Revision: D35437863 fbshipit-source-id: 068ca2191e3be2abac4082bae3b234b21eb1ac0d
Reference: #71108