Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit 07a67c8

Browse files
authored
Merge pull request #3 from tsvikas/conv-vae
Implement ConvVAE
2 parents 0d2cfed + 322f7ea commit 07a67c8

File tree

10 files changed

+1323
-85
lines changed

10 files changed

+1323
-85
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ __pycache__
55
TODO
66
.ipynb_checkpoints/
77
*.ipynb
8+
.idea/
9+
.jupyter_ystore.db
10+
.virtual_documents/

poetry.lock

Lines changed: 231 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ authors = ["Tsvika S <[email protected]>"]
88

99
[tool.poetry.dependencies]
1010
python = "^3.11,!=3.11.0"
11-
torch = {version = "^2.0.0", source = "pytorch"}
12-
torchvision = {version = "^0.15.1", source = "pytorch"}
11+
torch = "^2.0.1"
12+
torchvision = "^0.15.2"
1313
pytorch-lightning = "^2.0.2"
1414
wandb = "^0.15.2"
1515
rich = "^13.3.5"
1616
einops = "^0.6.1"
17+
typer = {extras = ["all"], version = "^0.9.0"}
1718

1819
[tool.poetry.group.jupyter.dependencies]
1920
jupyterlab = "^4.0.0"
@@ -26,6 +27,7 @@ python-lsp-server = "^1.7.3"
2627
[tool.poetry.group.dev.dependencies]
2728
pre-commit = "^3.3.1"
2829
ruff = "^0.0.264"
30+
black = "^23.3.0"
2931

3032
[[tool.poetry.source]]
3133
name = "pytorch"
@@ -50,12 +52,14 @@ ignore = [
5052
"C408",
5153
"TRY003",
5254
"FBT002",
55+
"PLW2901",
5356
]
5457
src = ["src"]
5558

5659
[tool.ruff.per-file-ignores]
5760
"__init__.py" = ["F401"]
5861
"src/train.py" = ["INP001", "T201"]
62+
"src/models/resnet_vae.py" = ["T201", "PD002"]
5963
"src/explore_model.py" = [
6064
"INP001",
6165
"E703",

src/datamodules/images.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@ def train_val_split(
1616
train_transform,
1717
val_transform,
1818
dataset_cls,
19+
generator: torch.Generator = None,
1920
**dataset_kwargs,
2021
):
2122
"""load a dataset and split it, using a different transform for train and val"""
2223
lengths = [train_length, val_length]
2324
with isolate_rng():
2425
dataset_train = dataset_cls(**dataset_kwargs, transform=train_transform)
25-
train_split, _ = torch.utils.data.random_split(dataset_train, lengths)
26+
train_split, _ = torch.utils.data.random_split(
27+
dataset_train, lengths, generator=generator
28+
)
2629
with isolate_rng():
2730
dataset_val = dataset_cls(**dataset_kwargs, transform=val_transform)
28-
_, val_split = torch.utils.data.random_split(dataset_val, lengths)
31+
_, val_split = torch.utils.data.random_split(
32+
dataset_val, lengths, generator=generator
33+
)
2934
# repeat to consume the random state
3035
dataset = dataset_cls(**dataset_kwargs)
31-
torch.utils.data.random_split(dataset, lengths)
36+
torch.utils.data.random_split(dataset, lengths, generator=generator)
3237
return train_split, val_split
3338

3439

@@ -84,6 +89,7 @@ def __init__(
8489
self.val_size_or_frac = val_size_or_frac
8590
self.target_is_self = target_is_self
8691
self.noise_transforms = noise_transforms or []
92+
self.generator = torch.Generator()
8793

8894
# defined in self.setup()
8995
self.train_val_size = None
@@ -154,6 +160,7 @@ def setup(self, stage=None):
154160
root=self.data_dir,
155161
train=True,
156162
download=False,
163+
generator=self.generator,
157164
)
158165
self.test_set = self.dataset_cls(
159166
root=self.data_dir,
@@ -165,8 +172,12 @@ def setup(self, stage=None):
165172
self.train_set = TransformedSelfDataset(
166173
self.train_set, transforms=self.noise_transforms
167174
)
168-
self.val_set = TransformedSelfDataset(self.val_set)
169-
self.test_set = TransformedSelfDataset(self.test_set)
175+
self.val_set = TransformedSelfDataset(
176+
self.val_set, transforms=self.noise_transforms
177+
)
178+
self.test_set = TransformedSelfDataset(
179+
self.test_set, transforms=self.noise_transforms
180+
)
170181

171182
# verify num_classes and num_channels
172183
if (num_classes := len(self.test_set.classes)) != self.num_classes:
@@ -220,6 +231,8 @@ def train_dataloader(self):
220231
batch_size=self.batch_size,
221232
shuffle=True,
222233
num_workers=self.num_workers,
234+
pin_memory=True,
235+
persistent_workers=True,
223236
)
224237

225238
# we can use a x2 batch_size in validation and testing,
@@ -230,6 +243,8 @@ def val_dataloader(self):
230243
batch_size=self.batch_size * 2,
231244
shuffle=False,
232245
num_workers=self.num_workers,
246+
pin_memory=True,
247+
persistent_workers=True,
233248
)
234249

235250
def test_dataloader(self):
@@ -238,6 +253,8 @@ def test_dataloader(self):
238253
batch_size=self.batch_size * 2,
239254
shuffle=False,
240255
num_workers=self.num_workers,
256+
pin_memory=True,
257+
persistent_workers=True,
241258
)
242259

243260

@@ -269,5 +286,6 @@ def __getitem__(self, item):
269286
def __len__(self):
270287
return len(self.dataset)
271288

272-
def __getattr__(self, item):
273-
return getattr(self.dataset, item)
289+
@property
290+
def classes(self):
291+
return self.dataset.classes

