Skip to content

Commit b941d10

Browse files
wonjoolee95facebook-github-bot
authored andcommitted
Update torch::lazy::BackendDevice to have a new default ordinal (pytorch#76264)
Summary: Fixes pytorch/xla#3490. Updates `torch::lazy::BackendDevice` with changes below: 1. Remove the no-op string constructor. 2. Update default ordinal to `-1`. 3. Add a `is_valid` function to check if `ordinal` is valid/non-default (`ordinal >= 0`). Pull Request resolved: pytorch#76264 Reviewed By: mrshenli Differential Revision: D35860266 Pulled By: alanwaketan fbshipit-source-id: 554ebe16a0683d37b00270c4f35163bf690bfe28
1 parent ed11e5d commit b941d10

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

test/cpp/lazy/test_backend_device.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ TEST(BackendDeviceTest, Basic1) {
2020
auto device = BackendDevice();
2121

2222
EXPECT_EQ(device.type(), 0);
23-
EXPECT_EQ(device.ordinal(), 0);
24-
EXPECT_STREQ(device.toString().c_str(), "Unknown0");
23+
EXPECT_EQ(device.ordinal(), -1);
24+
EXPECT_FALSE(device.has_index());
25+
EXPECT_STREQ(device.toString().c_str(), "Unknown");
2526
}
2627

2728
TEST(BackendDeviceTest, Basic2) {
@@ -31,6 +32,7 @@ TEST(BackendDeviceTest, Basic2) {
3132

3233
EXPECT_EQ(device.type(), 1);
3334
EXPECT_EQ(device.ordinal(), 1);
35+
EXPECT_TRUE(device.has_index());
3436
EXPECT_STREQ(device.toString().c_str(), "Unknown1");
3537
}
3638

@@ -43,6 +45,7 @@ TEST(BackendDeviceTest, Basic3) {
4345

4446
EXPECT_EQ(device.type(), 0);
4547
EXPECT_EQ(device.ordinal(), 1);
48+
EXPECT_TRUE(device.has_index());
4649
EXPECT_STREQ(device.toString().c_str(), "Test1");
4750
}
4851

@@ -86,8 +89,8 @@ TEST(BackendDeviceTest, FromAten) {
8689
TEST(BackendDeviceTest, ToAten) {
8790
auto device = backendDeviceToAtenDevice(BackendDevice());
8891
EXPECT_EQ(device.type(), c10::kLazy);
89-
EXPECT_TRUE(device.has_index());
90-
EXPECT_EQ(device.index(), 0);
92+
EXPECT_FALSE(device.has_index());
93+
EXPECT_EQ(device.index(), -1);
9194
}
9295

9396
// TODO(alanwaketan): Update the following test once we have TorchScript backend upstreamed.

torch/csrc/lazy/backend/backend_device.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,24 @@ namespace torch {
1010
namespace lazy {
1111

1212
// TODO(alanwaketan): Use the backend API to get the default device type.
13-
// In the future, we should also get the default device ordinal.
1413
BackendDevice::BackendDevice()
15-
: type_(std::make_shared<BackendDeviceType>()) {}
14+
: type_(std::make_shared<BackendDeviceType>()), ordinal_(-1) {}
1615

1716
BackendDevice::BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal)
1817
: type_(std::move(type)), ordinal_(ordinal) {}
1918

20-
BackendDevice::BackendDevice(const std::string& device_spec)
21-
: BackendDevice::BackendDevice() {}
22-
2319
int8_t BackendDevice::type() const {
2420
TORCH_INTERNAL_ASSERT(type_);
2521
return type_->type;
2622
}
2723

2824
std::string BackendDevice::toString() const {
2925
TORCH_INTERNAL_ASSERT(type_);
30-
return c10::str(type_->toString(), ordinal_);
26+
std::string str = type_->toString();
27+
if (has_index()) {
28+
str.append(std::to_string(ordinal_));
29+
}
30+
return str;
3131
}
3232

3333
int BackendDevice::compare(const BackendDevice& rhs) const {
@@ -42,10 +42,9 @@ std::ostream& operator<<(std::ostream& os, const BackendDevice& device) {
4242
return os;
4343
}
4444

45-
// TODO(whc) refactor this: we need to support non-zero default ordinal for torch/XLA.
4645
BackendDevice atenDeviceToBackendDevice(const c10::Device& device) {
4746
TORCH_CHECK(device.type() == at::kLazy, device);
48-
int64_t ordinal = device.has_index() ? device.index() : 0;
47+
int64_t ordinal = device.has_index() ? device.index() : -1;
4948
return BackendDevice(getBackend()->GetDefaultDeviceType(), ordinal);
5049
}
5150

torch/csrc/lazy/backend/backend_device.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,15 @@ class TORCH_API BackendDevice {
3636
BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
3737

3838
int8_t type() const;
39-
int64_t ordinal() const { return ordinal_; }
39+
int64_t ordinal() const { return ordinal_; }
4040

4141
bool operator==(const BackendDevice& other) const { return compare(other) == 0; }
4242
bool operator!=(const BackendDevice& other) const { return compare(other) != 0; }
4343
bool operator<(const BackendDevice& rhs) const { return compare(rhs) < 0; }
4444

45-
std::string toString() const;
45+
bool has_index() const { return ordinal_ >= 0; }
4646

47-
// The string -> Device conversion should be handled by the backend interface.
48-
C10_DEPRECATED explicit BackendDevice(const std::string& device_spec);
47+
std::string toString() const;
4948

5049
private:
5150
int compare(const BackendDevice& rhs) const;

0 commit comments

Comments
 (0)