1
1
import os
2
2
import sys
3
- import torch
4
3
import time
5
- from torch .utils .data import DataLoader
6
- from torchvision import transforms , datasets
7
- import torch .nn as nn
8
- import torch .nn .functional as F
9
4
10
5
import habana_frameworks .torch .core as htcore
6
+ import torch
7
+ import torch .nn as nn
8
+ import torch .nn .functional as F
9
+ from torch .utils .data import DataLoader
10
+ from torchvision import datasets , transforms
11
11
12
12
13
13
class Net (nn .Module ):
14
14
def __init__ (self ):
15
15
super (Net , self ).__init__ ()
16
- self .fc1 = nn .Linear (784 , 256 )
17
- self .fc2 = nn .Linear (256 , 64 )
18
- self .fc3 = nn .Linear (64 , 10 )
16
+ self .fc1 = nn .Linear (784 , 256 )
17
+ self .fc2 = nn .Linear (256 , 64 )
18
+ self .fc3 = nn .Linear (64 , 10 )
19
+
19
20
def forward (self , x ):
20
- out = x .view (- 1 ,28 * 28 )
21
+ out = x .view (- 1 , 28 * 28 )
21
22
out = F .relu (self .fc1 (out ))
22
23
out = F .relu (self .fc2 (out ))
23
24
out = self .fc3 (out )
24
25
out = F .log_softmax (out , dim = 1 )
25
26
return out
26
27
28
+
27
29
model = Net ()
28
- checkpoint = torch .load (' mnist-epoch_20.pth' )
30
+ checkpoint = torch .load (" mnist-epoch_20.pth" )
29
31
model .load_state_dict (checkpoint )
30
32
31
33
model = model .eval ()
32
34
33
35
model = model .to ("hpu" )
34
36
35
37
36
-
37
- model = torch .compile (model ,backend = "hpu_backend" )
38
+ model = torch .compile (model , backend = "hpu_backend" )
38
39
39
40
40
- transform = transforms .Compose ([
41
- transforms .ToTensor (),
42
- transforms .Normalize ((0.1307 ,), (0.3081 ,))])
41
+ transform = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))])
43
42
44
- data_path = ' ./data'
43
+ data_path = " ./data"
45
44
test_dataset = datasets .MNIST (data_path , train = False , download = True , transform = transform )
46
45
test_loader = torch .utils .data .DataLoader (test_dataset , batch_size = 32 )
47
46
@@ -57,4 +56,4 @@ def forward(self, x):
57
56
correct += output .argmax (1 ).eq (label ).sum ().item ()
58
57
59
58
accuracy = correct / len (test_loader .dataset ) * 100
60
- print (' Inference with torch.compile Completed. Accuracy: {:.2f}%' .format (accuracy ))
59
+ print (" Inference with torch.compile Completed. Accuracy: {:.2f}%" .format (accuracy ))
0 commit comments