Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Implement ConvVAE #3

Merged
merged 19 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ __pycache__
TODO
.ipynb_checkpoints/
*.ipynb
.idea/
.jupyter_ystore.db
.virtual_documents/
247 changes: 231 additions & 16 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ authors = ["Tsvika S <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.11,!=3.11.0"
torch = {version = "^2.0.0", source = "pytorch"}
torchvision = {version = "^0.15.1", source = "pytorch"}
torch = "^2.0.1"
torchvision = "^0.15.2"
pytorch-lightning = "^2.0.2"
wandb = "^0.15.2"
rich = "^13.3.5"
einops = "^0.6.1"
typer = {extras = ["all"], version = "^0.9.0"}

[tool.poetry.group.jupyter.dependencies]
jupyterlab = "^4.0.0"
Expand All @@ -26,6 +27,7 @@ python-lsp-server = "^1.7.3"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.1"
ruff = "^0.0.264"
black = "^23.3.0"

[[tool.poetry.source]]
name = "pytorch"
Expand All @@ -50,12 +52,14 @@ ignore = [
"C408",
"TRY003",
"FBT002",
"PLW2901",
]
src = ["src"]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
"src/train.py" = ["INP001", "T201"]
"src/models/resnet_vae.py" = ["T201", "PD002"]
"src/explore_model.py" = [
"INP001",
"E703",
Expand Down
32 changes: 25 additions & 7 deletions src/datamodules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@ def train_val_split(
train_transform,
val_transform,
dataset_cls,
generator: torch.Generator = None,
**dataset_kwargs,
):
"""load a dataset and split it, using a different transform for train and val"""
lengths = [train_length, val_length]
with isolate_rng():
dataset_train = dataset_cls(**dataset_kwargs, transform=train_transform)
train_split, _ = torch.utils.data.random_split(dataset_train, lengths)
train_split, _ = torch.utils.data.random_split(
dataset_train, lengths, generator=generator
)
with isolate_rng():
dataset_val = dataset_cls(**dataset_kwargs, transform=val_transform)
_, val_split = torch.utils.data.random_split(dataset_val, lengths)
_, val_split = torch.utils.data.random_split(
dataset_val, lengths, generator=generator
)
# repeat to consume the random state
dataset = dataset_cls(**dataset_kwargs)
torch.utils.data.random_split(dataset, lengths)
torch.utils.data.random_split(dataset, lengths, generator=generator)
return train_split, val_split


Expand Down Expand Up @@ -84,6 +89,7 @@ def __init__(
self.val_size_or_frac = val_size_or_frac
self.target_is_self = target_is_self
self.noise_transforms = noise_transforms or []
self.generator = torch.Generator()

# defined in self.setup()
self.train_val_size = None
Expand Down Expand Up @@ -154,6 +160,7 @@ def setup(self, stage=None):
root=self.data_dir,
train=True,
download=False,
generator=self.generator,
)
self.test_set = self.dataset_cls(
root=self.data_dir,
Expand All @@ -165,8 +172,12 @@ def setup(self, stage=None):
self.train_set = TransformedSelfDataset(
self.train_set, transforms=self.noise_transforms
)
self.val_set = TransformedSelfDataset(self.val_set)
self.test_set = TransformedSelfDataset(self.test_set)
self.val_set = TransformedSelfDataset(
self.val_set, transforms=self.noise_transforms
)
self.test_set = TransformedSelfDataset(
self.test_set, transforms=self.noise_transforms
)

# verify num_classes and num_channels
if (num_classes := len(self.test_set.classes)) != self.num_classes:
Expand Down Expand Up @@ -220,6 +231,8 @@ def train_dataloader(self):
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
)

# we can use a x2 batch_size in validation and testing,
Expand All @@ -230,6 +243,8 @@ def val_dataloader(self):
batch_size=self.batch_size * 2,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
)

def test_dataloader(self):
Expand All @@ -238,6 +253,8 @@ def test_dataloader(self):
batch_size=self.batch_size * 2,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
)


Expand Down Expand Up @@ -269,5 +286,6 @@ def __getitem__(self, item):
def __len__(self):
return len(self.dataset)

def __getattr__(self, item):
return getattr(self.dataset, item)
@property
def classes(self):
return self.dataset.classes
78 changes: 61 additions & 17 deletions src/explore_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,33 @@

import matplotlib.pyplot as plt
import torch
import torchvision.transforms.functional as TF # noqa: N812
from einops import rearrange
from IPython.core.display_functions import display
from ipywidgets import interact
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import to_pil_image

import models
from datamodules import ImagesDataModule
from models import FullyConnectedAutoEncoder
from train import LOGS_DIR

# %%
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else DEVICE

# %%
ModelClass = models.ConvVAE
dataset_name = "FashionMNIST"
datamodule = ImagesDataModule(dataset_name, 1, 10)

# %%
model_name = ModelClass.__name__.lower()
ckpt_dir = (
Path("/tmp/logs")
/ "fullyconnectedautoencodersgd-fashionmnist"
/ "fullyconnectedautoencodersgd-fashionmnist"
LOGS_DIR
/ f"{model_name}-{dataset_name.lower()}/{model_name}-{dataset_name.lower()}"
)

for p in ckpt_dir.parents[::-1] + (ckpt_dir,):
if not p.exists():
raise ValueError(f"{p} not exists")
Expand All @@ -53,54 +66,85 @@ def sort_dict(d: dict):
all_ckpts = sort_dict(get_last_fn(subdir) for subdir in ckpt_dir.glob("*"))
display(all_ckpts)


# %%
# torch.load(ckpt_dir/list(all_ckpts.values())[-1])['hyper_parameters']

# %%
model = FullyConnectedAutoEncoder.load_latest_checkpoint(ckpt_dir)
model.eval()


def load_model():
return ModelClass.load_latest_checkpoint(ckpt_dir, map_location=DEVICE).eval()


model = load_model()
print(model.hparams)
print(model)

# %%
x_rand = torch.rand(1, 1, 28, 28)
image = ImagesDataModule("FashionMNIST", 1, 10).dataset()[0][0]
x_rand = torch.rand(1, 1, 32, 32)
image, _target = datamodule.dataset()[0]

x_real = ToTensor()(image).unsqueeze(0)
x_rand = TF.center_crop(x_rand, 32)
x_real = TF.center_crop(x_real, 32)
print(x_real.shape)


# %%
def show_tensors(imgs: list[torch.Tensor]):
def show_tensors(imgs: list[torch.Tensor], normalize=True, figsize=None):
if not isinstance(imgs, list):
imgs = [imgs]
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False)
fig, axss = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize)
axs = axss[0]
for i, img in enumerate(imgs):
img_clipped = img.detach().clip(0, 1)
img_pil = to_pil_image(img_clipped)
if normalize:
img = (img - img.min()) / (img.max() - img.min())
img = img.clamp(0, 1).detach()
img_pil = to_pil_image(img)
axs[i].imshow(img_pil, cmap="gray", vmin=0, vmax=255)
axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


for x in [x_rand, x_real]:
show_tensors([x[0], model(x.cuda())[0]])
show_tensors([x[0], model(x.to(DEVICE)).x_hat[0]])

# %%
n_latent = 8
n_latent = model.latent_dim

lims = (-2, 2, 0.01)
lims = (-3, 3, 0.01)
all_lims = {f"x{i:02}": lims for i in range(n_latent)}


def show_from_latent(**inputs):
data = torch.tensor(list(inputs.values()))
data = data.view(1, -1).cuda()
data = data.view(1, -1).to(DEVICE)
result = model.decoder(data)[0]
show_tensors(result)
show_tensors(result, normalize=True)
plt.show()


interact(show_from_latent, **all_lims)

# %%
model = load_model()


def sample_latent(model, n: int = 30, lim: float = 3.0, downsample_factor: int = 2):
x = torch.linspace(-lim, lim, n)
y = torch.linspace(-lim, lim, n)
z = torch.cartesian_prod(x, y)
assert z.shape[1] == 2
with torch.inference_mode():
outs = model.decoder(z.to(model.device))
out = rearrange(outs, "(i j) c h w -> c (i h) (j w)", i=n, j=n)
out = torch.nn.functional.avg_pool2d(out, kernel_size=downsample_factor)
# out = reduce(out, "c (h i) (w j) -> c h w", i=downsample_factor,j=downsample_factor, reduction="max")
return out


out = sample_latent(model)
print(out.shape)
show_tensors(out, figsize=(10, 10))

# %%
2 changes: 2 additions & 0 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .auto_encoder import FullyConnectedAutoEncoder
from .conv_vae import ConvVAE
from .mlp import MultiLayerPerceptron
from .resnet import Resnet
from .resnet_vae import ResidualAutoencoder
Loading