Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit 322f7ea

Browse files
committed
use enum instead of Literal
1 parent ee48b7b commit 322f7ea

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import tempfile
44
import time
55
import warnings
6+
from enum import Enum
67
from pathlib import Path
7-
from typing import Literal
88

99
import torch
1010
import typer
@@ -117,6 +117,11 @@ def get_model(
117117
)
118118

119119

120+
class AvailableDatasets(str, Enum):
121+
FashionMNIST = "FashionMNIST"
122+
KMNIST = "KMNIST"
123+
124+
120125
@app.command()
121126
def train(
122127
seed: int = 42,
@@ -127,10 +132,10 @@ def train(
127132
checkpoint_path: str = None,
128133
batch_size: int = 2048,
129134
kl_weight: float = 0.005,
130-
dataset: Literal["FashionMNSIT", "KMNIST"] = "FashionMNIST",
135+
dataset: AvailableDatasets = AvailableDatasets.FashionMNIST,
131136
):
132137
seed = seed_everything(seed)
133-
datamodule = get_datamodule(batch_size=batch_size, dataset=dataset)
138+
datamodule = get_datamodule(batch_size=batch_size, dataset=dataset.value)
134139
model = get_model(
135140
num_channels=datamodule.num_channels,
136141
latent_dim=latent_dim,

0 commit comments

Comments
 (0)