Skip to content

feat: Add min, max ranges to mark_dynamic API #119737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

feat: Add min, max ranges to mark_dynamic API #119737

wants to merge 6 commits into from

Conversation

peri044
Copy link
Contributor

@peri044 peri044 commented Feb 13, 2024

Fixes #115137

This PR adds:

  • mark_dynamic API will accept min, max values to create a bounded constraint on the dim.
  • test case in test_misc.py which checks if ConstraintViolationError is triggered if torch.compile gets a input dimension out of bounds.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @aakhundov @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
cc: @narendasan

Copy link

pytorch-bot bot commented Feb 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119737

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending, 3 Unrelated Failures

As of commit 8abdc26 with merge base 8fa6340 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

pytorch-bot bot commented Feb 13, 2024

Please seek CI approval before scheduling CIFlow labels

@colesbury
Copy link
Member

@jansel - would you please review this?

@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 14, 2024
@colesbury colesbury requested a review from ezyang February 14, 2024 14:11
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nit, though @ezyang should take a look here.


if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set()
# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_indices.add((index, min, max))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider just making a little data structure for this, easier than indexing it

Copy link
Contributor Author

@peri044 peri044 Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I added _DimRange dataclass now to access it via dim_range.min etc. Please let me know if there's any other preference

if min_val == 2 and not max_val:
constraint_dim = RelaxedUnspecConstraint(warn_only=False)
else:
constraint_dim = StrictMinMaxConstraint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned about this going from Relaxed to Strict, the main question is whether or not export requires there are NO other constraints, or if it's just testing the value range. If it's just value range that should be alright.

Copy link
Contributor Author

@peri044 peri044 Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per my understanding (of this snippet and this file), it seems like it's just testing the value range.

@ezyang
Copy link
Contributor

ezyang commented Feb 14, 2024

This looks good but there are test failures, you know how to solve?

Copy link

pytorch-bot bot commented Feb 15, 2024

Please seek CI approval before scheduling CIFlow labels

Copy link

pytorch-bot bot commented Feb 15, 2024

Please seek CI approval before scheduling CIFlow labels

@ezyang
Copy link
Contributor

ezyang commented Feb 19, 2024

holler if you need help fixing ci problems

Copy link

pytorch-bot bot commented Feb 20, 2024

Please seek CI approval before scheduling CIFlow labels

@peri044
Copy link
Contributor Author

peri044 commented Feb 20, 2024

holler if you need help fixing ci problems

Yes. That would be very helpful.

  1. Currently I pushed a commit to fix linter issues. I ran lintrunner --force-color -a <filename> --skip CLANGTIDY,CLANGFORMAT to fix them. Please approve so that CI starts running with this latest commit.
  2. The previous CI error (not visible now) is _DimRange is not defined (occured at test_dynamic_shapes.py -k test_mark_dynamic_with_ranges_dynamic_shapes). I'm not sure why. It passes on my local. Please let me know if you have any suggestions for this if it repeats again. Thanks

@jansel
Copy link
Contributor

jansel commented Feb 20, 2024

I approved the CI for you

@ezyang
Copy link
Contributor

ezyang commented Feb 22, 2024

seems like just silly ci errors right now

@peri044
Copy link
Contributor Author

peri044 commented Feb 22, 2024

seems like just silly ci errors right now

It seems so. The error is as follows

ERROR RUNNING GUARDS my_dyn_fn /home/dperi/Downloads/my_fork/pytorch/test/dynamo/test_misc.py:7106
lambda L, **___kwargs_ignored:
  ___check_global_state() and
  ___check_type_id(L['a'], 94252923105904) and
  ((L['a']._dynamo_dynamic_indices.issubset({_DimRange(dim=0, min=None, max=None)})) if hasattr(L['a'], '_dynamo_dynamic_indices') else True) and
  utils_device.CURRENT_DEVICE == None and
  ___check_current_backend(139619195448208) and
  ___check_tensors(L['a'], tensor_check_names=tensor_check_names) and
  L['a'].size()[0] > 2 and
  2 <= L['a'].size()[0] and
  2 <= L['a'].size()[0]
Malformed guard:
((L['a']._dynamo_dynamic_indices.issubset({_DimRange(dim=0, min=None, max=None)})) if hasattr(L['a'], '_dynamo_dynamic_indices') else True)

As per your suggestion, I made a dataclass _DimRange to store index, min and max in the attribute _dynamo_dynamic_indices. But from the error log it seems like the guards do not have _DimRange definition. Do I need to define it elsewhere ? Any suggestions ? Thanks !!

@ezyang
Copy link
Contributor

ezyang commented Feb 24, 2024

oh blah, this is annoying. Hmm...

So, we can fix the proximal problem by introducing a _DimRange binding to CLOSURE_VARS in torch/_dynamo/guards.py. However, the failure here has made me realize that there is another annoying problem, which is that the issubset test is no longer the right thing to do in the presence of this extra information. To motivate this, the idea is that let's say you compile some code under the assumption that dim=1 is dynamic. If later you also mark dim=2 dynamic, the guard here will force a recompilation (so that we actually generate a dynamic kernel). If you remove the dim=1 marking, though, we don't recompile, because our dynamic kernel should work for your static case. There's a comment on this in guards.py at

            # A frame is valid for reuse with dynamic dimensions if the new dynamic dimensions are a
            # strict subset of the old.

This is all very delicate though and we are pretty inconsistent (I don't think we're guarding on mark static lol). So I feel maybe the easiest thing to do, is to just store the min/max range on a separate variable and file a bug for follow up on the guard problem. If you want to bash your way past this, though, then you not only need to do the subset test, but you also have to do some containment test on the ranges (a frame is valid to reuse if the new allowed range is a subset of the old).

@peri044
Copy link
Contributor Author

peri044 commented Mar 6, 2024

@pytorchbot label "release notes: dynamo"

@peri044
Copy link
Contributor Author

peri044 commented Mar 6, 2024

Thank you for all the help @ezyang

@ezyang ezyang added the topic: new features topic category label Mar 6, 2024
@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Signed-off-by: Edward Z. Yang <[email protected]>
@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral), trunk / win-vs2019-cpu-py3 / test (default, 3, 3, windows.4xlarge.nonephemeral)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@peri044
Copy link
Contributor Author

peri044 commented Mar 7, 2024

@ezyang The following failures on Windows seem unrelated. Any suggestions ?
1)

test_torch.py::TestTorchDeviceTypeCPU::test_grad_scaling_autocast_foreach_cpu FAILED [0.0849s] [  0%]

=================================== RERUNS ====================================
________ TestTorchDeviceTypeCPU.test_grad_scaling_autocast_foreach_cpu ________
Traceback (most recent call last):
  File "C:\actions-runner\_work\pytorch\pytorch\test\test_torch.py", line 5886, in test_grad_scaling_autocast_foreach
    self._grad_scaling_autocast_test(device=device.type, optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True})
  File "C:\actions-runner\_work\pytorch\pytorch\test\test_torch.py", line 5872, in _grad_scaling_autocast_test
    self._run_scaling_case(
  File "C:\Jenkins\Miniconda3\lib\unittest\case.py", line 226, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "C:\Jenkins\Miniconda3\lib\unittest\case.py", line 163, in _raiseFailure
    raise self.test_case.failureException(msg)
test_optim.py::TestSWAUtils::test_averaged_model_all_devices_ema_True <- test\optim\test_swa_utils.py FAILED [0.0112s] [ 28%]

=================================== RERUNS ====================================
____________ TestSWAUtils.test_averaged_model_all_devices_ema_True ____________
Traceback (most recent call last):
  File "C:\actions-runner\_work\pytorch\pytorch\test\optim\test_swa_utils.py", line 99, in test_averaged_model_all_devices
    self._test_averaged_model(cpu, cpu, ema)
  File "C:\actions-runner\_work\pytorch\pytorch\test\optim\test_swa_utils.py", line 69, in _test_averaged_model
    self.assertEqual(p_avg, p_swa)
  File "C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 3625, in assertEqual
    raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!

@ezyang
Copy link
Contributor

ezyang commented Mar 7, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: dynamo topic: new features topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[torch.compile] Request for shape ranges in torch.compile workflow
6 participants