Skip to content

Commit 9596fab

Browse files
jerome-habanakaushikb11rohitgr7justusschockBorda
authored
Add auto_device_count and device name support (#13423)
Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]> Co-authored-by: awaelchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: mansy <[email protected]> Co-authored-by: manskx <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Mansy <[email protected]> Co-authored-by: otaj <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Keiichi Kuroyanagi <[email protected]> Co-authored-by: Martino Sorbaro <[email protected]> Co-authored-by: Wang Ran (汪然) <[email protected]> Co-authored-by: Rhys Goodall <[email protected]> Co-authored-by: Siyuan Li <[email protected]> Co-authored-by: Ekagra Ranjan <[email protected]> Co-authored-by: S. Kumano <[email protected]> Co-authored-by: otaj <[email protected]> Co-authored-by: Gautier Dagan <[email protected]> Co-authored-by: Sherin Thomas <[email protected]> Co-authored-by: Cyprien Ricque <[email protected]> Co-authored-by: Masahiro Wada <[email protected]> Co-authored-by: nitinramvelraj <[email protected]> Co-authored-by: donlapark <[email protected]> Co-authored-by: Justin Goheen <[email protected]> Co-authored-by: Shantam Gilra <[email protected]> Co-authored-by: Bibhabasu Mohapatra <[email protected]> Co-authored-by: Jimmy Yao <[email protected]> Co-authored-by: Nikhil Shenoy <[email protected]> Co-authored-by: Sanjay Aradhyamath <[email protected]>
1 parent bc6d735 commit 9596fab

File tree

5 files changed

+35
-6
lines changed

5 files changed

+35
-6
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
147147
- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
148148

149149

150+
- Updated Habana Accelerator's `auto_device_count`, `is_available` & `get_device_name` methods based on the latest torch habana package ([#13423](https://github.com/PyTorchLightning/pytorch-lightning/pull/13423))
151+
152+
153+
-
154+
155+
150156
### Deprecated
151157

152158
- Deprecated `pytorch_lightning.loggers.base.LightningLoggerBase` in favor of `pytorch_lightning.loggers.logger.Logger`, and deprecated `pytorch_lightning.loggers.base` in favor of `pytorch_lightning.loggers.logger` ([#120148](https://github.com/PyTorchLightning/pytorch-lightning/pull/12014))

src/pytorch_lightning/accelerators/hpu.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2222
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
2323

24+
if _HPU_AVAILABLE:
25+
import habana_frameworks.torch.hpu as torch_hpu
26+
2427

2528
class HPUAccelerator(Accelerator):
2629
"""Accelerator for HPU devices."""
@@ -52,13 +55,28 @@ def get_parallel_devices(devices: int) -> List[torch.device]:
5255

5356
@staticmethod
5457
def auto_device_count() -> int:
55-
"""Get the devices when set to auto."""
56-
# TODO(@kaushikb11): Update this when api is exposed by the Habana team
57-
return 8
58+
"""Returns the number of HPU devices when the devices is set to auto."""
59+
try:
60+
return torch_hpu.device_count()
61+
except (AttributeError, NameError):
62+
rank_zero_debug("HPU `auto_device_count` failed, returning default count of 8.")
63+
return 8
5864

5965
@staticmethod
6066
def is_available() -> bool:
61-
return _HPU_AVAILABLE
67+
"""Returns a bool indicating if HPU is currently available."""
68+
try:
69+
return torch_hpu.is_available()
70+
except (AttributeError, NameError):
71+
return False
72+
73+
@staticmethod
74+
def get_device_name() -> str:
75+
"""Returns the name of the HPU device."""
76+
try:
77+
return torch_hpu.get_device_name()
78+
except (AttributeError, NameError):
79+
return ""
6280

6381
@classmethod
6482
def register_accelerators(cls, accelerator_registry: Dict) -> None:

src/pytorch_lightning/strategies/hpu_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
if _HPU_AVAILABLE:
3434
import habana_frameworks.torch.core as htcore
35-
import habana_frameworks.torch.core.hccl # noqa: F401
35+
import habana_frameworks.torch.distributed.hccl # noqa: F401
3636

3737
log = logging.getLogger(__name__)
3838

src/pytorch_lightning/strategies/single_hpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
if _HPU_AVAILABLE:
2626
import habana_frameworks.torch.core as htcore
27-
import habana_frameworks.torch.core.hccl # noqa: F401
2827

2928

3029
class SingleHPUStrategy(SingleDeviceStrategy):

tests/tests_pytorch/accelerators/test_hpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def test_availability():
4040
assert HPUAccelerator.is_available()
4141

4242

43+
@RunIf(hpu=True)
44+
def test_device_name():
45+
assert HPUAccelerator.get_device_name() == "GAUDI"
46+
47+
4348
@pytest.mark.skipif(_HPU_AVAILABLE, reason="test requires non-HPU machine")
4449
def test_fail_if_no_hpus():
4550
with pytest.raises(MisconfigurationException, match="HPUAccelerator can not run on your system"):
@@ -239,6 +244,7 @@ def test_inference_only(tmpdir, hpus):
239244
trainer.predict(model)
240245

241246

247+
@RunIf(hpu=True)
242248
def test_hpu_auto_device_count():
243249
assert HPUAccelerator.auto_device_count() == 8
244250

0 commit comments

Comments
 (0)