diff --git a/.gitignore b/.gitignore index 82963e9..fef4590 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ __pycache__ TODO .ipynb_checkpoints/ *.ipynb +.idea/ +.jupyter_ystore.db +.virtual_documents/ diff --git a/poetry.lock b/poetry.lock index 99a9754..110509b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -360,6 +360,53 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "black" +version = "23.3.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.7" +files = [ + {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, + {file = "black-23.3.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:7c3eb7cea23904399866c55826b31c1f55bbcd3890ce22ff70466b907b6775c2"}, + {file = "black-23.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32daa9783106c28815d05b724238e30718f34155653d4d6e125dc7daec8e260c"}, + {file = "black-23.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:35d1381d7a22cc5b2be2f72c7dfdae4072a3336060635718cc7e1ede24221d6c"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:a8a968125d0a6a404842fa1bf0b349a568634f856aa08ffaff40ae0dfa52e7c6"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c7ab5790333c448903c4b721b59c0d80b11fe5e9803d8703e84dcb8da56fec1b"}, + {file = "black-23.3.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:a6f6886c9869d4daae2d1715ce34a19bbc4b95006d20ed785ca00fa03cba312d"}, + {file = "black-23.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3c333ea1dd6771b2d3777482429864f8e258899f6ff05826c3a4fcc5ce3f70"}, + {file = "black-23.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:11c410f71b876f961d1de77b9699ad19f939094c3a677323f43d7a29855fe326"}, + {file = "black-23.3.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:1d06691f1eb8de91cd1b322f21e3bfc9efe0c7ca1f0e1eb1db44ea367dff656b"}, + {file = "black-23.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50cb33cac881766a5cd9913e10ff75b1e8eb71babf4c7104f2e9c52da1fb7de2"}, + {file = "black-23.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e114420bf26b90d4b9daa597351337762b63039752bdf72bf361364c1aa05925"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:48f9d345675bb7fbc3dd85821b12487e1b9a75242028adad0333ce36ed2a6d27"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:714290490c18fb0126baa0fca0a54ee795f7502b44177e1ce7624ba1c00f2331"}, + {file = "black-23.3.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:064101748afa12ad2291c2b91c960be28b817c0c7eaa35bec09cc63aa56493c5"}, + {file = "black-23.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:562bd3a70495facf56814293149e51aa1be9931567474993c7942ff7d3533961"}, + {file = "black-23.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:e198cf27888ad6f4ff331ca1c48ffc038848ea9f031a3b40ba36aced7e22f2c8"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:3238f2aacf827d18d26db07524e44741233ae09a584273aa059066d644ca7b30"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:f0bd2f4a58d6666500542b26354978218a9babcdc972722f4bf90779524515f3"}, + {file = "black-23.3.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:92c543f6854c28a3c7f39f4d9b7694f9a6eb9d3c5e2ece488c327b6e7ea9b266"}, + {file = "black-23.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a150542a204124ed00683f0db1f5cf1c2aaaa9cc3495b7a3b5976fb136090ab"}, + {file = "black-23.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:6b39abdfb402002b8a7d030ccc85cf5afff64ee90fa4c5aebc531e3ad0175ddb"}, + {file = "black-23.3.0-py3-none-any.whl", hash = "sha256:ec751418022185b0c1bb7d7736e6933d40bbb14c14a0abcf9123d1b159f98dd4"}, + {file = "black-23.3.0.tar.gz", hash = "sha256:1c7b8d606e728a41ea1ccbd7264677e494e87cf630e399262ced92d4a8dac940"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "bleach" version = "6.0.0" @@ -1257,16 +1304,17 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt [[package]] name = "jupyter-collaboration" -version = "1.0.0a8" +version = "1.0.0a9" description = "JupyterLab Extension enabling Real-Time Collaboration" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_collaboration-1.0.0a8-py3-none-any.whl", hash = "sha256:cc50897bbf5107dccb9f00beb1228e3d90e8d3159d1bc85ef10fbb8031e2771f"}, - {file = "jupyter_collaboration-1.0.0a8.tar.gz", hash = "sha256:2848ddb49bc64095e25bfb0bc74fb262488adb6b321e34a28a9bdf71301895cc"}, + {file = "jupyter_collaboration-1.0.0a9-py3-none-any.whl", hash = "sha256:a17bbcfb0e20aead43c0c67867d89e1ac7325df31b70029b0d9ec625e205597d"}, + {file = "jupyter_collaboration-1.0.0a9.tar.gz", hash = "sha256:d1941e9d621cba4109c36cb807ebcc4999eea169d0311484a240c40ca6c6b079"}, ] [package.dependencies] +jupyter-events = "*" jupyter-server = ">=2.0.0,<3.0.0" jupyter-server-fileid = ">=0.6.0,<1" jupyter-ydoc = ">=1.0.1,<2.0.0" @@ -1274,8 +1322,8 @@ ypy-websocket = ">=0.8.3,<0.9.0" [package.extras] dev = ["click", "jupyter-releaser", "pre-commit"] -docs = ["jupyterlab (>=4.0.0a32)", "myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["coverage", "jupyter-server[test] (>=2.0.0)", "pytest (>=7.0)", "pytest-cov"] +docs = ["jupyterlab (>=4.0.0)", "myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["coverage", "jupyter-server[test] (>=2.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov"] [[package]] name = "jupyter-core" @@ -1430,13 +1478,13 @@ test = ["pre-commit", "pytest", "pytest-asyncio", "websockets (>=10.0)", "ypy-we [[package]] name = "jupyterlab" -version = "4.0.0" +version = "4.0.1" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.0.0-py3-none-any.whl", hash = "sha256:e2f67c189f833963c271a89df6bfa3eec4d5c8f7827ad3059538c5f467de193b"}, - {file = "jupyterlab-4.0.0.tar.gz", hash = "sha256:ce656d04828b2e4ee0758e22c862cc99aedec66a10319d09f0fd5ea51be68dd8"}, + {file = "jupyterlab-4.0.1-py3-none-any.whl", hash = "sha256:f3ebd90e41d3ba1b8152c8eda2bd1a18e0de490192b4be1a6ec132517cfe43ef"}, + {file = "jupyterlab-4.0.1.tar.gz", hash = "sha256:4dc3901f7bbfd4704c994b7a893a49955256abf57dba9831f4825e3f3165b8bb"}, ] [package.dependencies] @@ -1453,9 +1501,9 @@ tornado = ">=6.2.0" traitlets = "*" [package.extras] -dev = ["black[jupyter] (==23.3.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.263)"] +dev = ["black[jupyter] (==23.3.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.267)"] docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8)", "sphinx-copybutton"] -docs-screenshots = ["altair (==4.2.2)", "ipython (==8.13.1)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.3.1)", "jupyterlab-language-pack-zh-cn (==3.6.post1)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.1)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] +docs-screenshots = ["altair (==4.2.2)", "ipython (==8.13.1)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.3.1)", "jupyterlab-language-pack-zh-cn (==3.6.post2)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.1)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] [[package]] @@ -1824,6 +1872,23 @@ files = [ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"}, ] +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "multidict" version = "6.0.4" @@ -1907,6 +1972,17 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nbclient" version = "0.8.0" @@ -1998,6 +2074,24 @@ files = [ {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"}, ] +[[package]] +name = "networkx" +version = "3.1" +description = "Python package for creating and manipulating graphs and networks" +optional = false +python-versions = ">=3.8" +files = [ + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, +] + +[package.extras] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2114,6 +2208,17 @@ files = [ qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "pathspec" +version = "0.11.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, + {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, +] + [[package]] name = "pathtools" version = "0.1.2" @@ -2981,6 +3086,17 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "shellingham" +version = "1.5.0.post1" +description = "Tool to Detect Surrounding Shell" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.0.post1-py2.py3-none-any.whl", hash = "sha256:368bf8c00754fd4f55afb7bbb86e272df77e4dc76ac29dbcbb81a59e9fc15744"}, + {file = "shellingham-1.5.0.post1.tar.gz", hash = "sha256:823bc5fb5c34d60f285b624e7264f4dda254bc803a3774a147bf99c0e3004a28"}, +] + [[package]] name = "six" version = "1.16.0" @@ -3044,6 +3160,20 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "sympy" +version = "1.12" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"}, + {file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"}, +] + +[package.dependencies] +mpmath = ">=0.19" + [[package]] name = "terminado" version = "0.17.1" @@ -3098,8 +3228,39 @@ name = "torch" version = "2.0.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false -python-versions = "*" -files = [] +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8ced00b3ba471856b993822508f77c98f48a458623596a4c43136158781e306a"}, + {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, + {file = "torch-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7c84e44d9002182edd859f3400deaa7410f5ec948a519cc7ef512c2f9b34d2c4"}, + {file = "torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:567f84d657edc5582d716900543e6e62353dbe275e61cdc36eda4929e46df9e7"}, + {file = "torch-2.0.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:787b5a78aa7917465e9b96399b883920c88a08f4eb63b5a5d2d1a16e27d2f89b"}, + {file = "torch-2.0.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:e617b1d0abaf6ced02dbb9486803abfef0d581609b09641b34fa315c9c40766d"}, + {file = "torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b6019b1de4978e96daa21d6a3ebb41e88a0b474898fe251fd96189587408873e"}, + {file = "torch-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:dbd68cbd1cd9da32fe5d294dd3411509b3d841baecb780b38b3b7b06c7754434"}, + {file = "torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:ef654427d91600129864644e35deea761fb1fe131710180b952a6f2e2207075e"}, + {file = "torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:25aa43ca80dcdf32f13da04c503ec7afdf8e77e3a0183dd85cd3e53b2842e527"}, + {file = "torch-2.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ef3ea3d25441d3957348f7e99c7824d33798258a2bf5f0f0277cbcadad2e20d"}, + {file = "torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0882243755ff28895e8e6dc6bc26ebcf5aa0911ed81b2a12f241fc4b09075b13"}, + {file = "torch-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:f66aa6b9580a22b04d0af54fcd042f52406a8479e2b6a550e3d9f95963e168c8"}, + {file = "torch-2.0.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:1adb60d369f2650cac8e9a95b1d5758e25d526a34808f7448d0bd599e4ae9072"}, + {file = "torch-2.0.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:1bcffc16b89e296826b33b98db5166f990e3b72654a2b90673e817b16c50e32b"}, + {file = "torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:e10e1597f2175365285db1b24019eb6f04d53dcd626c735fc502f1e8b6be9875"}, + {file = "torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:423e0ae257b756bb45a4b49072046772d1ad0c592265c5080070e0767da4e490"}, + {file = "torch-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8742bdc62946c93f75ff92da00e3803216c6cce9b132fbca69664ca38cfb3e18"}, + {file = "torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:c62df99352bd6ee5a5a8d1832452110435d178b5164de450831a3a8cc14dc680"}, + {file = "torch-2.0.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:671a2565e3f63b8fe8e42ae3e36ad249fe5e567435ea27b94edaa672a7d0c416"}, +] + +[package.dependencies] +filelock = "*" +jinja2 = "*" +networkx = "*" +sympy = "*" +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] [[package]] name = "torchmetrics" @@ -3131,8 +3292,38 @@ name = "torchvision" version = "0.15.2" description = "image and video datasets and models for torch deep learning" optional = false -python-versions = "*" -files = [] +python-versions = ">=3.8" +files = [ + {file = "torchvision-0.15.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7754088774e810c5672b142a45dcf20b1bd986a5a7da90f8660c43dc43fb850c"}, + {file = "torchvision-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37eb138e13f6212537a3009ac218695483a635c404b6cc1d8e0d0d978026a86d"}, + {file = "torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:54143f7cc0797d199b98a53b7d21c3f97615762d4dd17ad45a41c7e80d880e73"}, + {file = "torchvision-0.15.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1eefebf5fbd01a95fe8f003d623d941601c94b5cec547b420da89cb369d9cf96"}, + {file = "torchvision-0.15.2-cp310-cp310-win_amd64.whl", hash = "sha256:96fae30c5ca8423f4b9790df0f0d929748e32718d88709b7b567d2f630c042e3"}, + {file = "torchvision-0.15.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5f35f6bd5bcc4568e6522e4137fa60fcc72f4fa3e615321c26cd87e855acd398"}, + {file = "torchvision-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:757505a0ab2be7096cb9d2bf4723202c971cceddb72c7952a7e877f773de0f8a"}, + {file = "torchvision-0.15.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:012ad25cfd9019ff9b0714a168727e3845029be1af82296ff1e1482931fa4b80"}, + {file = "torchvision-0.15.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b02a7ffeaa61448737f39a4210b8ee60234bda0515a0c0d8562f884454105b0f"}, + {file = "torchvision-0.15.2-cp311-cp311-win_amd64.whl", hash = "sha256:10be76ceded48329d0a0355ac33da131ee3993ff6c125e4a02ab34b5baa2472c"}, + {file = "torchvision-0.15.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f12415b686dba884fb086f53ac803f692be5a5cdd8a758f50812b30fffea2e4"}, + {file = "torchvision-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:31211c01f8b8ec33b8a638327b5463212e79a03e43c895f88049f97af1bd12fd"}, + {file = "torchvision-0.15.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c55f9889e436f14b4f84a9c00ebad0d31f5b4626f10cf8018e6c676f92a6d199"}, + {file = "torchvision-0.15.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9a192f2aa979438f23c20e883980b23d13268ab9f819498774a6d2eb021802c2"}, + {file = "torchvision-0.15.2-cp38-cp38-win_amd64.whl", hash = "sha256:c07071bc8d02aa8fcdfe139ab6a1ef57d3b64c9e30e84d12d45c9f4d89fb6536"}, + {file = "torchvision-0.15.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4790260fcf478a41c7ecc60a6d5200a88159fdd8d756e9f29f0f8c59c4a67a68"}, + {file = "torchvision-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:987ab62225b4151a11e53fd06150c5258ced24ac9d7c547e0e4ab6fbca92a5ce"}, + {file = "torchvision-0.15.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:63df26673e66cba3f17e07c327a8cafa3cce98265dbc3da329f1951d45966838"}, + {file = "torchvision-0.15.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b85f98d4cc2f72452f6792ab4463a3541bc5678a8cdd3da0e139ba2fe8b56d42"}, + {file = "torchvision-0.15.2-cp39-cp39-win_amd64.whl", hash = "sha256:07c462524cc1bba5190c16a9d47eac1fca024d60595a310f23c00b4ffff18b30"}, +] + +[package.dependencies] +numpy = "*" +pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" +requests = "*" +torch = "2.0.1" + +[package.extras] +scipy = ["scipy"] [[package]] name = "tornado" @@ -3189,6 +3380,30 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] +[[package]] +name = "typer" +version = "0.9.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, + {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" +colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""} +rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""} +shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""} +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "typing-extensions" version = "4.6.2" @@ -3610,4 +3825,4 @@ test = ["mypy", "pre-commit", "pytest", "pytest-asyncio", "websockets (>=10.0)"] [metadata] lock-version = "2.0" python-versions = "^3.11,!=3.11.0" -content-hash = "0694a43e5ab4020ccb5d212f1f73bfe5a5e9511ecfb0c20001d6c172cbb492cf" +content-hash = "be3c1881613436341bf9a7786e1162ebde4bd38420d6505d58cbcc71ed4f04e2" diff --git a/pyproject.toml b/pyproject.toml index 38d0242..30986c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,13 @@ authors = ["Tsvika S "] [tool.poetry.dependencies] python = "^3.11,!=3.11.0" -torch = {version = "^2.0.0", source = "pytorch"} -torchvision = {version = "^0.15.1", source = "pytorch"} +torch = "^2.0.1" +torchvision = "^0.15.2" pytorch-lightning = "^2.0.2" wandb = "^0.15.2" rich = "^13.3.5" einops = "^0.6.1" +typer = {extras = ["all"], version = "^0.9.0"} [tool.poetry.group.jupyter.dependencies] jupyterlab = "^4.0.0" @@ -26,6 +27,7 @@ python-lsp-server = "^1.7.3" [tool.poetry.group.dev.dependencies] pre-commit = "^3.3.1" ruff = "^0.0.264" +black = "^23.3.0" [[tool.poetry.source]] name = "pytorch" @@ -50,12 +52,14 @@ ignore = [ "C408", "TRY003", "FBT002", + "PLW2901", ] src = ["src"] [tool.ruff.per-file-ignores] "__init__.py" = ["F401"] "src/train.py" = ["INP001", "T201"] +"src/models/resnet_vae.py" = ["T201", "PD002"] "src/explore_model.py" = [ "INP001", "E703", diff --git a/src/datamodules/images.py b/src/datamodules/images.py index 0f0435b..4de3633 100644 --- a/src/datamodules/images.py +++ b/src/datamodules/images.py @@ -16,19 +16,24 @@ def train_val_split( train_transform, val_transform, dataset_cls, + generator: torch.Generator = None, **dataset_kwargs, ): """load a dataset and split it, using a different transform for train and val""" lengths = [train_length, val_length] with isolate_rng(): dataset_train = dataset_cls(**dataset_kwargs, transform=train_transform) - train_split, _ = torch.utils.data.random_split(dataset_train, lengths) + train_split, _ = torch.utils.data.random_split( + dataset_train, lengths, generator=generator + ) with isolate_rng(): dataset_val = dataset_cls(**dataset_kwargs, transform=val_transform) - _, val_split = torch.utils.data.random_split(dataset_val, lengths) + _, val_split = torch.utils.data.random_split( + dataset_val, lengths, generator=generator + ) # repeat to consume the random state dataset = dataset_cls(**dataset_kwargs) - torch.utils.data.random_split(dataset, lengths) + torch.utils.data.random_split(dataset, lengths, generator=generator) return train_split, val_split @@ -84,6 +89,7 @@ def __init__( self.val_size_or_frac = val_size_or_frac self.target_is_self = target_is_self self.noise_transforms = noise_transforms or [] + self.generator = torch.Generator() # defined in self.setup() self.train_val_size = None @@ -154,6 +160,7 @@ def setup(self, stage=None): root=self.data_dir, train=True, download=False, + generator=self.generator, ) self.test_set = self.dataset_cls( root=self.data_dir, @@ -165,8 +172,12 @@ def setup(self, stage=None): self.train_set = TransformedSelfDataset( self.train_set, transforms=self.noise_transforms ) - self.val_set = TransformedSelfDataset(self.val_set) - self.test_set = TransformedSelfDataset(self.test_set) + self.val_set = TransformedSelfDataset( + self.val_set, transforms=self.noise_transforms + ) + self.test_set = TransformedSelfDataset( + self.test_set, transforms=self.noise_transforms + ) # verify num_classes and num_channels if (num_classes := len(self.test_set.classes)) != self.num_classes: @@ -220,6 +231,8 @@ def train_dataloader(self): batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, ) # we can use a x2 batch_size in validation and testing, @@ -230,6 +243,8 @@ def val_dataloader(self): batch_size=self.batch_size * 2, shuffle=False, num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, ) def test_dataloader(self): @@ -238,6 +253,8 @@ def test_dataloader(self): batch_size=self.batch_size * 2, shuffle=False, num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, ) @@ -269,5 +286,6 @@ def __getitem__(self, item): def __len__(self): return len(self.dataset) - def __getattr__(self, item): - return getattr(self.dataset, item) + @property + def classes(self): + return self.dataset.classes diff --git a/src/explore_model.py b/src/explore_model.py index 35329cb..c2a54b2 100644 --- a/src/explore_model.py +++ b/src/explore_model.py @@ -18,20 +18,33 @@ import matplotlib.pyplot as plt import torch +import torchvision.transforms.functional as TF # noqa: N812 +from einops import rearrange from IPython.core.display_functions import display from ipywidgets import interact from torchvision.transforms import ToTensor from torchvision.transforms.functional import to_pil_image +import models from datamodules import ImagesDataModule -from models import FullyConnectedAutoEncoder +from train import LOGS_DIR # %% +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +DEVICE = torch.device("mps") if torch.backends.mps.is_available() else DEVICE + +# %% +ModelClass = models.ConvVAE +dataset_name = "FashionMNIST" +datamodule = ImagesDataModule(dataset_name, 1, 10) + +# %% +model_name = ModelClass.__name__.lower() ckpt_dir = ( - Path("/tmp/logs") - / "fullyconnectedautoencodersgd-fashionmnist" - / "fullyconnectedautoencodersgd-fashionmnist" + LOGS_DIR + / f"{model_name}-{dataset_name.lower()}/{model_name}-{dataset_name.lower()}" ) + for p in ckpt_dir.parents[::-1] + (ckpt_dir,): if not p.exists(): raise ValueError(f"{p} not exists") @@ -53,54 +66,85 @@ def sort_dict(d: dict): all_ckpts = sort_dict(get_last_fn(subdir) for subdir in ckpt_dir.glob("*")) display(all_ckpts) + # %% # torch.load(ckpt_dir/list(all_ckpts.values())[-1])['hyper_parameters'] # %% -model = FullyConnectedAutoEncoder.load_latest_checkpoint(ckpt_dir) -model.eval() + + +def load_model(): + return ModelClass.load_latest_checkpoint(ckpt_dir, map_location=DEVICE).eval() + + +model = load_model() print(model.hparams) print(model) # %% -x_rand = torch.rand(1, 1, 28, 28) -image = ImagesDataModule("FashionMNIST", 1, 10).dataset()[0][0] +x_rand = torch.rand(1, 1, 32, 32) +image, _target = datamodule.dataset()[0] x_real = ToTensor()(image).unsqueeze(0) +x_rand = TF.center_crop(x_rand, 32) +x_real = TF.center_crop(x_real, 32) print(x_real.shape) # %% -def show_tensors(imgs: list[torch.Tensor]): +def show_tensors(imgs: list[torch.Tensor], normalize=True, figsize=None): if not isinstance(imgs, list): imgs = [imgs] - fig, axss = plt.subplots(ncols=len(imgs), squeeze=False) + fig, axss = plt.subplots(ncols=len(imgs), squeeze=False, figsize=figsize) axs = axss[0] for i, img in enumerate(imgs): - img_clipped = img.detach().clip(0, 1) - img_pil = to_pil_image(img_clipped) + if normalize: + img = (img - img.min()) / (img.max() - img.min()) + img = img.clamp(0, 1).detach() + img_pil = to_pil_image(img) axs[i].imshow(img_pil, cmap="gray", vmin=0, vmax=255) axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for x in [x_rand, x_real]: - show_tensors([x[0], model(x.cuda())[0]]) + show_tensors([x[0], model(x.to(DEVICE)).x_hat[0]]) # %% -n_latent = 8 +n_latent = model.latent_dim -lims = (-2, 2, 0.01) +lims = (-3, 3, 0.01) all_lims = {f"x{i:02}": lims for i in range(n_latent)} def show_from_latent(**inputs): data = torch.tensor(list(inputs.values())) - data = data.view(1, -1).cuda() + data = data.view(1, -1).to(DEVICE) result = model.decoder(data)[0] - show_tensors(result) + show_tensors(result, normalize=True) plt.show() interact(show_from_latent, **all_lims) # %% +model = load_model() + + +def sample_latent(model, n: int = 30, lim: float = 3.0, downsample_factor: int = 2): + x = torch.linspace(-lim, lim, n) + y = torch.linspace(-lim, lim, n) + z = torch.cartesian_prod(x, y) + assert z.shape[1] == 2 + with torch.inference_mode(): + outs = model.decoder(z.to(model.device)) + out = rearrange(outs, "(i j) c h w -> c (i h) (j w)", i=n, j=n) + out = torch.nn.functional.avg_pool2d(out, kernel_size=downsample_factor) + # out = reduce(out, "c (h i) (w j) -> c h w", i=downsample_factor,j=downsample_factor, reduction="max") + return out + + +out = sample_latent(model) +print(out.shape) +show_tensors(out, figsize=(10, 10)) + +# %% diff --git a/src/models/__init__.py b/src/models/__init__.py index 5e4ed97..2afa3dc 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +1,5 @@ from .auto_encoder import FullyConnectedAutoEncoder +from .conv_vae import ConvVAE from .mlp import MultiLayerPerceptron from .resnet import Resnet +from .resnet_vae import ResidualAutoencoder diff --git a/src/models/base.py b/src/models/base.py index 8fc2a57..9d638d8 100644 --- a/src/models/base.py +++ b/src/models/base.py @@ -1,11 +1,29 @@ +import dataclasses from pathlib import Path import pytorch_lightning as pl import torch import torch.nn.functional as F # noqa: N812 +from einops import rearrange from torchmetrics.functional.classification import multiclass_accuracy +class HasXHat: + x_hat: torch.Tensor + + +@dataclasses.dataclass +class AutoEncoderOutput: + x_hat: torch.Tensor + + +@dataclasses.dataclass +class VAEOutput(AutoEncoderOutput): + x_hat: torch.Tensor + mu: torch.Tensor + log_var_2: torch.Tensor + + class SimpleLightningModule(pl.LightningModule): def __init__(self): super().__init__() @@ -37,9 +55,11 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + self.log("trainer/total_examples", float(self.total_examples)) self.step(batch, batch_idx, "validation", evaluate=True) def test_step(self, batch, batch_idx): + self.log("trainer/total_examples", float(self.total_examples)) self.step(batch, batch_idx, "test", evaluate=True) @@ -186,22 +206,108 @@ class AutoEncoder(LightningModuleWithScheduler): n_images_to_save = 8 + def __init__( + self, + sampler_lim: float = 3.0, + sampler_n: int = 30, + sampler_downsample_factor: int = 1, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.sampler_n = sampler_n + self.sampler_downsample_factor = sampler_downsample_factor + self.register_buffer( + "sampler_x", torch.linspace(-sampler_lim, sampler_lim, sampler_n) + ) + self.register_buffer( + "sampler_y", torch.linspace(-sampler_lim, sampler_lim, sampler_n) + ) + self.register_buffer( + "sampler_xy", torch.cartesian_prod(self.sampler_x, self.sampler_y) + ) + + def sample_latent(self): + if self.latent_dim != 2: + return None + with torch.inference_mode(): + outs = self.decoder(self.sampler_xy) + out = rearrange( + outs, "(i j) c h w -> c (i h) (j w)", i=self.sampler_n, j=self.sampler_n + ) + out = torch.nn.functional.avg_pool2d( + out, kernel_size=self.sampler_downsample_factor + ) + return out + + def loss_function(self, batch, out: HasXHat): + x, target = batch + assert out.x_hat.shape == x.shape + loss = F.mse_loss(out.x_hat, target) + return loss + def step(self, batch, batch_idx, stage: str, *, evaluate=False): x, target = batch assert x.shape == target.shape - x2 = self(x) - assert x2.shape == x.shape - loss = F.mse_loss(x2, target) - self.log(f"loss/{stage}", loss, prog_bar=evaluate) - if stage == "validation": - assert torch.equal(x, target) - if self.global_step == 0 and batch_idx == 0: + out: AutoEncoderOutput = self(x) + loss = self.loss_function(batch, out) + + if isinstance(loss, dict): + assert "loss" in loss, "Primary loss must be present" + for k, v in loss.items(): + self.log(f"{k}/{stage}", v, prog_bar=evaluate) + loss = loss["loss"] + else: + self.log(f"loss/{stage}", loss, prog_bar=evaluate) + + if stage == "validation" and self.logger: + if batch_idx == 0: self.logger.log_image("image/src", list(x[: self.n_images_to_save])) + if self.global_step == 0 and batch_idx == 0: + self.logger.log_image( + "image/target", list(target[: self.n_images_to_save]) + ) if batch_idx == 0: - self.logger.log_image("image/pred", list(x2[: self.n_images_to_save])) + self.logger.log_image( + "image/pred", list(out.x_hat[: self.n_images_to_save]) + ) + + if stage == "validation" and self.logger and batch_idx == 0: + sampled_images = self.sample_latent() + if sampled_images is not None: + self.logger.log_image("image/sampled", [sampled_images]) + return loss +class VAE(AutoEncoder): + def __init__(self, kl_weight: float = 0.005, *args, **kwargs): + """ + Magic constant from https://github.com/AntixK/PyTorch-VAE/ + """ + super().__init__(*args, **kwargs) + self.kl_weight = kl_weight + + def loss_function(self, batch, out: VAEOutput) -> dict[str, torch.Tensor]: + x, target = batch + assert out.x_hat.shape == x.shape + reconstruction_loss = F.mse_loss(out.x_hat, target) + log_var_2 = out.log_var_2 + mu = out.mu + + kl_loss = ( + 0.5 + * (-log_var_2 + log_var_2.exp() + mu**2 - 1).sum(dim=1).mean(dim=0) + * self.kl_weight + ) + loss = reconstruction_loss + kl_loss + return dict( + loss=loss, + reconstruction_loss=reconstruction_loss.detach(), + kl_loss=kl_loss.detach(), + ) + + class ImageAutoEncoder(AutoEncoder): def __init__( self, @@ -233,3 +339,38 @@ def __init__( self.example_input_array = torch.empty( sample_batch_size, num_channels, self.image_size, self.image_size ) + + +class ImageVAE(VAE): + def __init__( + self, + image_size: int, + num_channels: int, + kl_weight: float, + *, + optimizer_cls=None, + optimizer_kwargs=None, + scheduler_cls=None, + scheduler_kwargs=None, + scheduler_interval="epoch", + scheduler_frequency=None, + scheduler_add_total_steps=False, + scheduler_monitor=None, + ): + super().__init__( + kl_weight=kl_weight, + optimizer_cls=optimizer_cls, + optimizer_kwargs=optimizer_kwargs, + scheduler_cls=scheduler_cls, + scheduler_kwargs=scheduler_kwargs, + scheduler_interval=scheduler_interval, + scheduler_frequency=scheduler_frequency, + scheduler_add_total_steps=scheduler_add_total_steps, + scheduler_monitor=scheduler_monitor, + ) + sample_batch_size = 32 + self.image_size = image_size or 96 + self.num_channels = num_channels + self.example_input_array = torch.empty( + sample_batch_size, num_channels, self.image_size, self.image_size + ) diff --git a/src/models/conv_vae.py b/src/models/conv_vae.py new file mode 100644 index 0000000..e1b5165 --- /dev/null +++ b/src/models/conv_vae.py @@ -0,0 +1,261 @@ +from collections.abc import Callable + +import torch +import torch.nn as nn +from einops.layers.torch import Rearrange + +from . import base +from .base import VAEOutput + +ActivationT = Callable[[torch.Tensor], torch.Tensor] + + +class DownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + act_fn: ActivationT, + ): + super().__init__() + self.shortcut = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity(), + ) + + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + self.act = act_fn + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + residual = self.shortcut(x) + x = self.act(self.conv1(x)) + x = self.act(self.conv2(x)) + x = self.bn(x) + x = x + residual + return x + + +class UpBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + act_fn: ActivationT = nn.functional.gelu, + output_padding: int = 1, + ): + super().__init__() + self.shortcut = nn.Sequential( + nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + if in_channels != out_channels + else nn.Identity(), + nn.Upsample(scale_factor=2), + ) + self.conv_t1 = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=output_padding, + ) + self.conv_t2 = nn.ConvTranspose2d( + out_channels, out_channels, kernel_size=3, padding=1 + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = act_fn + + def forward(self, x): + residual = self.shortcut(x) + x = self.act(self.conv_t1(x)) + x = self.act(self.conv_t2(x)) + x = self.bn(x) + x = x + residual + return x + + +class Encoder(nn.Module): + def __init__( + self, + num_input_channels: int, + channels: tuple[int], + latent_dim: int, + act_fn: ActivationT = nn.functional.gelu, + latent_act_fn: type[nn.Module] = nn.Identity, + first_kernel_size: int = 7, + image_size: int = 32, + ): + """ + Args: + num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3 + base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it. + latent_dim : Dimensionality of latent representation z + act_fn : Activation function used throughout the encoder network + """ + super().__init__() + self.act = act_fn + self.image_size = image_size + self.bottleneck_size = image_size // 2 // 2 // 2 // 2 + + self.conv = nn.Conv2d( + num_input_channels, + channels[0], + kernel_size=first_kernel_size, + padding=first_kernel_size // 2, + stride=2, + bias=False, + ) + + self.down1 = DownBlock(channels[0], channels[1], act_fn) + self.down2 = DownBlock(channels[1], channels[2], act_fn) + self.down3 = DownBlock(channels[2], channels[3], act_fn) + + self.flatten = Rearrange( + "b c h w -> b (c h w)", + h=self.bottleneck_size, + w=self.bottleneck_size, + c=channels[3], + ) + self.mu = nn.Linear( + self.bottleneck_size * self.bottleneck_size * channels[3], latent_dim + ) + self.log_var = nn.Linear( + self.bottleneck_size * self.bottleneck_size * channels[3], latent_dim + ) + self.latent_act = latent_act_fn() + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x = self.act(self.conv(x)) + x = self.down1(x) + x = self.down2(x) + x = self.down3(x) + x = self.flatten(x) + x = self.latent_act(x) + mu = self.mu(x) + log_var = self.log_var(x) + return mu, log_var + + +# %% +class Decoder(nn.Module): + def __init__( + self, + num_input_channels: int, + channels: tuple[int], + latent_dim: int, + act_fn: ActivationT = nn.functional.gelu, + first_kernel_size: int = 7, + image_size: int = 32, + ): + """ + Args: + num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3 + base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it. + latent_dim : Dimensionality of latent representation z + act_fn : Activation function used throughout the decoder network + """ + super().__init__() + self.act = act_fn + self.image_size = image_size + self.bottleneck_size = image_size // 2 // 2 // 2 // 2 + + self.linear = nn.Linear( + latent_dim, self.bottleneck_size * self.bottleneck_size * channels[3] + ) + self.reshape = Rearrange( + "b (c h w) -> b c h w", h=self.bottleneck_size, w=self.bottleneck_size + ) + self.up1 = UpBlock(channels[3], channels[2], act_fn) + self.up2 = UpBlock(channels[2], channels[1], act_fn) + self.up3 = UpBlock(channels[1], channels[0], act_fn) + self.final_layer_up = nn.Sequential( + nn.ConvTranspose2d( + channels[0], + channels[0], + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.BatchNorm2d(channels[0]), + ) + self.sharpen = nn.Sequential( + nn.Conv2d( + channels[0], out_channels=num_input_channels, kernel_size=3, padding=1 + ), + nn.Tanh(), + ) + + def forward(self, x): + x = self.act(self.linear(x)) + x = self.reshape(x) + x = self.up1(x) + x = self.up2(x) + x = self.up3(x) + x = self.final_layer_up(x) + x = self.act(x) + x = self.sharpen(x) + x = 3 * x + return x + + +class ConvVAE(base.ImageVAE): + def __init__( + self, + channels: tuple[int, int, int, int] = (16, 16, 16, 16), + latent_dim: int = 8, + encoder_class: type[nn.Module] = Encoder, + decoder_class: type[nn.Module] = Decoder, + num_channels: int = 1, + latent_noise: float = 0.0, + first_kernel_size: int = 5, + image_size: int = 32, + act_fn=nn.functional.gelu, + **kwargs, + ): + super().__init__(**kwargs, num_channels=num_channels, image_size=image_size) + self.save_hyperparameters() + # Creating encoder and decoder + self.encoder = encoder_class( + num_channels, + channels, + latent_dim, + first_kernel_size=first_kernel_size, + image_size=image_size, + act_fn=act_fn, + ) + self.decoder = decoder_class( + num_channels, + channels, + latent_dim, + first_kernel_size=first_kernel_size, + image_size=image_size, + act_fn=act_fn, + ) + + self.latent_dim = latent_dim + self.num_input_channels = num_channels + self.latent_noise = latent_noise + + def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: + """ + Reparameterization trick to sample from N(mu, var) from N(0,1). + """ + std = torch.exp(log_var * 0.5) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x: torch.Tensor) -> VAEOutput: + """The forward function takes in an image and returns the reconstructed image.""" + mu, log_var_2 = self.encoder(x) + # z is the latent representation + z = self.reparameterize(mu, log_var_2) + x_hat = self.decoder(z) + return VAEOutput(x_hat=x_hat, mu=mu, log_var_2=log_var_2) diff --git a/src/models/resnet_vae.py b/src/models/resnet_vae.py new file mode 100644 index 0000000..1d4f3cd --- /dev/null +++ b/src/models/resnet_vae.py @@ -0,0 +1,490 @@ +import copy # noqa: I001 +import itertools # noqa: F401 +import time +from typing import Callable, Optional, Union # noqa: F401, UP035 + +import torch +from torch import nn +from torch import Tensor +from torchvision.models import resnet18 # noqa: F401 +import torch.utils.data +import torchvision.datasets +from torchvision import transforms +from tqdm import tqdm + +import wandb + +from . import base + + +def conv3x3( + in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, **_kw) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def transposed_conv1x1( + in_planes: int, out_planes: int, stride: int = 1, output_padding: int = 0 +) -> nn.ConvTranspose2d: + """1x1 transposed convolution""" + return nn.ConvTranspose2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False, + output_padding=output_padding, + ) + + +def transposed_conv3x3( + in_planes: int, + out_planes: int, + stride: int = 1, + groups: int = 1, + _dilation: int = 1, + output_padding: int = 1, +) -> nn.ConvTranspose2d: + """3x3 convolution with padding""" + return nn.ConvTranspose2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + output_padding=output_padding, + groups=groups, + bias=False, + # dilation=dilation, + ) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, # noqa: UP007 + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, # noqa: UP007 + output_padding: int = 1, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + if inplanes > planes: + self.conv1 = transposed_conv3x3( + inplanes, planes, stride, output_padding=output_padding + ) + else: + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResidualAutoencoder(base.ImageAutoEncoder): + def __init__(self, bottleneck, layers=[2, 2, 2, 2], **kw): # noqa: B006 + super().__init__(**kw) + self.save_hyperparameters() + self.dilation = 1 + self.inplanes = 4 + self.conv1 = nn.Conv2d( + 1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.layer1 = self._make_layer(4, layers[0]) + self.layer2 = self._make_layer(8, layers[1], stride=2) + self.layer3 = self._make_layer(16, layers[2], stride=2) + self.bottleneck = nn.Linear(1024, bottleneck) + self.decode_bottleneck = nn.Linear(bottleneck, 1024) + self.layer3_ = self._make_layer_( + 8, layers[2], stride=2, output_padding=1, downsample_output_padding=1 + ) + self.layer2_ = self._make_layer_( + 4, layers[1], stride=2, output_padding=1, downsample_output_padding=1 + ) + self.layer1_ = self._make_layer_(1, layers[1]) + # self.tconv1 = nn.ConvTranspose2d( + # self.inplanes, 1, kernel_size=1, stride=1, padding=0, bias=False + # ) + + def _make_layer( + self, + planes: int, + blocks: int, + stride: int = 1, + downsample_output_padding: int = 0, + output_padding: int = 0, + resample: Callable = conv1x1, + ) -> nn.Sequential: + downsample = None + previous_dilation = self.dilation + if stride != 1 or self.inplanes != planes: + downsample = nn.Sequential( + resample( + self.inplanes, + planes, + stride, + output_padding=downsample_output_padding, + ), + nn.BatchNorm2d(planes), + ) + + layers = [] + layers.append( + BasicBlock( + self.inplanes, + planes, + stride, + downsample, + 1, + 64, + previous_dilation, + nn.BatchNorm2d, + output_padding=output_padding, + ) + ) + self.inplanes = planes + for _ in range(1, blocks): + layers.append( + BasicBlock( + self.inplanes, + planes, + groups=1, + base_width=64, + dilation=self.dilation, + norm_layer=nn.BatchNorm2d, + output_padding=output_padding, + ) + ) + + return nn.Sequential(*layers) + + def _make_layer_( + self, + planes: int, + blocks: int, + stride: int = 1, + output_padding: int = 0, + downsample_output_padding: int = 0, + ) -> nn.Sequential: + return self._make_layer( + planes, + blocks, + stride, + resample=transposed_conv1x1, + output_padding=output_padding, + downsample_output_padding=downsample_output_padding, + ) + + def forward(self, x): + shape = x.shape + + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + image_shape = x.shape + x = x.reshape(shape[0], -1) + x = self.bottleneck(x) + x = self.decode_bottleneck(x) + x = x.reshape(image_shape) + + x = self.layer3_(x) + x = self.layer2_(x) + x = self.layer1_(x) + + # x = self.tconv1(x) + + assert shape == x.shape, (shape, x.shape) + return x + + +def train_model(model, dataloaders, criterion, optimizer, num_epochs=25): + since = time.time() + + best_model_wts = copy.deepcopy(model.state_dict()) + best_acc = 0.0 + + for epoch in range(num_epochs): + print("Epoch {}/{}".format(epoch, num_epochs - 1)) # noqa: UP032 + print("-" * 10) + + # Each epoch has a training and validation phase + for phase in ["train", "val"]: + if phase == "train": + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + running_loss = 0.0 + running_corrects = 0 + + # Iterate over data. + for inputs, labels in tqdm(dataloaders[phase]): + labels = labels.to(device) + # zero the parameter gradients + optimizer.zero_grad() + + # forward + # track history if only in train + with torch.set_grad_enabled(phase == "train"): + # Get model outputs and calculate loss + outputs = model(inputs) + B, _C, H, W = inputs.shape # noqa: N806 + assert _C == 1 + assert outputs.shape == (B, 10) + assert labels.shape == (B,) + loss = criterion(outputs, labels) + _, preds = torch.max(outputs, 1) + + # backward + optimize only if in training phase + if phase == "train": + loss.backward() + optimizer.step() + + # statistics + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / len(dataloaders[phase].dataset) + epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) + + wandb.log( + {"epoch": epoch, phase + "/loss": epoch_loss, phase + "/acc": epoch_acc} + ) + + # deep copy the model + if phase == "val" and epoch_acc > best_acc: + best_acc = epoch_acc + best_model_wts = copy.deepcopy(model.state_dict()) + + print() + + time_elapsed = time.time() - since + print( + "Training complete in {:.0f}m {:.0f}s".format( + time_elapsed // 60, time_elapsed % 60 + ) + ) + print("Best val Acc: {:4f}".format(best_acc)) # noqa: UP032 + + # load best model weights + model.load_state_dict(best_model_wts) + return model + + +def train_autoencoder(model, dataloaders, criterion, optimizer, num_epochs=25): + since = time.time() + + best_model_wts = copy.deepcopy(model.state_dict()) + best_loss = 1e100 + # best_acc = 0.0 + + for epoch in range(num_epochs): + print("Epoch {}/{}".format(epoch, num_epochs - 1)) # noqa: UP032 + print("-" * 10) + + # Each epoch has a training and validation phase + for phase in ["train", "val"]: + if phase == "train": + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + running_loss = 0.0 + # running_corrects = 0 + + # Iterate over data. + for inputs, labels in tqdm(dataloaders[phase]): # noqa: B007 + # labels = labels.to(device) + # zero the parameter gradients + optimizer.zero_grad() + + # forward + # track history if only in train + with torch.set_grad_enabled(phase == "train"): + # Get model outputs and calculate loss + outputs = model(inputs) + B, _C, H, W = inputs.shape # noqa: N806 + assert _C == 1 + assert outputs.shape == inputs.shape + # assert labels.shape == (B,) + loss = criterion(outputs.reshape(B, -1), inputs.reshape(B, -1)) + # _, preds = torch.max(outputs, 1) + + # backward + optimize only if in training phase + if phase == "train": + loss.backward() + optimizer.step() + + # statistics + running_loss += loss.item() * inputs.size(0) + # running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / len(dataloaders[phase].dataset) + # epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) + + wandb.log( + { + "epoch": epoch, + phase + "/loss": epoch_loss, + } # , phase + "/acc": epoch_acc} + ) + + # deep copy the model + if phase == "val" and epoch_loss < best_loss: + # best_acc = epoch_acc + best_loss = epoch_loss + best_model_wts = copy.deepcopy(model.state_dict()) + + inputs = next(iter(dataloaders[phase]))[0][:8] + outputs = model(inputs) + image_array = torch.concat((inputs, outputs)) + + wandb.log( + { + "examples": wandb.Image( + image_array, caption="Top: Input, Bottom: Output" + ) + } + ) + + print() + + time_elapsed = time.time() - since + print( + "Training complete in {:.0f}m {:.0f}s".format( + time_elapsed // 60, time_elapsed % 60 + ) + ) + # print("Best val Acc: {:4f}".format(best_acc)) + print("Best val Loss: {:4f}".format(best_loss)) # noqa: UP032 + + # load best model weights + model.load_state_dict(best_model_wts) + return model + + +if __name__ == "__main__": + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + torch.manual_seed(42) + + args = {} + + normalize = transforms.Normalize([72.9404 / 255], [90.0212 / 255]) + data_transforms = { + "train": transforms.Compose( + [ + # transforms.RandomResizedCrop(input_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + transforms.Lambda(lambda x: x.to(device)), + ] + ), + "val": transforms.Compose( + [ + # transforms.Resize(input_size), + # transforms.CenterCrop(input_size), + transforms.ToTensor(), + normalize, + transforms.Lambda(lambda x: x.to(device)), + ] + ), + } + + train_data = torchvision.datasets.FashionMNIST( + "fashion-mnist", + train=True, + download=True, + transform=data_transforms["train"], + ) + train_loader = torch.utils.data.DataLoader( + train_data, + batch_size=64, + shuffle=True, + ) + + test_data = torchvision.datasets.FashionMNIST( + "fashion-mnist", train=False, transform=data_transforms["val"] + ) + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=64, + shuffle=True, + ) + + """ + model = resnet18() + model.conv1 = torch.nn.Conv2d( + 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False + ) + model.fc = torch.nn.Linear(in_features=512, out_features=10, bias=True) + model = model.to(device) + """ + # wandb.init(config=args, save_code=True) + model = ResidualAutoencoder(16).to(device) + print(model) + criterion = torch.nn.CrossEntropyLoss() + criterion = nn.MSELoss() + optimizer = torch.optim.AdamW(model.parameters()) + try: + model = train_autoencoder( + model, + {"train": train_loader, "val": test_loader}, + criterion, + optimizer, + num_epochs=5, + ) + torch.save(model.state_dict(), "weights.pkl") + finally: + pass # wandb.finish() diff --git a/src/train.py b/src/train.py index be29094..96d6542 100755 --- a/src/train.py +++ b/src/train.py @@ -2,9 +2,12 @@ import os import tempfile import time +import warnings +from enum import Enum from pathlib import Path import torch +import typer import wandb from pytorch_lightning import Trainer, callbacks, loggers, seed_everything from torchvision import transforms @@ -13,6 +16,13 @@ import models from datamodules import noise +app = typer.Typer(pretty_exceptions_enable=False) + +warnings.filterwarnings("ignore", ".*does not have many workers.*") +warnings.filterwarnings("ignore", ".*but CUDA is not available.*") +warnings.filterwarnings( + "ignore", ".*is supported for historical reasons but its usage is discouraged.*" +) LOGS_DIR = Path(tempfile.gettempdir()) / "logs" @@ -30,35 +40,57 @@ def get_logger(project_name: str): return logger -def get_datamodule(): +def get_datamodule(dataset: str, batch_size: int = 512): return datamodules.ImagesDataModule( # see torchvision.datasets for available datasets - "FashionMNIST", + dataset, num_channels=1, num_classes=10, - batch_size=512 if torch.cuda.is_available() else 64, + batch_size=batch_size, num_workers=os.cpu_count() - 1, - train_transforms=[transforms.CenterCrop(28)], - eval_transforms=[transforms.CenterCrop(28)], + train_transforms=[ + # transforms.RandomHorizontalFlip(), + transforms.CenterCrop(32), + ], + eval_transforms=[ + # transforms.RandomHorizontalFlip(), + transforms.CenterCrop(32), + ], target_is_self=True, - noise_transforms=[noise.GaussianNoise(0.1), noise.SaltPepperNoise(0.1, 0.1)], + noise_transforms=[ + transforms.RandomApply([transforms.RandomErasing()], p=0.5), + transforms.RandomApply([noise.SaltPepperNoise(0.05, 0.05)], p=0.5), + transforms.RandomApply([noise.GaussianNoise(0.05)], p=0.5), + ], ) -def get_model(num_channels): - return models.FullyConnectedAutoEncoder( +def get_model( + num_channels: int, + latent_dim: int = 32, + latent_noise: float = 0.1, + channels: tuple[int, int, int, int] = (16, 16, 32, 32), + kl_weight=0.005, +): + return models.ConvVAE( + latent_dim=latent_dim, + image_size=32, + latent_noise=latent_noise, num_channels=num_channels, - hidden_sizes=(256, 64, 8), - encoder_last_layer=torch.nn.LayerNorm, - encoder_last_layer_args=(8,), - decoder_last_layer=torch.nn.Identity, - decoder_last_layer_args=(), + channels=channels, + kl_weight=kl_weight, + # # FullyConnectedAutoEncoder + # hidden_sizes=(256, 64, 8), + # encoder_last_layer=torch.nn.LayerNorm, + # encoder_last_layer_args=(8,), + # decoder_last_layer=torch.nn.Identity, + # decoder_last_layer_args=(), # # SGD # optimizer_cls=torch.optim.SGD, # optimizer_kwargs=dict(lr=0.1, momentum=0.9, weight_decay=5e-4), # # AdamW optimizer_cls=torch.optim.AdamW, - optimizer_kwargs=dict(lr=0.01), + optimizer_kwargs=dict(lr=0.0003), # # ReduceLROnPlateau # scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, # scheduler_kwargs=dict(patience=1, threshold=0.05, factor=0.1), @@ -78,67 +110,95 @@ def get_model(num_channels): # optimizer_cls=torch.optim.Adam, # optimizer_kwargs=dict(lr=0.05), # # ExponentialLR - scheduler_cls=torch.optim.lr_scheduler.ExponentialLR, - scheduler_kwargs=dict(gamma=0.95), - scheduler_interval="epoch", - scheduler_add_total_steps=False, + # scheduler_cls=torch.optim.lr_scheduler.ExponentialLR, + # scheduler_kwargs=dict(gamma=0.95), + # scheduler_interval="epoch", + # scheduler_add_total_steps=False, ) -def train(seed): +class AvailableDatasets(str, Enum): + FashionMNIST = "FashionMNIST" + KMNIST = "KMNIST" + + +@app.command() +def train( + seed: int = 42, + max_epochs: int = 50, + latent_dim: int = 32, + latent_noise: float = 0.1, + channels: tuple[int, int, int, int] = (32, 64, 128, 256), + checkpoint_path: str = None, + batch_size: int = 2048, + kl_weight: float = 0.005, + dataset: AvailableDatasets = AvailableDatasets.FashionMNIST, +): seed = seed_everything(seed) - datamodule = get_datamodule() - model = get_model(datamodule.num_channels) - logger = get_logger( - project_name=f"{type(model).__name__.lower()}-{datamodule.dataset_name.lower()}" + datamodule = get_datamodule(batch_size=batch_size, dataset=dataset.value) + model = get_model( + num_channels=datamodule.num_channels, + latent_dim=latent_dim, + latent_noise=latent_noise, + channels=channels, + kl_weight=kl_weight, ) # trainer settings - max_epochs = 30 trainer_callbacks = [ - callbacks.EarlyStopping("loss/validation", min_delta=0.001), + # callbacks.EarlyStopping("loss/validation", min_delta=0.0, patience=10), + # callbacks.StochasticWeightAveraging(swa_lrs=1e-2), ] # set precision - torch.set_float32_matmul_precision("medium") - precision = "bf16-mixed" + precision = 16 + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("medium") + precision = "16-mixed" # fast_dev_run, to prevent logging of failed runs trainer_fast = Trainer( + accelerator="auto", fast_dev_run=True, enable_model_summary=False, enable_progress_bar=False, precision=precision, + logger=False, + callbacks=[ + callbacks.RichModelSummary(max_depth=4), + ], ) trainer_fast.fit(model, datamodule=datamodule) # set trainer + logger = get_logger( + project_name=f"{type(model).__name__.lower()}-{datamodule.dataset_name.lower()}" + ) trainer = Trainer( + accelerator="auto", max_epochs=max_epochs, logger=logger, callbacks=[ - callbacks.RichModelSummary(max_depth=2), callbacks.RichProgressBar(), - callbacks.LearningRateMonitor(logging_interval="step"), + callbacks.LearningRateMonitor(logging_interval="step", log_momentum=True), *trainer_callbacks, ], precision=precision, enable_model_summary=False, + log_every_n_steps=5 if len(datamodule.train_dataloader()) > 5 else 1, ) trainer.logger.log_hyperparams({"seed": seed}) # run trainer trainer.test(model, datamodule=datamodule, verbose=False) t_start = time.time() - trainer.fit(model, datamodule=datamodule) + trainer.fit( + model, datamodule=datamodule, ckpt_path=checkpoint_path and str(checkpoint_path) + ) t_total = time.time() - t_start trainer.logger.log_metrics({"trainer/total_time": t_total}) trainer.test(model, datamodule=datamodule) -def main(): - train(seed=None) - - if __name__ == "__main__": - main() + app()