18
18
19
19
import matplotlib .pyplot as plt
20
20
import torch
21
+ import torchvision .transforms .functional as TF # noqa: N812
22
+ from einops import rearrange
21
23
from IPython .core .display_functions import display
22
24
from ipywidgets import interact
23
25
from torchvision .transforms import ToTensor
24
26
from torchvision .transforms .functional import to_pil_image
25
27
28
+ import models
26
29
from datamodules import ImagesDataModule
27
- from models import FullyConnectedAutoEncoder
30
+ from train import LOGS_DIR
28
31
29
32
# %%
33
+ DEVICE = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
34
+ DEVICE = torch .device ("mps" ) if torch .backends .mps .is_available () else DEVICE
35
+
36
+ # %%
37
+ ModelClass = models .ConvVAE
38
+ dataset_name = "FashionMNIST"
39
+ datamodule = ImagesDataModule (dataset_name , 1 , 10 )
40
+
41
+ # %%
42
+ model_name = ModelClass .__name__ .lower ()
30
43
ckpt_dir = (
31
- Path ("/tmp/logs" )
32
- / "fullyconnectedautoencodersgd-fashionmnist"
33
- / "fullyconnectedautoencodersgd-fashionmnist"
44
+ LOGS_DIR
45
+ / f"{ model_name } -{ dataset_name .lower ()} /{ model_name } -{ dataset_name .lower ()} "
34
46
)
47
+
35
48
for p in ckpt_dir .parents [::- 1 ] + (ckpt_dir ,):
36
49
if not p .exists ():
37
50
raise ValueError (f"{ p } not exists" )
@@ -53,54 +66,85 @@ def sort_dict(d: dict):
53
66
all_ckpts = sort_dict (get_last_fn (subdir ) for subdir in ckpt_dir .glob ("*" ))
54
67
display (all_ckpts )
55
68
69
+
56
70
# %%
57
71
# torch.load(ckpt_dir/list(all_ckpts.values())[-1])['hyper_parameters']
58
72
59
73
# %%
60
- model = FullyConnectedAutoEncoder .load_latest_checkpoint (ckpt_dir )
61
- model .eval ()
74
+
75
+
76
+ def load_model ():
77
+ return ModelClass .load_latest_checkpoint (ckpt_dir , map_location = DEVICE ).eval ()
78
+
79
+
80
+ model = load_model ()
62
81
print (model .hparams )
63
82
print (model )
64
83
65
84
# %%
66
- x_rand = torch .rand (1 , 1 , 28 , 28 )
67
- image = ImagesDataModule ( "FashionMNIST" , 1 , 10 ) .dataset ()[ 0 ] [0 ]
85
+ x_rand = torch .rand (1 , 1 , 32 , 32 )
86
+ image , _target = datamodule .dataset ()[0 ]
68
87
69
88
x_real = ToTensor ()(image ).unsqueeze (0 )
89
+ x_rand = TF .center_crop (x_rand , 32 )
90
+ x_real = TF .center_crop (x_real , 32 )
70
91
print (x_real .shape )
71
92
72
93
73
94
# %%
74
- def show_tensors (imgs : list [torch .Tensor ]):
95
+ def show_tensors (imgs : list [torch .Tensor ], normalize = True , figsize = None ):
75
96
if not isinstance (imgs , list ):
76
97
imgs = [imgs ]
77
- fig , axss = plt .subplots (ncols = len (imgs ), squeeze = False )
98
+ fig , axss = plt .subplots (ncols = len (imgs ), squeeze = False , figsize = figsize )
78
99
axs = axss [0 ]
79
100
for i , img in enumerate (imgs ):
80
- img_clipped = img .detach ().clip (0 , 1 )
81
- img_pil = to_pil_image (img_clipped )
101
+ if normalize :
102
+ img = (img - img .min ()) / (img .max () - img .min ())
103
+ img = img .clamp (0 , 1 ).detach ()
104
+ img_pil = to_pil_image (img )
82
105
axs [i ].imshow (img_pil , cmap = "gray" , vmin = 0 , vmax = 255 )
83
106
axs [i ].set (xticklabels = [], yticklabels = [], xticks = [], yticks = [])
84
107
85
108
86
109
for x in [x_rand , x_real ]:
87
- show_tensors ([x [0 ], model (x .cuda ()) [0 ]])
110
+ show_tensors ([x [0 ], model (x .to ( DEVICE )). x_hat [0 ]])
88
111
89
112
# %%
90
- n_latent = 8
113
+ n_latent = model . latent_dim
91
114
92
- lims = (- 2 , 2 , 0.01 )
115
+ lims = (- 3 , 3 , 0.01 )
93
116
all_lims = {f"x{ i :02} " : lims for i in range (n_latent )}
94
117
95
118
96
119
def show_from_latent (** inputs ):
97
120
data = torch .tensor (list (inputs .values ()))
98
- data = data .view (1 , - 1 ).cuda ( )
121
+ data = data .view (1 , - 1 ).to ( DEVICE )
99
122
result = model .decoder (data )[0 ]
100
- show_tensors (result )
123
+ show_tensors (result , normalize = True )
101
124
plt .show ()
102
125
103
126
104
127
interact (show_from_latent , ** all_lims )
105
128
106
129
# %%
130
+ model = load_model ()
131
+
132
+
133
+ def sample_latent (model , n : int = 30 , lim : float = 3.0 , downsample_factor : int = 2 ):
134
+ x = torch .linspace (- lim , lim , n )
135
+ y = torch .linspace (- lim , lim , n )
136
+ z = torch .cartesian_prod (x , y )
137
+ assert z .shape [1 ] == 2
138
+ with torch .inference_mode ():
139
+ outs = model .decoder (z .to (model .device ))
140
+ out = rearrange (outs , "(i j) c h w -> c (i h) (j w)" , i = n , j = n )
141
+ out = torch .nn .functional .avg_pool2d (out , kernel_size = downsample_factor )
142
+ # out = reduce(out, "c (h i) (w j) -> c h w", i=downsample_factor,j=downsample_factor, reduction="max")
143
+ return out
144
+
145
+
146
+ out = sample_latent (model )
147
+ print (out .shape )
148
+ show_tensors (out , figsize = (10 , 10 ))
149
+
150
+ # %%
0 commit comments