Skip to content

Commit cc1b4e6

Browse files
committed
Merge branch 'suyue/gaudi_test' of https://github.com/intel/neural-compressor into suyue/gaudi_test
2 parents 0010f50 + f5333b2 commit cc1b4e6

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed
Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,46 @@
11
import os
22
import sys
3-
import torch
43
import time
5-
from torch.utils.data import DataLoader
6-
from torchvision import transforms, datasets
7-
import torch.nn as nn
8-
import torch.nn.functional as F
94

105
import habana_frameworks.torch.core as htcore
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
from torch.utils.data import DataLoader
10+
from torchvision import datasets, transforms
1111

1212

1313
class Net(nn.Module):
1414
def __init__(self):
1515
super(Net, self).__init__()
16-
self.fc1 = nn.Linear(784, 256)
17-
self.fc2 = nn.Linear(256, 64)
18-
self.fc3 = nn.Linear(64, 10)
16+
self.fc1 = nn.Linear(784, 256)
17+
self.fc2 = nn.Linear(256, 64)
18+
self.fc3 = nn.Linear(64, 10)
19+
1920
def forward(self, x):
20-
out = x.view(-1,28*28)
21+
out = x.view(-1, 28 * 28)
2122
out = F.relu(self.fc1(out))
2223
out = F.relu(self.fc2(out))
2324
out = self.fc3(out)
2425
out = F.log_softmax(out, dim=1)
2526
return out
2627

28+
2729
model = Net()
28-
checkpoint = torch.load('mnist-epoch_20.pth')
30+
checkpoint = torch.load("mnist-epoch_20.pth")
2931
model.load_state_dict(checkpoint)
3032

3133
model = model.eval()
3234

3335
model = model.to("hpu")
3436

3537

36-
37-
model = torch.compile(model,backend="hpu_backend")
38+
model = torch.compile(model, backend="hpu_backend")
3839

3940

40-
transform=transforms.Compose([
41-
transforms.ToTensor(),
42-
transforms.Normalize((0.1307,), (0.3081,))])
41+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
4342

44-
data_path = './data'
43+
data_path = "./data"
4544
test_dataset = datasets.MNIST(data_path, train=False, download=True, transform=transform)
4645
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
4746

@@ -57,4 +56,4 @@ def forward(self, x):
5756
correct += output.argmax(1).eq(label).sum().item()
5857

5958
accuracy = correct / len(test_loader.dataset) * 100
60-
print('Inference with torch.compile Completed. Accuracy: {:.2f}%'.format(accuracy))
59+
print("Inference with torch.compile Completed. Accuracy: {:.2f}%".format(accuracy))

0 commit comments

Comments
 (0)