Skip to content

Commit 07fa6e2

Browse files
guangyeypytorchmergebot
authored andcommitted
Fix torch.accelerator api abort when passing invaild device (pytorch#143550)
# Motivation Fix pytorch#143543 # Solution We should raise python exception instead of aborting... # Additional Context without this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) terminate called after throwing an instance of 'c10::Error' what(): device is out of range, device is 2, total number of device is 2. Exception raised from check_device_index at /home/dvrogozh/git/pytorch/pytorch/c10/xpu/XPUFunctions.h:36 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xac (0x7f30707eb95c in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f307078fc57 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10.so) frame #2: <unknown function> + 0x19a3e (0x7f3070c2ba3e in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #3: c10::xpu::getCurrentXPUStream(signed char) + 0x2f (0x7f3070c2c83f in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #4: <unknown function> + 0x1ca35 (0x7f3070c2ea35 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libc10_xpu.so) frame #5: <unknown function> + 0x653f15 (0x7f3083391f15 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) frame #6: <unknown function> + 0x39e5f2 (0x7f30830dc5f2 in /home/dvrogozh/git/pytorch/pytorch/torch/lib/libtorch_python.so) <omitting python frames> frame #20: <unknown function> + 0x29d90 (0x7f308b19bd90 in /lib/x86_64-linux-gnu/libc.so.6) frame #21: __libc_start_main + 0x80 (0x7f308b19be40 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` with this PR: ```python >>> import torch >>> torch.accelerator.current_stream(torch.accelerator.device_count()) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/pt-gpu/4T-4652/guangyey/stock-pytorch/torch/accelerator/__init__.py", line 123, in current_stream return torch._C._accelerator_getStream(device_index) RuntimeError: The device index is out of range. It must be in [0, 2), but got 2. ``` Pull Request resolved: pytorch#143550 Approved by: https://github.com/EikanWang, https://github.com/dvrogozh, https://github.com/albanD
1 parent eebc93d commit 07fa6e2

File tree

9 files changed

+23
-17
lines changed

9 files changed

+23
-17
lines changed

aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
8282
void uncheckedSetDevice(Device d) const noexcept override {
8383
C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
8484
}
85-
Stream getStream(Device d) const noexcept override {
85+
Stream getStream(Device d) const override {
8686
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
8787
}
8888
Stream getDefaultStream(Device d) const override {
@@ -94,7 +94,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
9494
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
9595
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
9696
}
97-
Stream exchangeStream(Stream s) const noexcept override {
97+
Stream exchangeStream(Stream s) const override {
9898
HIPStreamMasqueradingAsCUDA cs(s);
9999
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
100100
setCurrentHIPStreamMasqueradingAsCUDA(cs);

aten/src/ATen/mps/MPSGuardImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct TORCH_API MPSGuardImpl final
6464
// TODO: Currently setting only device 0
6565
}
6666

67-
Stream getStream(Device d) const noexcept override {
67+
Stream getStream(Device d) const override {
6868
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
6969
}
7070

@@ -78,7 +78,7 @@ struct TORCH_API MPSGuardImpl final
7878
}
7979

8080
// NB: These do NOT set the current device
81-
Stream exchangeStream(Stream s) const noexcept override {
81+
Stream exchangeStream(Stream s) const override {
8282
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
8383
}
8484
DeviceIndex deviceCount() const noexcept override {

c10/core/impl/DeviceGuardImplInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct C10_API DeviceGuardImplInterface {
105105
/**
106106
* Get the current stream for a given device.
107107
*/
108-
virtual Stream getStream(Device) const noexcept = 0;
108+
virtual Stream getStream(Device) const = 0;
109109

110110
/**
111111
* Get the default stream for a given device.
@@ -138,7 +138,7 @@ struct C10_API DeviceGuardImplInterface {
138138
* Return the previous stream for that device. You are NOT required
139139
* to set the current device to match the device of this stream.
140140
*/
141-
virtual Stream exchangeStream(Stream) const noexcept = 0;
141+
virtual Stream exchangeStream(Stream) const = 0;
142142

143143
/**
144144
* Destroys the given event.

c10/core/impl/VirtualGuardImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
3737
void uncheckedSetDevice(Device d) const noexcept override {
3838
impl_->uncheckedSetDevice(d);
3939
}
40-
Stream getStream(Device d) const noexcept override {
40+
Stream getStream(Device d) const override {
4141
return impl_->getStream(d);
4242
}
4343
Stream getNewStream(Device d, int priority = 0) const override {
@@ -50,7 +50,7 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
5050
const override {
5151
return impl_->getStreamFromGlobalPool(d, isHighPriority);
5252
}
53-
Stream exchangeStream(Stream s) const noexcept override {
53+
Stream exchangeStream(Stream s) const override {
5454
return impl_->exchangeStream(s);
5555
}
5656
DeviceIndex deviceCount() const noexcept override {

c10/cuda/impl/CUDAGuardImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
5656
void uncheckedSetDevice(Device d) const noexcept override {
5757
C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
5858
}
59-
Stream getStream(Device d) const noexcept override {
59+
Stream getStream(Device d) const override {
6060
return getCurrentCUDAStream(d.index()).unwrap();
6161
}
6262
Stream getDefaultStream(Device d) const override {
@@ -70,7 +70,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
7070
return getStreamFromPool(isHighPriority, d.index());
7171
}
7272
// NB: These do NOT set the current device
73-
Stream exchangeStream(Stream s) const noexcept override {
73+
Stream exchangeStream(Stream s) const override {
7474
CUDAStream cs(s);
7575
auto old_stream = getCurrentCUDAStream(s.device().index());
7676
setCurrentCUDAStream(cs);

c10/xpu/XPUFunctions.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ C10_XPU_API void get_device_properties(
3232

3333
C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr);
3434

35-
static inline void check_device_index(DeviceIndex device) {
35+
static inline void check_device_index(DeviceIndex device_index) {
3636
TORCH_CHECK(
37-
device >= 0 && device < c10::xpu::device_count(),
38-
"device is out of range, device is ",
39-
static_cast<int>(device),
40-
", total number of device is ",
37+
device_index >= 0 && device_index < c10::xpu::device_count(),
38+
"The device index is out of range. It must be in [0, ",
4139
static_cast<int>(c10::xpu::device_count()),
40+
"), but got ",
41+
static_cast<int>(device_index),
4242
".");
4343
}
4444

c10/xpu/impl/XPUGuardImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
4444
c10::xpu::set_device(d.index());
4545
}
4646

47-
Stream getStream(Device d) const noexcept override {
47+
Stream getStream(Device d) const override {
4848
return getCurrentXPUStream(d.index()).unwrap();
4949
}
5050

@@ -58,7 +58,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
5858
}
5959

6060
// NB: These do NOT set the current device
61-
Stream exchangeStream(Stream s) const noexcept override {
61+
Stream exchangeStream(Stream s) const override {
6262
const XPUStream stream(s);
6363
const auto old_stream = getCurrentXPUStream(s.device().index());
6464
setCurrentXPUStream(stream);

test/test_cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,10 @@ def test_stream_compatibility(self):
766766
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
767767
torch.accelerator.set_stream(s2)
768768
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
769+
with self.assertRaisesRegex(
770+
RuntimeError, "device_index >= 0 && device_index < num_gpus"
771+
):
772+
torch.accelerator.current_stream(torch.accelerator.device_count())
769773

770774
def test_record_stream(self):
771775
cycles_per_ms = get_cycles_per_ms()

test/test_xpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def test_stream_compatibility(self):
306306
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
307307
torch.accelerator.set_stream(s2)
308308
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
309+
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
310+
torch.accelerator.current_stream(torch.accelerator.device_count())
309311

310312
def test_generator(self):
311313
torch.manual_seed(2024)

0 commit comments

Comments
 (0)