Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit ee48b7b

Browse files
committed
model exploration now uses ConvVAE
1 parent 347c800 commit ee48b7b

File tree

1 file changed

+61
-17
lines changed

1 file changed

+61
-17
lines changed

src/explore_model.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,33 @@
1818

1919
import matplotlib.pyplot as plt
2020
import torch
21+
import torchvision.transforms.functional as TF # noqa: N812
22+
from einops import rearrange
2123
from IPython.core.display_functions import display
2224
from ipywidgets import interact
2325
from torchvision.transforms import ToTensor
2426
from torchvision.transforms.functional import to_pil_image
2527

28+
import models
2629
from datamodules import ImagesDataModule
27-
from models import FullyConnectedAutoEncoder
30+
from train import LOGS_DIR
2831

2932
# %%
33+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
34+
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else DEVICE
35+
36+
# %%
37+
ModelClass = models.ConvVAE
38+
dataset_name = "FashionMNIST"
39+
datamodule = ImagesDataModule(dataset_name, 1, 10)
40+
41+
# %%
42+
model_name = ModelClass.__name__.lower()
3043
ckpt_dir = (
31-
Path("/tmp/logs")
32-
/ "fullyconnectedautoencodersgd-fashionmnist"
33-
/ "fullyconnectedautoencodersgd-fashionmnist"
44+
LOGS_DIR
45+
/ f"{model_name}-{dataset_name.lower()}/{model_name}-{dataset_name.lower()}"
3446
)
47+
3548
for p in ckpt_dir.parents[::-1] + (ckpt_dir,):
3649
if not p.exists():
3750
raise ValueError(f"{p} not exists")
@@ -53,54 +66,85 @@ def sort_dict(d: dict):
5366
all_ckpts = sort_dict(get_last_fn(subdir) for subdir in ckpt_dir.glob("*"))
5467
display(all_ckpts)
5568

69+
5670
# %%
5771
# torch.load(ckpt_dir/list(all_ckpts.values())[-1])['hyper_parameters']
5872

5973
# %%
60-
model = FullyConnectedAutoEncoder.load_latest_checkpoint(ckpt_dir)
61-
model.eval()
74+
75+
76+
def load_model():
77+
return ModelClass.load_latest_checkpoint(ckpt_dir, map_location=DEVICE).eval()
78+
79+
80+
model = load_model()
6281
print(model.hparams)
6382
print(model)
6483

6584
# %%
66-
x_rand = torch.rand(1, 1, 28, 28)
67-
image = ImagesDataModule("FashionMNIST", 1, 10).dataset()[0][0]
85+
x_rand = torch.rand(1, 1, 32, 32)
86+
image, _target = datamodule.dataset()[0]
6887

6988
x_real = ToTensor()(image).unsqueeze(0)
89+
x_rand = TF.center_crop(x_rand, 32)
90+
x_real = TF.center_crop(x_real, 32)
7091
print(x_real.shape)
7192

7293

7394
# %%
74-
def show_tensors(imgs: list[torch.Tensor]):
95+
def show_tensors(imgs: list[torch.Tensor], normalize=True, figsize=None):
7596
if not isinstance(imgs, list):
7697
imgs = [imgs]
77-
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False)
98+
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
7899
axs = axss[0]
79100
for i, img in enumerate(imgs):
80-
img_clipped = img.detach().clip(0, 1)
81-
img_pil = to_pil_image(img_clipped)
101+
if normalize:
102+
img = (img - img.min()) / (img.max() - img.min())
103+
img = img.clamp(0, 1).detach()
104+
img_pil = to_pil_image(img)
82105
axs[i].imshow(img_pil, cmap="gray", vmin=0, vmax=255)
83106
axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
84107

85108

86109
for x in [x_rand, x_real]:
87-
show_tensors([x[0], model(x.cuda())[0]])
110+
show_tensors([x[0], model(x.to(DEVICE)).x_hat[0]])
88111

89112
# %%
90-
n_latent = 8
113+
n_latent = model.latent_dim
91114

92-
lims = (-2, 2, 0.01)
115+
lims = (-3, 3, 0.01)
93116
all_lims = {f"x{i:02}": lims for i in range(n_latent)}
94117

95118

96119
def show_from_latent(**inputs):
97120
data = torch.tensor(list(inputs.values()))
98-
data = data.view(1, -1).cuda()
121+
data = data.view(1, -1).to(DEVICE)
99122
result = model.decoder(data)[0]
100-
show_tensors(result)
123+
show_tensors(result, normalize=True)
101124
plt.show()
102125

103126

104127
interact(show_from_latent, **all_lims)
105128

106129
# %%
130+
model = load_model()
131+
132+
133+
def sample_latent(model, n: int = 30, lim: float = 3.0, downsample_factor: int = 2):
134+
x = torch.linspace(-lim, lim, n)
135+
y = torch.linspace(-lim, lim, n)
136+
z = torch.cartesian_prod(x, y)
137+
assert z.shape[1] == 2
138+
with torch.inference_mode():
139+
outs = model.decoder(z.to(model.device))
140+
out = rearrange(outs, "(i j) c h w -> c (i h) (j w)", i=n, j=n)
141+
out = torch.nn.functional.avg_pool2d(out, kernel_size=downsample_factor)
142+
# out = reduce(out, "c (h i) (w j) -> c h w", i=downsample_factor,j=downsample_factor, reduction="max")
143+
return out
144+
145+
146+
out = sample_latent(model)
147+
print(out.shape)
148+
show_tensors(out, figsize=(10, 10))
149+
150+
# %%

0 commit comments

Comments
 (0)