src/explore_model.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,33 @@
1818

1919
import matplotlib.pyplot as plt
2020
import torch
21+
import torchvision.transforms.functional as TF # noqa: N812
22+
from einops import rearrange
2123
from IPython.core.display_functions import display
2224
from ipywidgets import interact
2325
from torchvision.transforms import ToTensor
2426
from torchvision.transforms.functional import to_pil_image
2527

28+
import models
2629
from datamodules import ImagesDataModule
27-
from models import FullyConnectedAutoEncoder
30+
from train import LOGS_DIR
2831

2932
# %%
33+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
34+
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else DEVICE
35+
36+
# %%
37+
ModelClass = models.ConvVAE
38+
dataset_name = "FashionMNIST"
39+
datamodule = ImagesDataModule(dataset_name, 1, 10)
40+
41+
# %%
42+
model_name = ModelClass.__name__.lower()
3043
ckpt_dir = (
31-
Path("/tmp/logs")
32-
/ "fullyconnectedautoencodersgd-fashionmnist"
33-
/ "fullyconnectedautoencodersgd-fashionmnist"
44+
LOGS_DIR
45+
/ f"{model_name}-{dataset_name.lower()}/{model_name}-{dataset_name.lower()}"
3446
)
47+
3548
for p in ckpt_dir.parents[::-1] + (ckpt_dir,):
3649
if not p.exists():
3750
raise ValueError(f"{p} not exists")
@@ -53,54 +66,85 @@ def sort_dict(d: dict):
5366
all_ckpts = sort_dict(get_last_fn(subdir) for subdir in ckpt_dir.glob("*"))
5467
display(all_ckpts)
5568

69+
5670
# %%
5771
# torch.load(ckpt_dir/list(all_ckpts.values())[-1])['hyper_parameters']
5872

5973
# %%
60-
model = FullyConnectedAutoEncoder.load_latest_checkpoint(ckpt_dir)
61-
model.eval()
74+
75+
76+
def load_model():
77+
return ModelClass.load_latest_checkpoint(ckpt_dir, map_location=DEVICE).eval()
78+
79+
80+
model = load_model()
6281
print(model.hparams)
6382
print(model)
6483

6584
# %%
66-
x_rand = torch.rand(1, 1, 28, 28)
67-
image = ImagesDataModule("FashionMNIST", 1, 10).dataset()[0][0]
85+
x_rand = torch.rand(1, 1, 32, 32)
86+
image, _target = datamodule.dataset()[0]
6887

6988
x_real = ToTensor()(image).unsqueeze(0)
89+
x_rand = TF.center_crop(x_rand, 32)
90+
x_real = TF.center_crop(x_real, 32)
7091
print(x_real.shape)
7192

7293

7394
# %%
74-
def show_tensors(imgs: list[torch.Tensor]):
95+
def show_tensors(imgs: list[torch.Tensor], normalize=True, figsize=None):
7596
if not isinstance(imgs, list):
7697
imgs = [imgs]
77-
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False)
98+
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
7899
axs = axss[0]
79100
for i, img in enumerate(imgs):
80-
img_clipped = img.detach().clip(0, 1)
81-
img_pil = to_pil_image(img_clipped)
101+
if normalize:
102+
img = (img - img.min()) / (img.max() - img.min())
103+
img = img.clamp(0, 1).detach()
104+
img_pil = to_pil_image(img)
82105
axs[i].imshow(img_pil, cmap="gray", vmin=0, vmax=255)
83106
axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
84107

85108

86109
for x in [x_rand, x_real]:
87-
show_tensors([x[0], model(x.cuda())[0]])
110+
show_tensors([x[0], model(x.to(DEVICE)).x_hat[0]])
88111

89112
# %%
90-
n_latent = 8
113+
n_latent = model.latent_dim
91114

92-
lims = (-2, 2, 0.01)
115+
lims = (-3, 3, 0.01)
93116
all_lims = {f"x{i:02}": lims for i in range(n_latent)}
94117

95118

96119
def show_from_latent(**inputs):
97120
data = torch.tensor(list(inputs.values()))
98-
data = data.view(1, -1).cuda()
121+
data = data.view(1, -1).to(DEVICE)
99122
result = model.decoder(data)[0]
100-
show_tensors(result)
123+
show_tensors(result, normalize=True)
101124
plt.show()
102125

103126

104127
interact(show_from_latent, **all_lims)
105128

106129
# %%
130+
model = load_model()
131+
132+
133+
def sample_latent(model, n: int = 30, lim: float = 3.0, downsample_factor: int = 2):
134+
x = torch.linspace(-lim, lim, n)
135+
y = torch.linspace(-lim, lim, n)
136+
z = torch.cartesian_prod(x, y)
137+
assert z.shape[1] == 2
138+
with torch.inference_mode():
139+
outs = model.decoder(z.to(model.device))
140+
out = rearrange(outs, "(i j) c h w -> c (i h) (j w)", i=n, j=n)
141+
out = torch.nn.functional.avg_pool2d(out, kernel_size=downsample_factor)
142+
# out = reduce(out, "c (h i) (w j) -> c h w", i=downsample_factor,j=downsample_factor, reduction="max")
143+
return out
144+
145+
146+
out = sample_latent(model)
147+
print(out.shape)
148+
show_tensors(out, figsize=(10, 10))
149+
150+
# %%

src/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .auto_encoder import FullyConnectedAutoEncoder
2+
from .conv_vae import ConvVAE
23
from .mlp import MultiLayerPerceptron
34
from .resnet import Resnet
5+
from .resnet_vae import ResidualAutoencoder

0 commit comments

Comments
 (0)