diff --git a/diffusers/geodiff_molecule_conformation.ipynb b/diffusers/geodiff_molecule_conformation.ipynb new file mode 100644 index 00000000..ad337e7b --- /dev/null +++ b/diffusers/geodiff_molecule_conformation.ipynb @@ -0,0 +1,3653 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F88mignPnalS" + }, + "source": [ + "# Introduction\n", + "\n", + "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n", + "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n", + "\n", + "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule. \n", + "\n", + "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n", + "\n", + "> Colab made by [natolambert](https://twitter.com/natolambert).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7cnwXMocnuzB" + }, + "source": [ + "## Installations\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Install Conda" + ], + "metadata": { + "id": "ff9SxWnaNId9" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1g_6zOabItDk" + }, + "source": [ + "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K0ofXobG5Y-X", + "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "nvcc: NVIDIA (R) Cuda compiler driver\n", + "Copyright (c) 2005-2021 NVIDIA Corporation\n", + "Built on Sun_Feb_14_21:12:58_PST_2021\n", + "Cuda compilation tools, release 11.2, V11.2.152\n", + "Build cuda_11.2.r11.2/compiler.29618528_0\n" + ] + } + ], + "source": [ + "!nvcc --version" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VfthW90vI0nw" + }, + "source": [ + "Install Conda for some more complex dependencies for geometric networks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2WNFzSnbiE0k", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q condacolab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NUsbWYCUI7Km" + }, + "source": [ + "Setup Conda" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZelreINdmd0", + "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "✨🍰✨ Everything looks OK!\n" + ] + } + ], + "source": [ + "import condacolab\n", + "condacolab.install()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JzDHaPU7I9Sn" + }, + "source": [ + "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JMxRjHhL7w8V", + "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n", + "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - cudatoolkit=11.1\n", + " - pytorch\n", + " - torchaudio\n", + " - torchvision\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 960 KB\n", + "\n", + "The following packages will be UPDATED:\n", + "\n", + " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n", + "Preparing transaction: / \b\bdone\n", + "Verifying transaction: \\ \b\bdone\n", + "Executing transaction: / \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n", + "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Need to remove a pathspec for colab that specifies the incorrect cuda version." + ], + "metadata": { + "id": "QDS6FPZ0Tu5b" + } + }, + { + "cell_type": "code", + "source": [ + "!rm /usr/local/conda-meta/pinned" + ], + "metadata": { + "id": "dq1lxR10TtrR", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1L3DdZOJB30" + }, + "source": [ + "Install torch geometric (used in the model later)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D5ukfCOWfjzK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "\n", + "## Package Plan ##\n", + "\n", + " environment location: /usr/local\n", + "\n", + " added / updated specs:\n", + " - pytorch-geometric=1.7.2\n", + "\n", + "\n", + "The following packages will be downloaded:\n", + "\n", + " package | build\n", + " ---------------------------|-----------------\n", + " decorator-4.4.2 | py_0 11 KB conda-forge\n", + " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n", + " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n", + " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n", + " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n", + " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n", + " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n", + " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n", + " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n", + " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n", + " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n", + " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n", + " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n", + " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n", + " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n", + " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n", + " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n", + " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n", + " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n", + " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n", + " ------------------------------------------------------------\n", + " Total: 55.9 MB\n", + "\n", + "The following NEW packages will be INSTALLED:\n", + "\n", + " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n", + " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n", + " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n", + " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n", + " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n", + " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n", + " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n", + " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n", + " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n", + " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n", + " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n", + " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n", + " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n", + " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n", + " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n", + " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n", + " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n", + " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n", + " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n", + "\n", + "The following packages will be DOWNGRADED:\n", + "\n", + " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n", + "\n", + "\n", + "\n", + "Downloading and Extracting Packages\n", + "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n", + "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n", + "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n", + "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n", + "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n", + "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n", + "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n", + "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n", + "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n", + "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n", + "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n", + "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n", + "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n", + "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n", + "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n", + "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n", + "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n", + "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n", + "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n", + "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n", + "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n", + "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n", + "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n", + "Retrieving notices: ...working... done\n" + ] + } + ], + "source": [ + "!conda install -c rusty1s pytorch-geometric=1.7.2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgQA_XN-XGY2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85392615-b6a4-4052-9d2a-79604be62c94" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 9298, done.\u001b[K\n", + "remote: Counting objects: 100% (40/40), done.\u001b[K\n", + "remote: Compressing objects: 100% (23/23), done.\u001b[K\n", + "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n", + "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n", + "Resolving deltas: 100% (6168/6168), done.\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%cd /content\n", + "\n", + "# install latest HF diffusers (will update to the release once added)\n", + "!git clone https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers \n", + "\n", + "# dependencies for diffusers\n", + "!pip install -q datasets transformers " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LZO6AJKuJKO8" + }, + "source": [ + "Check that torch is installed correctly and utilizing the GPU in the colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gZt7BNi1e1PA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "True\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'1.8.2'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 8 + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available())\n", + "torch.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLE7CqlfJNUO" + }, + "source": [ + "### Install Chemistry-specific Dependencies\n", + "\n", + "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0CPv_NvehRz3", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting rdkit\n", + " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n", + "Installing collected packages: rdkit\n", + "Successfully installed rdkit-2022.3.5\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install rdkit" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "88GaDbDPxJ5I" + }, + "source": [ + "### Get viewer from nglview\n", + "\n", + "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors). \n", + "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed). \n", + "The rdmol in this object is a source of ground truth for the generated molecules.\n", + "\n", + "You will use one rendering function from nglviewer later!\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jcl8GCS2mz6t", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting nglview\n", + " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n", + "Collecting jupyterlab-widgets\n", + " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipywidgets>=7\n", + " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting widgetsnbextension~=4.0\n", + " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipython>=6.1.0\n", + " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ipykernel>=4.5.1\n", + " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting traitlets>=4.3.1\n", + " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n", + "Collecting pyzmq>=17\n", + " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting matplotlib-inline>=0.1\n", + " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n", + "Collecting tornado>=6.1\n", + " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nest-asyncio\n", + " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n", + "Collecting debugpy>=1.0\n", + " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting psutil\n", + " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jupyter-client>=6.1.12\n", + " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pickleshare\n", + " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n", + "Collecting backcall\n", + " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n", + "Collecting pexpect>4.3\n", + " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pygments\n", + " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jedi>=0.16\n", + " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n", + " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n", + "Collecting parso<0.9.0,>=0.8.0\n", + " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n", + "Collecting entrypoints\n", + " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n", + "Collecting jupyter-core>=4.9.2\n", + " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ptyprocess>=0.5\n", + " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n", + "Collecting wcwidth\n", + " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n", + "Building wheels for collected packages: nglview\n", + " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n", + " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n", + "Successfully built nglview\n", + "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n", + "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "pexpect", + "pickleshare", + "wcwidth" + ] + } + } + }, + "metadata": {} + } + ], + "source": [ + "!pip install nglview" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Create a diffusion model" + ], + "metadata": { + "id": "8t8_e_uVLdKB" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Model class(es)" + ], + "metadata": { + "id": "G0rMncVtNSqU" + } + }, + { + "cell_type": "markdown", + "source": [ + "Imports" + ], + "metadata": { + "id": "L5FEXz5oXkzt" + } + }, + { + "cell_type": "code", + "source": [ + "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n", + "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n", + "from dataclasses import dataclass\n", + "from typing import Callable, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor, nn\n", + "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n", + "\n", + "from torch_geometric.nn import MessagePassing, radius, radius_graph\n", + "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n", + "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n", + "from torch_scatter import scatter_add\n", + "from torch_sparse import SparseTensor, coalesce\n", + "\n", + "from diffusers.configuration_utils import ConfigMixin, register_to_config\n", + "from diffusers.modeling_utils import ModelMixin\n", + "from diffusers.utils import BaseOutput\n" + ], + "metadata": { + "id": "-3-P4w5sXkRU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Helper classes" + ], + "metadata": { + "id": "EzJQXPN_XrMX" + } + }, + { + "cell_type": "code", + "source": [ + "@dataclass\n", + "class MoleculeGNNOutput(BaseOutput):\n", + " \"\"\"\n", + " Args:\n", + " sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n", + " Hidden states output. Output of last layer of model.\n", + " \"\"\"\n", + "\n", + " sample: torch.FloatTensor\n", + "\n", + "\n", + "class MultiLayerPerceptron(nn.Module):\n", + " \"\"\"\n", + " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n", + " Args:\n", + " input_dim (int): input dimension\n", + " hidden_dim (list of int): hidden dimensions\n", + " activation (str or function, optional): activation function\n", + " dropout (float, optional): dropout rate\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n", + " super(MultiLayerPerceptron, self).__init__()\n", + "\n", + " self.dims = [input_dim] + hidden_dims\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " print(f\"Warning, activation passed {activation} is not string and ignored\")\n", + " self.activation = None\n", + " if dropout > 0:\n", + " self.dropout = nn.Dropout(dropout)\n", + " else:\n", + " self.dropout = None\n", + "\n", + " self.layers = nn.ModuleList()\n", + " for i in range(len(self.dims) - 1):\n", + " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\"\"\"\n", + " for i, layer in enumerate(self.layers):\n", + " x = layer(x)\n", + " if i < len(self.layers) - 1:\n", + " if self.activation:\n", + " x = self.activation(x)\n", + " if self.dropout:\n", + " x = self.dropout(x)\n", + " return x\n", + "\n", + "\n", + "class ShiftedSoftplus(torch.nn.Module):\n", + " def __init__(self):\n", + " super(ShiftedSoftplus, self).__init__()\n", + " self.shift = torch.log(torch.tensor(2.0)).item()\n", + "\n", + " def forward(self, x):\n", + " return F.softplus(x) - self.shift\n", + "\n", + "\n", + "class CFConv(MessagePassing):\n", + " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n", + " super(CFConv, self).__init__(aggr=\"add\")\n", + " self.lin1 = Linear(in_channels, num_filters, bias=False)\n", + " self.lin2 = Linear(num_filters, out_channels)\n", + " self.nn = mlp\n", + " self.cutoff = cutoff\n", + " self.smooth = smooth\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " torch.nn.init.xavier_uniform_(self.lin1.weight)\n", + " torch.nn.init.xavier_uniform_(self.lin2.weight)\n", + " self.lin2.bias.data.fill_(0)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " if self.smooth:\n", + " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n", + " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n", + " else:\n", + " C = (edge_length <= self.cutoff).float()\n", + " W = self.nn(edge_attr) * C.view(-1, 1)\n", + "\n", + " x = self.lin1(x)\n", + " x = self.propagate(edge_index, x=x, W=W)\n", + " x = self.lin2(x)\n", + " return x\n", + "\n", + " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n", + " return x_j * W\n", + "\n", + "\n", + "class InteractionBlock(torch.nn.Module):\n", + " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n", + " super(InteractionBlock, self).__init__()\n", + " mlp = Sequential(\n", + " Linear(num_gaussians, num_filters),\n", + " ShiftedSoftplus(),\n", + " Linear(num_filters, num_filters),\n", + " )\n", + " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n", + " self.act = ShiftedSoftplus()\n", + " self.lin = Linear(hidden_channels, hidden_channels)\n", + "\n", + " def forward(self, x, edge_index, edge_length, edge_attr):\n", + " x = self.conv(x, edge_index, edge_length, edge_attr)\n", + " x = self.act(x)\n", + " x = self.lin(x)\n", + " return x\n", + "\n", + "\n", + "class SchNetEncoder(Module):\n", + " def __init__(\n", + " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.hidden_channels = hidden_channels\n", + " self.num_filters = num_filters\n", + " self.num_interactions = num_interactions\n", + " self.cutoff = cutoff\n", + "\n", + " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n", + "\n", + " self.interactions = ModuleList()\n", + " for _ in range(num_interactions):\n", + " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n", + " self.interactions.append(block)\n", + "\n", + " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n", + " if embed_node:\n", + " assert z.dim() == 1 and z.dtype == torch.long\n", + " h = self.embedding(z)\n", + " else:\n", + " h = z\n", + " for interaction in self.interactions:\n", + " h = h + interaction(h, edge_index, edge_length, edge_attr)\n", + "\n", + " return h\n", + "\n", + "\n", + "class GINEConv(MessagePassing):\n", + " \"\"\"\n", + " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n", + " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n", + " \"\"\"\n", + "\n", + " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n", + " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n", + " self.nn = mlp\n", + " self.initial_eps = eps\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " if train_eps:\n", + " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n", + " else:\n", + " self.register_buffer(\"eps\", torch.Tensor([eps]))\n", + "\n", + " def forward(\n", + " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\"\"\"\n", + " if isinstance(x, torch.Tensor):\n", + " x: OptPairTensor = (x, x)\n", + "\n", + " # Node and edge feature dimensionalites need to match.\n", + " if isinstance(edge_index, torch.Tensor):\n", + " assert edge_attr is not None\n", + " assert x[0].size(-1) == edge_attr.size(-1)\n", + " elif isinstance(edge_index, SparseTensor):\n", + " assert x[0].size(-1) == edge_index.size(-1)\n", + "\n", + " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n", + " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n", + "\n", + " x_r = x[1]\n", + " if x_r is not None:\n", + " out += (1 + self.eps) * x_r\n", + "\n", + " return self.nn(out)\n", + "\n", + " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n", + " if self.activation:\n", + " return self.activation(x_j + edge_attr)\n", + " else:\n", + " return x_j + edge_attr\n", + "\n", + " def __repr__(self):\n", + " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n", + "\n", + "\n", + "class GINEncoder(torch.nn.Module):\n", + " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n", + " super().__init__()\n", + "\n", + " self.hidden_dim = hidden_dim\n", + " self.num_convs = num_convs\n", + " self.short_cut = short_cut\n", + " self.concat_hidden = concat_hidden\n", + " self.node_emb = nn.Embedding(100, hidden_dim)\n", + "\n", + " if isinstance(activation, str):\n", + " self.activation = getattr(F, activation)\n", + " else:\n", + " self.activation = None\n", + "\n", + " self.convs = nn.ModuleList()\n", + " for i in range(self.num_convs):\n", + " self.convs.append(\n", + " GINEConv(\n", + " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n", + " activation=activation,\n", + " )\n", + " )\n", + "\n", + " def forward(self, z, edge_index, edge_attr):\n", + " \"\"\"\n", + " Input:\n", + " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n", + " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n", + " Output:\n", + " node_feature: graph feature\n", + " \"\"\"\n", + "\n", + " node_attr = self.node_emb(z) # (num_node, hidden)\n", + "\n", + " hiddens = []\n", + " conv_input = node_attr # (num_node, hidden)\n", + "\n", + " for conv_idx, conv in enumerate(self.convs):\n", + " hidden = conv(conv_input, edge_index, edge_attr)\n", + " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n", + " hidden = self.activation(hidden)\n", + " assert hidden.shape == conv_input.shape\n", + " if self.short_cut and hidden.shape == conv_input.shape:\n", + " hidden += conv_input\n", + "\n", + " hiddens.append(hidden)\n", + " conv_input = hidden\n", + "\n", + " if self.concat_hidden:\n", + " node_feature = torch.cat(hiddens, dim=-1)\n", + " else:\n", + " node_feature = hiddens[-1]\n", + "\n", + " return node_feature\n", + "\n", + "\n", + "class MLPEdgeEncoder(Module):\n", + " def __init__(self, hidden_dim=100, activation=\"relu\"):\n", + " super().__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n", + " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n", + "\n", + " @property\n", + " def out_channels(self):\n", + " return self.hidden_dim\n", + "\n", + " def forward(self, edge_length, edge_type):\n", + " \"\"\"\n", + " Input:\n", + " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n", + " Returns:\n", + " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n", + " \"\"\"\n", + " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n", + " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n", + " return d_emb * edge_attr # (num_edge, hidden)\n", + "\n", + "\n", + "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n", + " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n", + " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n", + " return h_pair\n", + "\n", + "\n", + "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n", + " \"\"\"\n", + " Args:\n", + " num_nodes: Number of atoms.\n", + " edge_index: Bond indices of the original graph.\n", + " edge_type: Bond types of the original graph.\n", + " order: Extension order.\n", + " Returns:\n", + " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n", + " \"\"\"\n", + "\n", + " def binarize(x):\n", + " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n", + "\n", + " def get_higher_order_adj_matrix(adj, order):\n", + " \"\"\"\n", + " Args:\n", + " adj: (N, N)\n", + " type_mat: (N, N)\n", + " Returns:\n", + " Following attributes will be updated:\n", + " - edge_index\n", + " - edge_type\n", + " Following attributes will be added to the data object:\n", + " - bond_edge_index: Original edge_index.\n", + " \"\"\"\n", + " adj_mats = [\n", + " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n", + " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n", + " ]\n", + "\n", + " for i in range(2, order + 1):\n", + " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n", + " order_mat = torch.zeros_like(adj)\n", + "\n", + " for i in range(1, order + 1):\n", + " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n", + "\n", + " return order_mat\n", + "\n", + " num_types = 22\n", + " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n", + " # from rdkit.Chem.rdchem import BondType as BT\n", + " N = num_nodes\n", + " adj = to_dense_adj(edge_index).squeeze(0)\n", + " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n", + "\n", + " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n", + " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n", + " assert (type_mat * type_highorder == 0).all()\n", + " type_new = type_mat + type_highorder\n", + "\n", + " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n", + " _, edge_order = dense_to_sparse(adj_order)\n", + "\n", + " # data.bond_edge_index = data.edge_index # Save original edges\n", + " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n", + " assert edge_type.dim() == 1\n", + " N = pos.size(0)\n", + "\n", + " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n", + "\n", + " if is_sidechain is None:\n", + " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n", + " else:\n", + " # fetch sidechain and its batch index\n", + " is_sidechain = is_sidechain.bool()\n", + " dummy_index = torch.arange(pos.size(0), device=pos.device)\n", + " sidechain_pos = pos[is_sidechain]\n", + " sidechain_index = dummy_index[is_sidechain]\n", + " sidechain_batch = batch[is_sidechain]\n", + "\n", + " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n", + " r_edge_index_x = assign_index[1]\n", + " r_edge_index_y = assign_index[0]\n", + " r_edge_index_y = sidechain_index[r_edge_index_y]\n", + "\n", + " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n", + " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n", + " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n", + " # delete self loop\n", + " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n", + "\n", + " rgraph_adj = torch.sparse.LongTensor(\n", + " rgraph_edge_index,\n", + " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n", + " torch.Size([N, N]),\n", + " )\n", + "\n", + " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n", + "\n", + " new_edge_index = composed_adj.indices()\n", + " new_edge_type = composed_adj.values().long()\n", + "\n", + " return new_edge_index, new_edge_type\n", + "\n", + "\n", + "def extend_graph_order_radius(\n", + " num_nodes,\n", + " pos,\n", + " edge_index,\n", + " edge_type,\n", + " batch,\n", + " order=3,\n", + " cutoff=10.0,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + "):\n", + " if extend_order:\n", + " edge_index, edge_type = _extend_graph_order(\n", + " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n", + " )\n", + "\n", + " if extend_radius:\n", + " edge_index, edge_type = _extend_to_radius_graph(\n", + " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n", + " )\n", + "\n", + " return edge_index, edge_type\n", + "\n", + "\n", + "def get_distance(pos, edge_index):\n", + " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n", + "\n", + "\n", + "def graph_field_network(score_d, pos, edge_index, edge_length):\n", + " \"\"\"\n", + " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n", + " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n", + " \"\"\"\n", + " N = pos.size(0)\n", + " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n", + " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n", + " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n", + " ) # (N, 3)\n", + " return score_pos\n", + "\n", + "\n", + "def clip_norm(vec, limit, p=2):\n", + " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n", + " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n", + " return vec * denom\n", + "\n", + "\n", + "def is_local_edge(edge_type):\n", + " return edge_type > 0\n" + ], + "metadata": { + "id": "oR1Y56QiLY90" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Main model class!" + ], + "metadata": { + "id": "QWrHJFcYXyUB" + } + }, + { + "cell_type": "code", + "source": [ + "class MoleculeGNN(ModelMixin, ConfigMixin):\n", + " @register_to_config\n", + " def __init__(\n", + " self,\n", + " hidden_dim=128,\n", + " num_convs=6,\n", + " num_convs_local=4,\n", + " cutoff=10.0,\n", + " mlp_act=\"relu\",\n", + " edge_order=3,\n", + " edge_encoder=\"mlp\",\n", + " smooth_conv=True,\n", + " ):\n", + " super().__init__()\n", + " self.cutoff = cutoff\n", + " self.edge_encoder = edge_encoder\n", + " self.edge_order = edge_order\n", + "\n", + " \"\"\"\n", + " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n", + " in SchNetEncoder\n", + " \"\"\"\n", + " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n", + "\n", + " \"\"\"\n", + " The graph neural network that extracts node-wise features.\n", + " \"\"\"\n", + " self.encoder_global = SchNetEncoder(\n", + " hidden_channels=hidden_dim,\n", + " num_filters=hidden_dim,\n", + " num_interactions=num_convs,\n", + " edge_channels=self.edge_encoder_global.out_channels,\n", + " cutoff=cutoff,\n", + " smooth=smooth_conv,\n", + " )\n", + " self.encoder_local = GINEncoder(\n", + " hidden_dim=hidden_dim,\n", + " num_convs=num_convs_local,\n", + " )\n", + "\n", + " \"\"\"\n", + " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n", + " gradients w.r.t. edge_length (out_dim = 1).\n", + " \"\"\"\n", + " self.grad_global_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " self.grad_local_dist_mlp = MultiLayerPerceptron(\n", + " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n", + " )\n", + "\n", + " \"\"\"\n", + " Incorporate parameters together\n", + " \"\"\"\n", + " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n", + " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n", + "\n", + " def _forward(\n", + " self,\n", + " atom_type,\n", + " pos,\n", + " bond_index,\n", + " bond_type,\n", + " batch,\n", + " time_step, # NOTE, model trained without timestep performed best\n", + " edge_index=None,\n", + " edge_type=None,\n", + " edge_length=None,\n", + " return_edges=False,\n", + " extend_order=True,\n", + " extend_radius=True,\n", + " is_sidechain=None,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " atom_type: Types of atoms, (N, ).\n", + " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n", + " bond_type: Bond types, (E, ).\n", + " batch: Node index to graph index, (N, ).\n", + " \"\"\"\n", + " N = atom_type.size(0)\n", + " if edge_index is None or edge_type is None or edge_length is None:\n", + " edge_index, edge_type = extend_graph_order_radius(\n", + " num_nodes=N,\n", + " pos=pos,\n", + " edge_index=bond_index,\n", + " edge_type=bond_type,\n", + " batch=batch,\n", + " order=self.edge_order,\n", + " cutoff=self.cutoff,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " is_sidechain=is_sidechain,\n", + " )\n", + " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n", + " local_edge_mask = is_local_edge(edge_type) # (E, )\n", + "\n", + " # with the parameterization of NCSNv2\n", + " # DDPM loss implicit handle the noise variance scale conditioning\n", + " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n", + "\n", + " # Encoding global\n", + " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + "\n", + " # Global\n", + " node_attr_global = self.encoder_global(\n", + " z=atom_type,\n", + " edge_index=edge_index,\n", + " edge_length=edge_length,\n", + " edge_attr=edge_attr_global,\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_global = assemble_atom_pair_feature(\n", + " node_attr=node_attr_global,\n", + " edge_index=edge_index,\n", + " edge_attr=edge_attr_global,\n", + " ) # (E_global, 2H)\n", + " # Invariant features of edges (radius graph, global)\n", + " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n", + "\n", + " # Encoding local\n", + " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n", + " # edge_attr += temb_edge\n", + "\n", + " # Local\n", + " node_attr_local = self.encoder_local(\n", + " z=atom_type,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " )\n", + " # Assemble pairwise features\n", + " h_pair_local = assemble_atom_pair_feature(\n", + " node_attr=node_attr_local,\n", + " edge_index=edge_index[:, local_edge_mask],\n", + " edge_attr=edge_attr_local[local_edge_mask],\n", + " ) # (E_local, 2H)\n", + "\n", + " # Invariant features of edges (bond graph, local)\n", + " if isinstance(sigma_edge, torch.Tensor):\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n", + " 1.0 / sigma_edge[local_edge_mask]\n", + " ) # (E_local, 1)\n", + " else:\n", + " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n", + "\n", + " if return_edges:\n", + " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n", + " else:\n", + " return edge_inv_global, edge_inv_local\n", + "\n", + " def forward(\n", + " self,\n", + " sample,\n", + " timestep: Union[torch.Tensor, float, int],\n", + " return_dict: bool = True,\n", + " sigma=1.0,\n", + " global_start_sigma=0.5,\n", + " w_global=1.0,\n", + " extend_order=False,\n", + " extend_radius=True,\n", + " clip_local=None,\n", + " clip_global=1000.0,\n", + " ) -> Union[MoleculeGNNOutput, Tuple]:\n", + " r\"\"\"\n", + " Args:\n", + " sample: packed torch geometric object\n", + " timestep (`torch.FloatTensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n", + " return_dict (`bool`, *optional*, defaults to `True`):\n", + " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n", + " Returns:\n", + " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n", + " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n", + " \"\"\"\n", + "\n", + " # unpack sample\n", + " atom_type = sample.atom_type\n", + " bond_index = sample.edge_index\n", + " bond_type = sample.edge_type\n", + " num_graphs = sample.num_graphs\n", + " pos = sample.pos\n", + "\n", + " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n", + "\n", + " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n", + " atom_type=atom_type,\n", + " pos=sample.pos,\n", + " bond_index=bond_index,\n", + " bond_type=bond_type,\n", + " batch=sample.batch,\n", + " time_step=timesteps,\n", + " return_edges=True,\n", + " extend_order=extend_order,\n", + " extend_radius=extend_radius,\n", + " ) # (E_global, 1), (E_local, 1)\n", + "\n", + " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n", + " node_eq_local = graph_field_network(\n", + " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n", + " )\n", + " if clip_local is not None:\n", + " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n", + "\n", + " # Global\n", + " if sigma < global_start_sigma:\n", + " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n", + " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n", + " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n", + " else:\n", + " node_eq_global = 0\n", + "\n", + " # Sum\n", + " eps_pos = node_eq_local + node_eq_global * w_global\n", + "\n", + " if not return_dict:\n", + " return (-eps_pos,)\n", + "\n", + " return MoleculeGNNOutput(sample=torch.FloatTensor(-eps_pos).to(pos.device))" + ], + "metadata": { + "id": "MCeZA1qQXzoK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CCIrPYSJj9wd" + }, + "source": [ + "### Load pretrained model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YdrAr6Ch--Ab" + }, + "source": [ + "#### Load a model\n", + "The model used is a design an\n", + "equivariant convolutional layer, named graph field network (GFN).\n", + "\n", + "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DyCo0nsqjbml", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 172, + "referenced_widgets": [ + "d90f304e9560472eacfbdd11e46765eb", + "1c6246f15b654f4daa11c9bcf997b78c", + "c2321b3bff6f490ca12040a20308f555", + "b7feb522161f4cf4b7cc7c1a078ff12d", + "e2d368556e494ae7ae4e2e992af2cd4f", + "bbef741e76ec41b7ab7187b487a383df", + "561f742d418d4721b0670cc8dd62e22c", + "872915dd1bb84f538c44e26badabafdd", + "d022575f1fa2446d891650897f187b4d", + "fdc393f3468c432aa0ada05e238a5436", + "2c9362906e4b40189f16d14aa9a348da", + "6010fc8daa7a44d5aec4b830ec2ebaa1", + "7e0bb1b8d65249d3974200686b193be2", + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "6526646be5ed415c84d1245b040e629b", + "24d31fc3576e43dd9f8301d2ef3a37ab", + "2918bfaadc8d4b1a9832522c40dfefb8", + "a4bfdca35cc54dae8812720f1b276a08", + "e4901541199b45c6a18824627692fc39", + "f915cf874246446595206221e900b2fe", + "a9e388f22a9742aaaf538e22575c9433", + "42f6c3db29d7484ba6b4f73590abd2f4" + ] + }, + "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/3.27M [00:00] 124.78K 180KB/s in 0.7s \n", + "\n", + "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n", + "\n" + ] + } + ], + "source": [ + "import torch \n", + "import numpy as np\n", + "\n", + "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n", + "dataset = torch.load('/content/molecules.pkl')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QZcmy1EvKQRk" + }, + "source": [ + "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JVjz6iH_H6Eh", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run the diffusion process" + ], + "metadata": { + "id": "vHNiZAUxNgoy" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZ1KZrxKqENg" + }, + "source": [ + "#### Helper Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s240tYueqKKf" + }, + "outputs": [], + "source": [ + "from torch_geometric.data import Data, Batch\n", + "from torch_scatter import scatter_add, scatter_mean\n", + "from tqdm import tqdm\n", + "import copy\n", + "import os\n", + "\n", + "def repeat_data(data: Data, num_repeat) -> Batch:\n", + " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n", + " return Batch.from_data_list(datas)\n", + "\n", + "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n", + " datas = batch.to_data_list()\n", + " new_data = []\n", + " for i in range(num_repeat):\n", + " new_data += copy.deepcopy(datas)\n", + " return Batch.from_data_list(new_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AMnQTk0eqT7Z" + }, + "source": [ + "#### Constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WYGkzqgzrHmF" + }, + "outputs": [], + "source": [ + "num_samples = 1 # solutions per molecule\n", + "num_molecules = 3\n", + "\n", + "DEVICE = 'cuda'\n", + "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n", + "# constants for inference\n", + "w_global = 0.5 #0,.3 for qm9\n", + "global_start_sigma = 0.5\n", + "eta = 1.0\n", + "clip_local = None \n", + "clip_pos = None\n", + "\n", + "# constands for data handling\n", + "save_traj = False\n", + "save_data = False\n", + "output_dir = '/content/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-xD5bJ3SqM7t" + }, + "source": [ + "#### Generate samples!\n", + "Note that the 3d representation of a molecule is referred to as the **conformation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x9xuLUNg26z1", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " after removing the cwd from sys.path.\n", + "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n" + ] + } + ], + "source": [ + "results = []\n", + "\n", + "# define sigmas\n", + "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n", + "sigmas = sigmas.to(DEVICE)\n", + "\n", + "for count, data in enumerate(tqdm(dataset)): \n", + " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n", + "\n", + " data_input = data.clone()\n", + " data_input['pos_ref'] = None\n", + " batch = repeat_data(data_input, num_samples).to(DEVICE)\n", + "\n", + " # initial configuration\n", + " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n", + "\n", + " # for logging animation of denoising\n", + " pos_traj = []\n", + " with torch.no_grad():\n", + "\n", + " # scale initial sample\n", + " pos = pos_init * sigmas[-1]\n", + " for t in scheduler.timesteps:\n", + " batch.pos = pos\n", + "\n", + " # generate geometry with model, then filter it\n", + " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n", + "\n", + " # Update\n", + " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n", + "\n", + " pos = reconstructed_pos\n", + "\n", + " if torch.isnan(pos).any():\n", + " print(\"NaN detected. Please restart.\")\n", + " raise FloatingPointError()\n", + "\n", + " # recenter graph of positions for next iteration\n", + " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n", + "\n", + " # optional clipping\n", + " if clip_pos is not None:\n", + " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n", + " pos_traj.append(pos.clone().cpu())\n", + "\n", + " pos_gen = pos.cpu()\n", + " if save_traj:\n", + " pos_gen_traj = pos_traj.cpu()\n", + " data.pos_gen = torch.stack(pos_gen_traj)\n", + " else:\n", + " data.pos_gen = pos_gen\n", + " results.append(data)\n", + "\n", + "\n", + "if save_data:\n", + " save_path = os.path.join(output_dir, 'samples_all.pkl')\n", + "\n", + " with open(save_path, 'wb') as f:\n", + " pickle.dump(results, f)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Render the results!" + ], + "metadata": { + "id": "fSApwSaZNndW" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d47Zxo2OKdgZ" + }, + "source": [ + "This function allows us to render 3d in colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9Cd0kCAv9b8" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "output.enable_custom_widget_manager()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Helper functions" + ], + "metadata": { + "id": "RjaVuR15NqzF" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "28rBYa9NKhlz" + }, + "source": [ + "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKdKdwxcyTQ6" + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "def set_rdmol_positions(rdkit_mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " mol = deepcopy(rdkit_mol)\n", + " set_rdmol_positions_(mol, pos)\n", + " return mol\n", + "\n", + "def set_rdmol_positions_(mol, pos):\n", + " \"\"\"\n", + " Args:\n", + " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n", + " pos: (N_atoms, 3)\n", + " \"\"\"\n", + " for i in range(pos.shape[0]):\n", + " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n", + " return mol\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NuE10hcpKmzK" + }, + "source": [ + "Process the generated data to make it easy to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KieVE1vc0_Vs", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "collect 5 generated molecules in `mols`\n" + ] + } + ], + "source": [ + "# the model can generate multiple conformations per 2d geometry\n", + "num_gen = results[0]['pos_gen'].shape[0]\n", + "\n", + "# init storage objects\n", + "mols_gen = []\n", + "mols_orig = []\n", + "for to_process in results:\n", + "\n", + " # store the reference 3d position\n", + " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # store the generated 3d position\n", + " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n", + "\n", + " # copy data to new object\n", + " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n", + "\n", + " # append results\n", + " mols_gen.append(new_mol)\n", + " mols_orig.append(to_process.rdmol)\n", + "\n", + "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tin89JwMKp4v" + }, + "source": [ + "Import tools to visualize the 2d chemical diagram of the molecule." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yqV6gllSZn38" + }, + "outputs": [], + "source": [ + "from rdkit.Chem import AllChem\n", + "from rdkit import Chem\n", + "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n", + "from IPython.display import SVG, display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TFNKmGddVoOk" + }, + "source": [ + "Select molecule to visualize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KzuwLlrrVaGc" + }, + "outputs": [], + "source": [ + "idx = 0\n", + "assert idx < len(results), \"selected molecule that was not generated\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Viewing" + ], + "metadata": { + "id": "hkb8w0_SNtU8" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I3R4QBQeKttN" + }, + "source": [ + "This 2D rendering is the equivalent of the **input to the model**!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gkQRWjraaKex", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 321 + }, + "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" + }, + "metadata": {} + } + ], + "source": [ + "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n", + "molSize=(450,300)\n", + "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n", + "drawer.DrawMolecule(mc)\n", + "drawer.FinishDrawing()\n", + "svg = drawer.GetDrawingText()\n", + "display(SVG(svg.replace('svg:','')))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z4FDMYMxKw2I" + }, + "source": [ + "Generate the 3d molecule! " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aT1Bkb8YxJfV", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "695ab5bbf30a4ab19df1f9f33469f314", + "eac6a8dcdc9d4335a2e51031793ead29" + ] + }, + "outputId": "b98870ae-049d-4386-b676-166e9526bda2" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "695ab5bbf30a4ab19df1f9f33469f314" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } + ], + "source": [ + "from nglview import show_rdkit as show" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxtq8I-I18C-", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "be446195da2b4ff2aec21ec5ff963a54", + "c6596896148b4a8a9c57963b67c7782f", + "2489b5e5648541fbbdceadb05632a050", + "01e0ba4e5da04914b4652b8d58565d7b", + "c30e6c2f3e2a44dbbb3d63bd519acaa4", + "f31c6e40e9b2466a9064a2669933ecd5", + "19308ccac642498ab8b58462e3f1b0bb", + "4a081cdc2ec3421ca79dd933b7e2b0c4", + "e5c0d75eb5e1447abd560c8f2c6017e1", + "5146907ef6764654ad7d598baebc8b58", + "144ec959b7604a2cabb5ca46ae5e5379", + "abce2a80e6304df3899109c6d6cac199", + "65195cb7a4134f4887e9dd19f3676462" + ] + }, + "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "NGLWidget()" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "be446195da2b4ff2aec21ec5ff963a54" + } + }, + "metadata": { + "application/vnd.jupyter.widget-view+json": { + "colab": { + "custom_widget_manager": { + "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js" + } + } + } + } + } + ], + "source": [ + "# new molecule\n", + "show(mols_gen[idx])" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "KJr4h2mwXeTo" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "d90f304e9560472eacfbdd11e46765eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c", + "IPY_MODEL_c2321b3bff6f490ca12040a20308f555", + "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d" + ], + "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f" + } + }, + "1c6246f15b654f4daa11c9bcf997b78c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df", + "placeholder": "​", + "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c", + "value": "Downloading: 100%" + } + }, + "c2321b3bff6f490ca12040a20308f555": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd", + "max": 3271865, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d", + "value": 3271865 + } + }, + "b7feb522161f4cf4b7cc7c1a078ff12d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436", + "placeholder": "​", + "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da", + "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]" + } + }, + "e2d368556e494ae7ae4e2e992af2cd4f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbef741e76ec41b7ab7187b487a383df": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "561f742d418d4721b0670cc8dd62e22c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "872915dd1bb84f538c44e26badabafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d022575f1fa2446d891650897f187b4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "fdc393f3468c432aa0ada05e238a5436": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9362906e4b40189f16d14aa9a348da": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6010fc8daa7a44d5aec4b830ec2ebaa1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2", + "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a", + "IPY_MODEL_6526646be5ed415c84d1245b040e629b" + ], + "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab" + } + }, + "7e0bb1b8d65249d3974200686b193be2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8", + "placeholder": "​", + "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08", + "value": "Downloading: 100%" + } + }, + "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39", + "max": 401, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f915cf874246446595206221e900b2fe", + "value": 401 + } + }, + "6526646be5ed415c84d1245b040e629b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433", + "placeholder": "​", + "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4", + "value": " 401/401 [00:00<00:00, 13.5kB/s]" + } + }, + "24d31fc3576e43dd9f8301d2ef3a37ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2918bfaadc8d4b1a9832522c40dfefb8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a4bfdca35cc54dae8812720f1b276a08": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e4901541199b45c6a18824627692fc39": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f915cf874246446595206221e900b2fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a9e388f22a9742aaaf538e22575c9433": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "42f6c3db29d7484ba6b4f73590abd2f4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695ab5bbf30a4ab19df1f9f33469f314": { + "model_module": "nglview-js-widgets", + "model_name": "ColormakerRegistryModel", + "model_module_version": "3.0.1", + "state": { + "_dom_classes": [], + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "ColormakerRegistryModel", + "_msg_ar": [], + "_msg_q": [], + "_ready": false, + "_view_count": null, + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "ColormakerRegistryView", + "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29" + } + }, + "eac6a8dcdc9d4335a2e51031793ead29": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be446195da2b4ff2aec21ec5ff963a54": { + "model_module": "nglview-js-widgets", + "model_name": "NGLModel", + "model_module_version": "3.0.1", + "state": { + "_camera_orientation": [ + -15.519693580202304, + -14.065056548036177, + -23.53197484807691, + 0, + -23.357853515109753, + 20.94055073042662, + 2.888695042134944, + 0, + 14.352363398292777, + 18.870825741878015, + -20.744689572909344, + 0, + 0.2724999189376831, + 0.6940000057220459, + -0.3734999895095825, + 1 + ], + "_camera_str": "orthographic", + "_dom_classes": [], + "_gui_theme": null, + "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050", + "_igui": null, + "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b", + "_model_module": "nglview-js-widgets", + "_model_module_version": "3.0.1", + "_model_name": "NGLModel", + "_ngl_color_dict": {}, + "_ngl_coordinate_resource": {}, + "_ngl_full_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_msg_archive": [ + { + "target": "Stage", + "type": "call_method", + "methodName": "loadFile", + "reconstruc_color_scheme": false, + "args": [ + { + "type": "blob", + "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n", + "binary": false + } + ], + "kwargs": { + "defaultRepresentation": true, + "ext": "pdb" + } + } + ], + "_ngl_original_stage_parameters": { + "impostor": true, + "quality": "medium", + "workerDefault": true, + "sampleLevel": 0, + "backgroundColor": "white", + "rotateSpeed": 2, + "zoomSpeed": 1.2, + "panSpeed": 1, + "clipNear": 0, + "clipFar": 100, + "clipDist": 10, + "fogNear": 50, + "fogFar": 100, + "cameraFov": 40, + "cameraEyeSep": 0.3, + "cameraType": "perspective", + "lightColor": 14540253, + "lightIntensity": 1, + "ambientColor": 14540253, + "ambientIntensity": 0.2, + "hoverTimeout": 0, + "tooltip": true, + "mousePreset": "default" + }, + "_ngl_repr_dict": { + "0": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + }, + "1": { + "0": { + "type": "ball+stick", + "params": { + "lazy": false, + "visible": true, + "quality": "high", + "sphereDetail": 2, + "radialSegments": 20, + "openEnded": true, + "disableImpostor": false, + "aspectRatio": 1.5, + "lineOnly": false, + "cylinderOnly": false, + "multipleBond": "off", + "bondScale": 0.3, + "bondSpacing": 0.75, + "linewidth": 2, + "radiusType": "size", + "radiusData": {}, + "radiusSize": 0.15, + "radiusScale": 2, + "assembly": "default", + "defaultAssembly": "", + "clipNear": 0, + "clipRadius": 0, + "clipCenter": { + "x": 0, + "y": 0, + "z": 0 + }, + "flatShaded": false, + "opacity": 1, + "depthWrite": true, + "side": "double", + "wireframe": false, + "colorScheme": "element", + "colorScale": "", + "colorReverse": false, + "colorValue": 9474192, + "colorMode": "hcl", + "roughness": 0.4, + "metalness": 0, + "diffuse": 16777215, + "diffuseInterior": false, + "useInteriorColor": true, + "interiorColor": 2236962, + "interiorDarkening": 0, + "matrix": { + "elements": [ + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1 + ] + }, + "disablePicking": false, + "sele": "" + } + } + } + }, + "_ngl_serialize": false, + "_ngl_version": "", + "_ngl_view_id": [ + "FB989FD1-5B9C-446B-8914-6B58AF85446D" + ], + "_player_dict": {}, + "_scene_position": {}, + "_scene_rotation": {}, + "_synced_model_ids": [], + "_synced_repr_model_ids": [], + "_view_count": null, + "_view_height": "", + "_view_module": "nglview-js-widgets", + "_view_module_version": "3.0.1", + "_view_name": "NGLView", + "_view_width": "", + "background": "white", + "frame": 0, + "gui_style": null, + "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f", + "max_frame": 0, + "n_components": 2, + "picked": {} + } + }, + "c6596896148b4a8a9c57963b67c7782f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2489b5e5648541fbbdceadb05632a050": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "", + "disabled": false, + "icon": "compress", + "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199", + "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462", + "tooltip": "" + } + }, + "01e0ba4e5da04914b4652b8d58565d7b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1", + "IPY_MODEL_5146907ef6764654ad7d598baebc8b58" + ], + "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379" + } + }, + "c30e6c2f3e2a44dbbb3d63bd519acaa4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f31c6e40e9b2466a9064a2669933ecd5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "19308ccac642498ab8b58462e3f1b0bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4a081cdc2ec3421ca79dd933b7e2b0c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "SliderStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "SliderStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "", + "handle_color": null + } + }, + "e5c0d75eb5e1447abd560c8f2c6017e1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "PlayModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PlayModel", + "_playing": false, + "_repeat": false, + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PlayView", + "description": "", + "description_tooltip": null, + "disabled": false, + "interval": 100, + "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4", + "max": 0, + "min": 0, + "show_repeat": true, + "step": 1, + "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5", + "value": 0 + } + }, + "5146907ef6764654ad7d598baebc8b58": { + "model_module": "@jupyter-widgets/controls", + "model_name": "IntSliderModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "IntSliderModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "IntSliderView", + "continuous_update": true, + "description": "", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb", + "max": 0, + "min": 0, + "orientation": "horizontal", + "readout": true, + "readout_format": "d", + "step": 1, + "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4", + "value": 0 + } + }, + "144ec959b7604a2cabb5ca46ae5e5379": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "abce2a80e6304df3899109c6d6cac199": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "34px" + } + }, + "65195cb7a4134f4887e9dd19f3676462": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ButtonStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/diffusers/reinforcement_learning_with_diffusers.ipynb b/diffusers/reinforcement_learning_with_diffusers.ipynb new file mode 100644 index 00000000..95144236 --- /dev/null +++ b/diffusers/reinforcement_learning_with_diffusers.ipynb @@ -0,0 +1,4252 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "UwC-_9kK37yN" + }, + "source": [ + "### Introduction\n", + "This notebook is designed to run inference on the [Diffuser](https://arxiv.org/abs/2205.09991) planning model for model-based RL. The notebook is modified from the authors' [original](https://colab.research.google.com/drive/1YajKhu-CUIGBJeQPehjVPJcK_b38a8Nc?usp=sharing#scrollTo=57hSzI4mCgat). For those new to reinforcement learning, consider checking out the HuggingFace [Reinforcement Learning Course](https://huggingface.co/blog/deep-rl-intro) for a primer.\n", + "\n", + "> Colab made by [Nathan Lambert](https://natolambert.com) and [Ben Glickenhaus](https://www.linkedin.com/in/benjamin-glickenhaus-859532a3).\n", + "\n", + "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7Cy2P-c4XFTx" + }, + "source": [ + "### Installing Packages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tj7eyweNapes" + }, + "source": [ + "#### `apt-get install` requirements \n", + "\n", + "These requirements primarily pertain to install mujoco and run it in the colab.\n", + "Source was inspired by this (fairly recent) [demo](https://colab.research.google.com/drive/1KGMZdRq6AemfcNscKjgpRzXqfhUtCf-V?usp=sharing)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HKMZc5zvfoY1", + "outputId": "c173b13e-98f2-48f1-92ef-dbad90d347c4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "libglew-dev is already the newest version (2.0.0-5).\n", + "libgl1-mesa-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).\n", + "libgl1-mesa-glx is already the newest version (20.0.8-0ubuntu1~18.04.1).\n", + "libosmesa6-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).\n", + "software-properties-common is already the newest version (0.96.24.32.18).\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "0 upgraded, 0 newly installed, 0 to remove and 27 not upgraded.\n", + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "patchelf is already the newest version (0.9-1).\n", + "The following package was automatically installed and is no longer required:\n", + " libnvidia-common-460\n", + "Use 'apt autoremove' to remove it.\n", + "0 upgraded, 0 newly installed, 0 to remove and 27 not upgraded.\n" + ] + } + ], + "source": [ + "# installations primiarly needed for Mujoco\n", + "!apt-get install -y \\\n", + " libgl1-mesa-dev \\\n", + " libgl1-mesa-glx \\\n", + " libglew-dev \\\n", + " libosmesa6-dev \\\n", + " software-properties-common\n", + "\n", + "!apt-get install -y patchelf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ppxv6Mdkalbc" + }, + "source": [ + "#### Install Diffusers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mgQA_XN-XGY2", + "outputId": "e1a26091-ccc5-44c9-c9aa-0d51875114c9" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content\n", + "Found existing installation: diffusers 0.5.0.dev0\n", + "Uninstalling diffusers-0.5.0.dev0:\n", + " Successfully uninstalled diffusers-0.5.0.dev0\n", + "Cloning into 'diffusers'...\n", + "remote: Enumerating objects: 10356, done.\u001b[K\n", + "remote: Counting objects: 100% (502/502), done.\u001b[K\n", + "remote: Compressing objects: 100% (251/251), done.\u001b[K\n", + "remote: Total 10356 (delta 277), reused 384 (delta 201), pack-reused 9854\u001b[K\n", + "Receiving objects: 100% (10356/10356), 7.81 MiB | 17.77 MiB/s, done.\n", + "Resolving deltas: 100% (6885/6885), done.\n", + "\u001b[33m DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.\n", + " pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.\u001b[0m\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for diffusers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "%cd /content\n", + "!pip uninstall -y diffusers\n", + "# install latest HF diffusers\n", + "!rm -rf /content/diffusers/\n", + "!git clone -b rl https://github.com/huggingface/diffusers.git\n", + "!pip install -q /content/diffusers \n", + "!pip install -q datasets transformers " + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "q4C1hwIAaeZQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ILOcqnwhzrj" + }, + "source": [ + "#### `pip install` requirements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x29OgI2uh8Iv", + "outputId": "c9de1921-2978-47df-c5a3-fa8aa8fb8c20" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Looking in links: https://download.pytorch.org/whl/torch_stable.html\n", + "Collecting git+https://github.com/rail-berkeley/d4rl.git\n", + " Cloning https://github.com/rail-berkeley/d4rl.git to /tmp/pip-req-build-7j2y8u6t\n", + " Running command git clone -q https://github.com/rail-berkeley/d4rl.git /tmp/pip-req-build-7j2y8u6t\n", + "Requirement already satisfied: free-mujoco-py in /usr/local/lib/python3.7/dist-packages (2.1.6)\n", + "Requirement already satisfied: einops in /usr/local/lib/python3.7/dist-packages (0.5.0)\n", + "Requirement already satisfied: gym in /usr/local/lib/python3.7/dist-packages (0.24.1)\n", + "Requirement already satisfied: protobuf==3.20.1 in /usr/local/lib/python3.7/dist-packages (3.20.1)\n", + "Requirement already satisfied: mediapy in /usr/local/lib/python3.7/dist-packages (1.1.2)\n", + "Requirement already satisfied: Pillow==9.0.0 in /usr/local/lib/python3.7/dist-packages (9.0.0)\n", + "Collecting mjrl@ git+https://github.com/aravindr93/mjrl@master#egg=mjrl\n", + " Cloning https://github.com/aravindr93/mjrl (to revision master) to /tmp/pip-install-g98wzheg/mjrl_0abe064c9aa541e98742a70535434798\n", + " Running command git clone -q https://github.com/aravindr93/mjrl /tmp/pip-install-g98wzheg/mjrl_0abe064c9aa541e98742a70535434798\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (1.21.6)\n", + "Requirement already satisfied: mujoco_py in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (2.1.2.14)\n", + "Requirement already satisfied: pybullet in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (3.2.5)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (3.1.0)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (2.0.1)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (7.1.2)\n", + "Requirement already satisfied: dm_control>=1.0.3 in /usr/local/lib/python3.7/dist-packages (from D4RL==1.1) (1.0.8)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from gym) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from gym) (1.5.0)\n", + "Requirement already satisfied: importlib-metadata>=4.8.0 in /usr/local/lib/python3.7/dist-packages (from gym) (4.13.0)\n", + "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (1.3.0)\n", + "Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (4.9.1)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (2.23.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (4.64.1)\n", + "Requirement already satisfied: pyparsing<3.0.0 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (2.4.7)\n", + "Requirement already satisfied: mujoco>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (2.3.0)\n", + "Requirement already satisfied: pyopengl>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (3.1.6)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (1.7.3)\n", + "Requirement already satisfied: dm-tree!=0.1.2 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (0.1.7)\n", + "Requirement already satisfied: glfw in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (1.12.0)\n", + "Requirement already satisfied: dm-env in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (1.5)\n", + "Requirement already satisfied: labmaze in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (1.0.5)\n", + "Requirement already satisfied: setuptools!=50.0.0 in /usr/local/lib/python3.7/dist-packages (from dm_control>=1.0.3->D4RL==1.1) (57.4.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.8.0->gym) (3.9.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.8.0->gym) (4.1.1)\n", + "Requirement already satisfied: Cython<0.30.0,>=0.29.24 in /usr/local/lib/python3.7/dist-packages (from free-mujoco-py) (0.29.32)\n", + "Requirement already satisfied: cffi<2.0.0,>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from free-mujoco-py) (1.15.1)\n", + "Requirement already satisfied: fasteners==0.15 in /usr/local/lib/python3.7/dist-packages (from free-mujoco-py) (0.15)\n", + "Requirement already satisfied: imageio<3.0.0,>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from free-mujoco-py) (2.9.0)\n", + "Requirement already satisfied: monotonic>=0.1 in /usr/local/lib/python3.7/dist-packages (from fasteners==0.15->free-mujoco-py) (1.6)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from fasteners==0.15->free-mujoco-py) (1.15.0)\n", + "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi<2.0.0,>=1.15.0->free-mujoco-py) (2.21)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mediapy) (3.2.2)\n", + "Requirement already satisfied: ipython in /usr/local/lib/python3.7/dist-packages (from mediapy) (7.9.0)\n", + "Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py->D4RL==1.1) (1.5.2)\n", + "Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (5.1.1)\n", + "Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (0.7.5)\n", + "Requirement already satisfied: backcall in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (0.2.0)\n", + "Requirement already satisfied: jedi>=0.10 in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (0.18.1)\n", + "Requirement already satisfied: pexpect in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (4.8.0)\n", + "Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (2.6.1)\n", + "Requirement already satisfied: prompt-toolkit<2.1.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (2.0.10)\n", + "Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython->mediapy) (4.4.2)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.7/dist-packages (from jedi>=0.10->ipython->mediapy) (0.8.3)\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.1.0,>=2.0.0->ipython->mediapy) (0.2.5)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mediapy) (0.11.0)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mediapy) (2.8.2)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mediapy) (1.4.4)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.7/dist-packages (from pexpect->ipython->mediapy) (0.7.0)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->dm_control>=1.0.3->D4RL==1.1) (2.10)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->dm_control>=1.0.3->D4RL==1.1) (3.0.4)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->dm_control>=1.0.3->D4RL==1.1) (1.25.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->dm_control>=1.0.3->D4RL==1.1) (2022.9.24)\n" + ] + } + ], + "source": [ + "# primarily RL-sepcific requirements\n", + "%pip install -f https://download.pytorch.org/whl/torch_stable.html \\\n", + " free-mujoco-py \\\n", + " einops \\\n", + " gym \\\n", + " protobuf==3.20.1 \\\n", + " git+https://github.com/rail-berkeley/d4rl.git \\\n", + " mediapy \\\n", + " Pillow==9.0.0 \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7OkjKjMPfSZR" + }, + "source": [ + "#### Import D4RL to initialize Mujoco\n", + "[Mujoco](https://github.com/deepmind/mujoco) is a physics simulator used extensively in reinforcement learning research. Here, we import [D4RL](https://github.com/rail-berkeley/d4rl) (a library of datasets and environments for Offline RL), which results in the building of Mujoco." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rFVGxWIuVj5F", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "6a641f6a-e51e-4a79-e5e0-1d3937ea1ed8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Warning: Gym version v0.24.1 has a number of critical issues with `gym.make` such that environment observation and action spaces are incorrectly evaluated, raising incorrect errors and warning . It is recommend to downgrading to v0.23.1 or upgrading to v0.25.1\n", + "Warning: Flow failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n", + "No module named 'flow'\n", + "Warning: CARLA failed to import. Set the environment variable D4RL_SUPPRESS_IMPORT_ERROR=1 to suppress this message.\n", + "No module named 'carla'\n", + "/usr/local/lib/python3.7/dist-packages/gym/envs/registration.py:416: UserWarning: \u001b[33mWARN: The `registry.env_specs` property along with `EnvSpecTree` is deprecated. Please use `registry` directly as a dictionary instead.\u001b[0m\n", + " \"The `registry.env_specs` property along with `EnvSpecTree` is deprecated. Please use `registry` directly as a dictionary instead.\"\n" + ] + } + ], + "source": [ + "## cythonize mujoco-py at first import\n", + "import d4rl" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e3fx2pZrgIu3" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0qKnJbCXssgw" + }, + "source": [ + "### Environment & Model Setup\n", + "In this section, we will create the environment, handle the data, and run the diffusion model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R0CRaEtNVq8C" + }, + "source": [ + "#### Imports\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MDTifo67l-zN" + }, + "outputs": [], + "source": [ + "import torch\n", + "import tqdm\n", + "import numpy as np\n", + "import gym " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Wtw1hYkgqCR" + }, + "source": [ + "#### Create environment\n", + "This colab is designed to run with pretrained models from the hopper environment. As more models are trained, this can be extended.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FKpK2gz1gvpn", + "outputId": "7e5a38c6-fe81-4e77-d6b5-837ed909ad50" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/gym/envs/mujoco/mujoco_env.py:47: UserWarning: \u001b[33mWARN: This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).\u001b[0m\n", + " \"This version of the mujoco environments depends \"\n", + "/usr/local/lib/python3.7/dist-packages/gym/spaces/box.py:112: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", + " logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:70: UserWarning: \u001b[33mWARN: Agent's minimum action space value is -infinity. This is probably too low.\u001b[0m\n", + " \"Agent's minimum action space value is -infinity. This is probably too low.\"\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:74: UserWarning: \u001b[33mWARN: Agent's maximum action space value is infinity. This is probably too high\u001b[0m\n", + " \"Agent's maximum action space value is infinity. This is probably too high\"\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:98: UserWarning: \u001b[33mWARN: We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html\u001b[0m\n", + " \"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) \"\n", + "load datafile: 19%|█▉ | 4/21 [00:00<00:03, 5.38it/s]/usr/local/lib/python3.7/dist-packages/h5py/_hl/dataset.py:767: DeprecationWarning: Passing None into shape arguments as an alias for () is deprecated.\n", + " arr = numpy.ndarray(selection.mshape, dtype=new_dtype)\n", + "load datafile: 100%|██████████| 21/21 [00:01<00:00, 15.70it/s]\n" + ] + } + ], + "source": [ + "env_name = \"hopper-medium-v2\"\n", + "env = gym.make(env_name)\n", + "data = env.get_dataset() # dataset is only used for normalization in this colab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wWAyYOvIgzQH" + }, + "source": [ + "#### Define constants" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5UORiIwOg2XK" + }, + "outputs": [], + "source": [ + "# Cuda settings for colab\n", + "torch.cuda.get_device_name(0)\n", + "DEVICE = 'cuda:0'\n", + "DTYPE = torch.float\n", + "\n", + "# diffusion model settings\n", + "n_samples = 4 # number of trajectories planned via diffusion\n", + "horizon = 128 # length of sampled trajectories\n", + "state_dim = env.observation_space.shape[0] \n", + "action_dim = env.action_space.shape[0]\n", + "num_inference_steps = 20 # number of difusion steps" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vxe7lRYApfqb" + }, + "source": [ + "#### Helper functions\n", + "* `normalize` scales the state values corresponding to the training data-set in D4RL,\n", + "* `de_normalize` unscales the data for correct rendering,\n", + "* `to_torch` handles casting to torch for both numpy arrays and dicts (used for conditionning the model, see `reset_x0`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qGCapQB_l8ph" + }, + "outputs": [], + "source": [ + "def normalize(x_in, data, key):\n", + " means = data[key].mean(axis=0)\n", + " stds = data[key].std(axis=0)\n", + " return (x_in - means) / stds\n", + "\n", + "\n", + "def de_normalize(x_in, data, key):\n", + " means = data[key].mean(axis=0)\n", + " stds = data[key].std(axis=0)\n", + " return x_in * stds + means\n", + "\t\n", + "def to_torch(x_in, dtype=None, device=None):\n", + "\tdtype = dtype or DTYPE\n", + "\tdevice = device or DEVICE\n", + "\tif type(x_in) is dict:\n", + "\t\treturn {k: to_torch(v, dtype, device) for k, v in x_in.items()}\n", + "\telif torch.is_tensor(x_in):\n", + "\t\treturn x_in.to(device).type(dtype)\n", + "\treturn torch.tensor(x_in, dtype=dtype, device=device)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Bw2wObJOVt-l" + }, + "source": [ + "#### Sample env. initial state" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "t9VdEBJLlOAA" + }, + "outputs": [], + "source": [ + "## Can set environment seed for debugging\n", + "# torch.manual_seed(0)\n", + "# np.random.seed(0)\n", + "# env.seed(1996)\n", + "\n", + "obs = env.reset()\n", + "obs_raw = obs\n", + "\n", + "# normalize observations for forward passes\n", + "obs = normalize(obs, data, 'observations')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "77DJISGlhAom" + }, + "source": [ + "### Run the Diffusion Process -- from Scratch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sKWZhFGY9LYn" + }, + "source": [ + "#### Initialize model\n", + "In this section, we create a scheduler and load a pretrained model from the Hub. An important detail in the RL application space is to save `conditions` which will allow the model to optimize trajectories only from the current state (which is cruical to making decisions!). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4z78D4ikh1lm" + }, + "outputs": [], + "source": [ + "from diffusers import DDPMScheduler, UNet1DModel\n", + "\n", + "# Two generators for different parts of the diffusion loop to work in colab\n", + "generator = torch.Generator(device='cuda')\n", + "generator_cpu = torch.Generator(device='cpu')\n", + "\n", + "scheduler = DDPMScheduler(num_train_timesteps=100,beta_schedule=\"squaredcos_cap_v2\")\n", + "\n", + "# The horizion represents the length of trajectories used in training.\n", + "network = UNet1DModel.from_pretrained(\"bglick13/hopper-medium-v2-value-function-hor32\", subfolder=\"unet\").to(device=DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yDdFDvWkh5iY" + }, + "source": [ + "#### Planning helper function\n", + "`reset_x0` is used to constrain the diffusion process to trajectories starting at the current state of the agent. \n", + "Without this, the diffusion process would generate arbitrary high-reward trajectories, rather than trajectories beginning at the current state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uh7x-zg7g_el" + }, + "outputs": [], + "source": [ + "def reset_x0(x_in, cond, act_dim):\n", + "\tfor key, val in cond.items():\n", + "\t\tx_in[:, key, act_dim:] = val.clone()\n", + "\treturn x_in" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4_L3i61niRMW" + }, + "source": [ + "#### Setup for denoising\n", + "`conditions` is the variable used to hold the first state of the planned trajectories to the current state (it is passed into `reset_x0`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U2SkcJD49NNr" + }, + "outputs": [], + "source": [ + "## add a batch dimension and repeat for multiple samples\n", + "## [ observation_dim ] --> [ n_samples x observation_dim ]\n", + "obs = obs[None].repeat(n_samples, axis=0)\n", + "conditions = {\n", + " 0: to_torch(obs, device=DEVICE)\n", + " }\n", + "\n", + "# constants for inference\n", + "batch_size = len(conditions[0])\n", + "shape = (batch_size, horizon, state_dim+action_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qjMfrIi8iTtJ" + }, + "source": [ + "#### Sample initial noise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HIVNaHnaiYKR" + }, + "outputs": [], + "source": [ + "# sample random initial noise vector\n", + "x1 = torch.randn(shape, device=DEVICE, generator=generator)\n", + "\n", + "# this model is conditioned from an initial state, so you will see this function\n", + "# multiple times to change the initial state of generated data to the state \n", + "# generated via env.reset() above or env.step() below\n", + "x = reset_x0(x1, conditions, action_dim)\n", + "\n", + "# convert a np observation to torch for model forward pass\n", + "x = to_torch(x)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "315o-PvOsXp_" + }, + "source": [ + "#### Generate trajectories\n", + "The diffusion process for trajectories has 4 central components:\n", + "1. sampling an predicted original sample from the model (note that this model directly predicts the sample, rather than the error term `epsilon` used in many diffusion models),\n", + "2. use the scheduler to predict the sample at the previous timestep,\n", + "3. [optional] add posterior noise to the sample,\n", + "4. condition the trajectory to constrain the initial state." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "llzMmLk227jK", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "457c55b3-2ab7-4b83-c077-88ee4cf54754" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 100/100 [00:01<00:00, 78.56it/s]\n" + ] + } + ], + "source": [ + "eta = 1.0 # noise factor for sampling reconstructed state\n", + "\n", + "# run the diffusion process\n", + "# for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):\n", + "for i in tqdm.tqdm(scheduler.timesteps):\n", + "\n", + " # create batch of timesteps to pass into model\n", + " timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long)\n", + " \n", + " # 1. generate prediction from model\n", + " with torch.no_grad():\n", + " residual = network(x.permute(0, 2, 1), timesteps).sample\n", + " residual = residual.permute(0, 2, 1) # needed to match model params to original \n", + "\n", + " # 2. use the model prediction to reconstruct an observation (de-noise)\n", + " obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=False)[\"prev_sample\"]\n", + "\n", + " # 3. [optional] add posterior noise to the sample\n", + " if eta > 0:\n", + " noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device)\n", + " posterior_variance = scheduler._get_variance(i) # * noise\n", + " # no noise when t == 0\n", + " # NOTE: original implementation missing sqrt on posterior_variance\n", + " obs_reconstruct = obs_reconstruct + int(i>0) * (0.5 * posterior_variance) * eta* noise # MJ had as log var, exponentiated\n", + "\n", + " # 4. apply conditions to the trajectory\n", + " obs_reconstruct_postcond = reset_x0(obs_reconstruct, conditions, action_dim)\n", + " x = to_torch(obs_reconstruct_postcond)\n" + ] + }, + { + "cell_type": "code", + "source": [ + "x.shape" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0Dwa7VabDMGP", + "outputId": "df8f831b-109a-4c97-ce07-13cad4412bc7" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([4, 128, 14])" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nkfuBJDTigOE" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OhHZC48kVxGM" + }, + "source": [ + "### Render the samples" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g4wYTq74pGCY" + }, + "source": [ + "#### Renderering Tools\n", + "Rendering from Mujoco is historically not easy. Here is a modified version from the original paper. Additionally, a TODO is to investigate this web-based [viewer](https://github.com/kevinzakka/mjc_viewer)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D7FxIwtopERr" + }, + "source": [ + "##### Video helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3g8N_n8VRLPs" + }, + "outputs": [], + "source": [ + "import os\n", + "import mediapy as media\n", + "\n", + "def to_np(x_in):\n", + "\tif torch.is_tensor(x_in):\n", + "\t\tx_in = x_in.detach().cpu().numpy()\n", + "\treturn x_in\n", + "\n", + "# from MJ's Diffuser code \n", + "# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79\n", + "def mkdir(savepath):\n", + " \"\"\"\n", + " returns `True` iff `savepath` is created\n", + " \"\"\"\n", + " if not os.path.exists(savepath):\n", + " os.makedirs(savepath)\n", + " return True\n", + " else:\n", + " return False\n", + "\n", + "\n", + "def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'):\n", + " '''\n", + " observations : [ batch_size x horizon x observation_dim ]\n", + " '''\n", + "\n", + " mkdir(savebase)\n", + " savepath = os.path.join(savebase, filename)\n", + "\n", + " images = []\n", + " for rollout in observations:\n", + " ## [ horizon x height x width x channels ]\n", + " img = renderer._renders(rollout, partial=True)\n", + " images.append(img)\n", + "\n", + " ## [ horizon x height x (batch_size * width) x channels ]\n", + " images = np.concatenate(images, axis=2)\n", + "\n", + " media.show_video(images, codec='h264', fps=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RsvU93pIt26I" + }, + "source": [ + "##### Renderer helpers\n", + "These functions involve setting the state of the environment and reading it out in a pixel form." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yg9JiztlpH1o" + }, + "outputs": [], + "source": [ + "# Code adapted from Michael Janner\n", + "# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py\n", + "import mujoco_py as mjc\n", + "\n", + "def env_map(env_name):\n", + " '''\n", + " map D4RL dataset names to custom fully-observed\n", + " variants for rendering\n", + " '''\n", + " if 'halfcheetah' in env_name:\n", + " return 'HalfCheetahFullObs-v2'\n", + " elif 'hopper' in env_name:\n", + " return 'HopperFullObs-v2'\n", + " elif 'walker2d' in env_name:\n", + " return 'Walker2dFullObs-v2'\n", + " else:\n", + " return env_name\n", + "\n", + "def get_image_mask(img):\n", + " background = (img == 255).all(axis=-1, keepdims=True)\n", + " mask = ~background.repeat(3, axis=-1)\n", + " return mask\n", + "\n", + "def atmost_2d(x):\n", + " while x.ndim > 2:\n", + " x = x.squeeze(0)\n", + " return x\n", + "\n", + "def set_state(env, state):\n", + " qpos_dim = env.sim.data.qpos.size\n", + " qvel_dim = env.sim.data.qvel.size\n", + " if not state.size == qpos_dim + qvel_dim:\n", + " warnings.warn(\n", + " f'[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, '\n", + " f'but got state of size {state.size}')\n", + " state = state[:qpos_dim + qvel_dim]\n", + "\n", + " env.set_state(state[:qpos_dim], state[qpos_dim:])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hbjK273Yt7AD" + }, + "source": [ + "##### Rendering class\n", + "Use the previously defined helpers to programatically render pixel sequences from a trajectory of states. \n", + "This class takes the re-scaled outputs of the diffusion process and visualizes them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s9iDg9oBt6aL" + }, + "outputs": [], + "source": [ + "class MuJoCoRenderer:\n", + " '''\n", + " default mujoco renderer\n", + " '''\n", + "\n", + " def __init__(self, env):\n", + " if type(env) is str:\n", + " env = env_map(env)\n", + " self.env = gym.make(env)\n", + " else:\n", + " self.env = env\n", + " ## - 1 because the envs in renderer are fully-observed\n", + " ## @TODO : clean up\n", + " self.observation_dim = np.prod(self.env.observation_space.shape) - 1\n", + " self.action_dim = np.prod(self.env.action_space.shape)\n", + " try:\n", + " self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)\n", + " except:\n", + " print('[ utils/rendering ] Warning: could not initialize offscreen renderer')\n", + " self.viewer = None\n", + "\n", + " def pad_observation(self, observation):\n", + " state = np.concatenate([\n", + " np.zeros(1),\n", + " observation,\n", + " ])\n", + " return state\n", + "\n", + " def pad_observations(self, observations):\n", + " qpos_dim = self.env.sim.data.qpos.size\n", + " ## xpos is hidden\n", + " xvel_dim = qpos_dim - 1\n", + " xvel = observations[:, xvel_dim]\n", + " xpos = np.cumsum(xvel) * self.env.dt\n", + " states = np.concatenate([\n", + " xpos[:,None],\n", + " observations,\n", + " ], axis=-1)\n", + " return states\n", + "\n", + " def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None):\n", + "\n", + " if type(dim) == int:\n", + " dim = (dim, dim)\n", + "\n", + " if self.viewer is None:\n", + " return np.zeros((*dim, 3), np.uint8)\n", + "\n", + " if render_kwargs is None:\n", + " xpos = observation[0] if not partial else 0\n", + " render_kwargs = {\n", + " 'trackbodyid': 2,\n", + " 'distance': 3,\n", + " 'lookat': [xpos, -0.5, 1],\n", + " 'elevation': -20\n", + " }\n", + "\n", + " for key, val in render_kwargs.items():\n", + " if key == 'lookat':\n", + " self.viewer.cam.lookat[:] = val[:]\n", + " else:\n", + " setattr(self.viewer.cam, key, val)\n", + "\n", + " if partial:\n", + " state = self.pad_observation(observation)\n", + " else:\n", + " state = observation\n", + "\n", + " qpos_dim = self.env.sim.data.qpos.size\n", + " if not qvel or state.shape[-1] == qpos_dim:\n", + " qvel_dim = self.env.sim.data.qvel.size\n", + " state = np.concatenate([state, np.zeros(qvel_dim)])\n", + "\n", + " set_state(self.env, state)\n", + "\n", + " self.viewer.render(*dim)\n", + " data = self.viewer.read_pixels(*dim, depth=False)\n", + " data = data[::-1, :, :]\n", + " return data\n", + "\n", + " def _renders(self, observations, **kwargs):\n", + " images = []\n", + " for observation in observations:\n", + " img = self.render(observation, **kwargs)\n", + " images.append(img)\n", + " return np.stack(images, axis=0)\n", + "\n", + " def renders(self, samples, partial=False, **kwargs):\n", + " if partial:\n", + " samples = self.pad_observations(samples)\n", + " partial = False\n", + "\n", + " sample_images = self._renders(samples, partial=partial, **kwargs)\n", + "\n", + " composite = np.ones_like(sample_images[0]) * 255\n", + "\n", + " for img in sample_images:\n", + " mask = get_image_mask(img)\n", + " composite[mask] = img[mask]\n", + "\n", + " return composite\n", + "\n", + " def __call__(self, *args, **kwargs):\n", + " return self.renders(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qIjF7ToPp_f0" + }, + "source": [ + "#### Show Plans\n", + "This section renders 4 trajectories chosen from the same initial state in the environment." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MuLscSfUiYio" + }, + "source": [ + "##### Initialize renderer class for the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Q0dgJHKJsBHl" + }, + "outputs": [], + "source": [ + "render = MuJoCoRenderer(env)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f4pegV-iibvJ" + }, + "source": [ + "##### Show the video\n", + "Show the states generated by the diffusion model in the real environment. \n", + "Not that the actions are dropped from the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Kxhu-7PiHnuF", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 279 + }, + "outputId": "ead27e9d-7e45-4faa-c8ed-3e1848bc2f82" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "
" + ] + }, + "metadata": {} + } + ], + "source": [ + "de_normalized = de_normalize(to_np(x[:,:,action_dim:]), data, 'observations')\n", + "show_sample(render, de_normalized)\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Run Value Guided Diffusion -- with Pipeline\n", + "\n", + "In this section, we repeat the above code, but we use a pre-trained pipeline in Diffusers!" + ], + "metadata": { + "id": "Eub5xlfTd0zm" + } + }, + { + "cell_type": "code", + "source": [ + "from diffusers import ValueGuidedRLPipeline" + ], + "metadata": { + "id": "0vXxDa-gd4fc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "env_name = \"hopper-medium-v2\"\n", + "env = gym.make(env_name)\n", + "data = env.get_dataset() # dataset is only used for normalization in this colab\n", + "render = MuJoCoRenderer(env)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hfkHRFNndapS", + "outputId": "8b5861de-9a1d-4ba6-f042-32491456e53a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/gym/envs/mujoco/mujoco_env.py:47: UserWarning: \u001b[33mWARN: This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).\u001b[0m\n", + " \"This version of the mujoco environments depends \"\n", + "/usr/local/lib/python3.7/dist-packages/gym/spaces/box.py:112: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", + " logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:70: UserWarning: \u001b[33mWARN: Agent's minimum action space value is -infinity. This is probably too low.\u001b[0m\n", + " \"Agent's minimum action space value is -infinity. This is probably too low.\"\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:74: UserWarning: \u001b[33mWARN: Agent's maximum action space value is infinity. This is probably too high\u001b[0m\n", + " \"Agent's maximum action space value is infinity. This is probably too high\"\n", + "/usr/local/lib/python3.7/dist-packages/gym/utils/passive_env_checker.py:98: UserWarning: \u001b[33mWARN: We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html\u001b[0m\n", + " \"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) \"\n", + "load datafile: 19%|█▉ | 4/21 [00:00<00:03, 5.16it/s]/usr/local/lib/python3.7/dist-packages/h5py/_hl/dataset.py:767: DeprecationWarning: Passing None into shape arguments as an alias for () is deprecated.\n", + " arr = numpy.ndarray(selection.mshape, dtype=new_dtype)\n", + "load datafile: 100%|██████████| 21/21 [00:01<00:00, 15.05it/s]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "state_dim = env.observation_space.shape[0]\n", + "action_dim = env.action_space.shape[0]\n", + "DEVICE = \"cuda\"" + ], + "metadata": { + "id": "NcLN2M-CdjNR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Load the pipeline!" + ], + "metadata": { + "id": "EZwHD83YY_XN" + } + }, + { + "cell_type": "code", + "source": [ + "pipeline = ValueGuidedRLPipeline.from_pretrained(\n", + " \"bglick13/hopper-medium-v2-value-function-hor32\",\n", + " env=env,\n", + " )" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 103, + "referenced_widgets": [ + "f6fee88020804250a48a1ee45806d7be", + "c52a439ef13d48b1b0f9e00b2696ea52", + "e937d5d11ec244deb19a681d9aafbbfd", + "5688c1ab38ef440180cc5309cb54b5d6", + "a918d6b301354c7a89aecd3523e77359", + "2f4d8c9b7bd544899531c4d36d62dd3e", + "1dc28c75e7f047a7b9baf4cba0276757", + "6558ad009b92422493ddcd3b46f3d089", + "a1c4efdd3e494a269b25b03e5e99d6fe", + "c977bec486204045867bc56a59b59620", + "e37678f1a748493cacfed2cb8d5fc0cb" + ] + }, + "id": "oRlo8Z841uBW", + "outputId": "04a00750-1058-4158-fe56-af78b8eefafc" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Fetching 11 files: 0%| | 0/11 [00:00" + ], + "text/html": [ + "
" + ] + }, + "metadata": {} + } + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "provenance": [], + "toc_visible": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "f6fee88020804250a48a1ee45806d7be": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c52a439ef13d48b1b0f9e00b2696ea52", + "IPY_MODEL_e937d5d11ec244deb19a681d9aafbbfd", + "IPY_MODEL_5688c1ab38ef440180cc5309cb54b5d6" + ], + "layout": "IPY_MODEL_a918d6b301354c7a89aecd3523e77359" + } + }, + "c52a439ef13d48b1b0f9e00b2696ea52": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2f4d8c9b7bd544899531c4d36d62dd3e", + "placeholder": "​", + "style": "IPY_MODEL_1dc28c75e7f047a7b9baf4cba0276757", + "value": "Fetching 11 files: 100%" + } + }, + "e937d5d11ec244deb19a681d9aafbbfd": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6558ad009b92422493ddcd3b46f3d089", + "max": 11, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a1c4efdd3e494a269b25b03e5e99d6fe", + "value": 11 + } + }, + "5688c1ab38ef440180cc5309cb54b5d6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c977bec486204045867bc56a59b59620", + "placeholder": "​", + "style": "IPY_MODEL_e37678f1a748493cacfed2cb8d5fc0cb", + "value": " 11/11 [00:00<00:00, 195.84it/s]" + } + }, + "a918d6b301354c7a89aecd3523e77359": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2f4d8c9b7bd544899531c4d36d62dd3e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1dc28c75e7f047a7b9baf4cba0276757": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6558ad009b92422493ddcd3b46f3d089": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a1c4efdd3e494a269b25b03e5e99d6fe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c977bec486204045867bc56a59b59620": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e37678f1a748493cacfed2cb8d5fc0cb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file