Skip to content

[complex] conv1d #75013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Mar 31, 2022

Reference: #71108

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 31, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit e8cb463 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@@ -11191,6 +11192,9 @@ def ref_pairwise_distance(input1, input2):
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
# "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# AssertionError: None mismatch: torch.complex128 is not None
DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete Stack Trace:

raceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 361, in test_custom_rules
    self.custom_rules_test_base(device, dtype, op)
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 357, in custom_rules_test_base
    self.assert_output_dtype_equal(expected_res, graph)
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 323, in assert_output_dtype_equal
    self.assert_tensor_dtype_equal(expected_res, actual_dtype[0])
  File "/home/kshiteej/Pytorch/pytorch_complex_convolution.py/test/jit/test_dtype_analysis.py", line 332, in assert_tensor_dtype_equal
    self.assertEqual(tensor_output.dtype, graph_dtype)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2160, in assertEqual
    assert_equal(
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1084, in assert_equal
    raise error_metas[0].to_error()
AssertionError: None mismatch: torch.complex128 is not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does TestDtypeCustomRules. test_custom_rules test? why is this an expected failure?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. Looking at the test it seems to check if dtype is propagated correctly in JIT traced graph. (I was planning to point someone from the JIT team to unblock this PR)

def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False):
try:
samples = op.sample_inputs(device, dtype, requires_grad=False)
sample_input = first_sample(self, samples)
input_args = [sample_input.input, *sample_input.args]
expected_res = op(*input_args, **sample_input.kwargs)
except Exception as e:
if allow_eager_fail:
return
else:
raise e
func = op.get_op()
traced_fn = create_traced_fn(self, func)
# Have to run the traced function to actually generate the trace
traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs)
# Run the Dtype Analysis
graph = traced_fn.graph # Note this is a cached graph
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
self.prop_dtype_on_graph(graph, input_tensors)
self.assert_output_dtype_equal(expected_res, graph)

@kshitij12345
Copy link
Collaborator Author

@anjali411 can you please take a look to see if it looks in the right direction. Thanks!

auto output = at::_convolution_mode(
input, weight, bias, stride, std::move(padding), dilation, groups);
Tensor output;
if (at::isComplexType(input_.scalar_type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice so adds support for complex half too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at this point as +, -, *, view_as_real will need to be supported first.

dtypes=floating_types_and(torch.int64),
dtypesIfCUDA=floating_types_and(torch.float16, *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
dtypes=floating_and_complex_types_and(torch.int64),
dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also test complex half

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Granted that this would be slow (since we call into three convolution kernels), however it's a cool functionality and has been asked by several users so pretty exciting!

This PR looks good to me but we also need more tests to validate all cases (extend tests in test_nn.py) as well as a test to verify if we are computing the correct value.

@kshitij12345
Copy link
Collaborator Author

@anjali411 have updated the relevant tests in test_nn. Have also added a test for correctness against scipy. PTAL :)

@0x00b1
Copy link
Contributor

0x00b1 commented Apr 5, 2022

Very excited! Nice job @kshitij12345 !

@kshitij12345 kshitij12345 marked this pull request as ready for review April 5, 2022 16:14

# Global dtype for this test suite is torch.double
# This leads to change in type-promotion
# and conv1d outputs `complex128` for `complex64` input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice note

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work @kshitij12345

@anjali411
Copy link
Contributor

@pytorchbot merge this please

@github-actions
Copy link
Contributor

github-actions bot commented Apr 5, 2022

Hey @kshitij12345.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@mruberry
Copy link
Collaborator

mruberry commented Apr 5, 2022

@pytorchbot revert this please

pytorchmergebot added a commit that referenced this pull request Apr 5, 2022
pytorchmergebot pushed a commit that referenced this pull request Apr 6, 2022
Reland : #75013

Reference: #71108
Pull Request resolved: #75310
Approved by: https://github.com/anjali411
facebook-github-bot pushed a commit that referenced this pull request Apr 7, 2022
Summary:
Reland : #75013

Reference: #71108

Pull Request resolved: #75310
Approved by: https://github.com/anjali411

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/706b9e8b8d8b17cc33f74636bf520a5a53e4d386

Reviewed By: b0noI

Differential Revision: D35437863

fbshipit-source-id: 068ca2191e3be2abac4082bae3b234b21eb1ac0d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants