├── VibeGen ├── __init__.py ├── TrainerPack_advanced.py ├── UtilityPack.py ├── JointSamplingPack.py ├── DataSetPack.py └── imagen_x_imagen_pytorch.py ├── setup.py ├── .gitignore ├── README.md ├── requirements.txt └── LICENSE.txt /VibeGen/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # VibeGen/__init__.py 3 | 4 | # import .UtilityPack as UPack 5 | 6 | # import .DataSetPack as DPack 7 | 8 | # import .ModelPack as MPack 9 | 10 | # import .TrainerPack as TPack 11 | 12 | # import .JointSamplingPack as SPack 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Function to read requirements.txt 4 | def parse_requirements(filename): 5 | with open(filename, 'r') as req_file: 6 | return req_file.read().splitlines() 7 | 8 | setup( 9 | name='VibeGen', 10 | version='0.1.0', 11 | packages=find_packages(), 12 | install_requires=parse_requirements('requirements.txt'), 13 | description='VibeGen: End-to-end de novo protein generation targeting normal mode vibrations using a language diffusion model duo', 14 | author='Bo Ni', 15 | url='https://github.com/lamm-mit/ModeShapeDiffusionDesign', 16 | ) 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # +++ 165 | Local_Store/ 166 | 167 | VibeGen/working_note.MD 168 | 169 | wk_dir/ 170 | 171 | trained_duo/ 172 | 173 | VibeGen_env/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VibeGen: Agentic End-to-End De Novo Protein Design for Tailored Dynamics Using a Language Diffusion Model 2 | 3 | Bo Ni1,2, Markus J. Buehler1,3,4* 4 | 5 | 1 Laboratory for Atomistic and Molecular Mechanics (LAMM), Massachusetts Institute of Technology 6 | 7 | 2 Department of Materials Science and Engineering, Carnegie Mellon University 8 | 9 | 3 Center for Computational Science and Engineering, Schwarzman College of Computing, Massachusetts Institute of Technology 10 | 11 | 4 Lead contact 12 | 13 | * Correspondence: mbuehler@MIT.EDU 14 | 15 | Proteins are dynamic molecular machines whose biological functions, spanning enzymatic catalysis, signal transduction, and structural adaptation, are intrinsically linked to their motions. We introduce VibeGen, a generative AI model based on an agentic dual-model architecture, comprising a protein designer that generates sequence candidates based on specified vibrational modes and a protein predictor that evaluates their dynamic accuracy. Via direct validation using full-atom molecular simulations, we demonstrate that the designed proteins accurately reproduce the prescribed normal mode amplitudes across the backbone while adopting various stable, functionally relevant structures. Generated sequences are de novo, exhibiting no significant similarity to natural proteins, thereby expanding the accessible protein space beyond evolutionary constraints. Our model establishes a direct, bidirectional link between sequence and vibrational behavior, unlocking new pathways for engineering biomolecules with tailored dynamical and functional properties. Our model holds broad implications for the rational design of enzymes, dynamic scaffolds, and biomaterials via dynamics-informed protein engineering. 16 | 17 | ![plot](./assets/TOC.svg) 18 | 19 | ## Installation 20 | 21 | Create a virtual environment 22 | 23 | ```bash 24 | conda create --prefix=./VibeGen_env 25 | conda activate ./VibeGen_env 26 | 27 | ``` 28 | 29 | Install: 30 | ```bash 31 | pip install git+https://github.com/lamm-mit/ModeShapeDiffusionDesign.git 32 | 33 | ``` 34 | If you want to create an editable installation, clone the repository using `git`: 35 | ```bash 36 | git clone https://github.com/lamm-mit/ModeShapeDiffusionDesign.git 37 | cd ModeShapeDiffusionDesign 38 | ``` 39 | Then, install: 40 | ```bash 41 | pip install -r requirements.txt 42 | pip install -e . 43 | ``` 44 | 45 | ### Directory structure 46 | ``` 47 | ModeShapeDiffusionDesign/ 48 | │ 49 | ├── VibeGen/ # Source code directory 50 | │ ├── DataSetPack.py 51 | │ ├── ModelPack.py 52 | │ ├── TrainerPack.py 53 | │ ├── UtilityPack.py 54 | │ ├── JointSamplingPack.py 55 | │ └── ... 56 | │ 57 | ├── demo_1_Inferrence_with_trained_duo.ipynb # demo 1: make an inference 58 | │ 59 | ├── colab_demo/ # demos for colab 60 | │ ├── Inference_demo.ipynb # demo 1: make an inference 61 | │ └── ... 62 | │ 63 | ├── setup.py # The setup file for packaging 64 | ├── requirements.txt # List of dependencies 65 | ├── README.md # Documentation 66 | ├── assets/ # Support materials 67 | └── ... 68 | ``` 69 | 70 | ## Usage 71 | 72 | ### Inference notebooks 73 | In the following example, for each input normal mode shape condition, we use the trained ProteinDesigner to propose 20 candidates. Then the trained ProteinPredictor will pick the best and worst two from them based on its predition. The chosen seqeucnes then will be folded using OmegaFold and the seondary strucutre of them will be analyzed. 74 | 75 | ``` 76 | demo_1_inference_with_trained_duo.ipynb 77 | ``` 78 | 79 | Alternatively, similar demo can run using Colab. 80 | 81 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lamm-mit/ModeShapeDiffusionDesign/blob/main/colab_demo/Inference_demo.ipynb) 82 | 83 | ### Pretrained models 84 | The checkpoints of the pretrained models that make up the agentic system is hosted at the [repository](https://huggingface.co/lamm-mit/VibeGen) on Huggingface. 85 | 86 | ### Reference 87 | 88 | ```bibtex 89 | @paper{BoBuehler2025VibeGen, 90 | title={VibeGen: Agentic End-to-End De Novo Protein Design for Tailored Dynamics Using a Language Diffusion Model}, 91 | author={Bo Ni and Markus J. Buehler}, 92 | year={2025}, 93 | eprint={2502.10173}, 94 | archivePrefix={arXiv}, 95 | primaryClass={q-bio.BM}, 96 | url={https://arxiv.org/abs/2502.10173}, 97 | } 98 | ``` 99 | 100 | Our implementation is inspired by the [imagen-pytorch](https://github.com/lucidrains/imagen-pytorch) repository by [Phil Wang](https://github.com/lucidrains). 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.32.1 3 | aiofiles # @ file:///croot/aiofiles_1683773582346/work 4 | aiohttp==3.9.5 5 | aiosignal==1.3.1 6 | aiosqlite # @ file:///croot/aiosqlite_1683773899903/work 7 | annotated-types==0.7.0 8 | anyio # @ file:///tmp/build/80754af9/anyio_1644481695334/work/dist 9 | argon2-cffi # @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work 10 | argon2-cffi-bindings # @ file:///tmp/build/80754af9/argon2-cffi-bindings_1644553347904/work 11 | asttokens # @ file:///opt/conda/conda-bld/asttokens_1646925590279/work 12 | astunparse==1.6.3 13 | async-timeout==4.0.3 14 | attrs # @ file:///croot/attrs_1695717823297/work 15 | Babel # @ file:///croot/babel_1671781930836/work 16 | backcall # @ file:///home/ktietz/src/ci/backcall_1611930011877/work 17 | beartype==0.18.5 18 | beautifulsoup4 # @ file:///croot/beautifulsoup4-split_1681493039619/work 19 | biopython==1.84 20 | bleach # @ file:///opt/conda/conda-bld/bleach_1641577558959/work 21 | Brotli # @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 22 | certifi # @ file:///croot/certifi_1700501669400/work/certifi 23 | cffi # @ file:///croot/cffi_1700254295673/work 24 | charset-normalizer # @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 25 | comm # @ file:///croot/comm_1671231121260/work 26 | contourpy==1.2.1 27 | cryptography # @ file:///croot/cryptography_1694444244250/work 28 | cycler==0.12.1 29 | datasets==2.20.0 30 | debugpy # @ file:///croot/debugpy_1690905042057/work 31 | decorator # @ file:///opt/conda/conda-bld/decorator_1643638310831/work 32 | defusedxml # @ file:///tmp/build/80754af9/defusedxml_1615228127516/work 33 | dill==0.3.8 34 | einops==0.8.0 35 | einops-exts==0.0.4 36 | ema-pytorch==0.5.2 37 | evaluate==0.4.2 38 | exceptiongroup # @ file:///croot/exceptiongroup_1668714342571/work 39 | executing # @ file:///opt/conda/conda-bld/executing_1646925071911/work 40 | fair-esm==2.0.0 41 | fastjsonschema # @ file:///opt/conda/conda-bld/python-fastjsonschema_1661371079312/work 42 | filelock # @ file:///croot/filelock_1700591183607/work 43 | flatbuffers==24.3.25 44 | fonttools==4.53.1 45 | frozenlist==1.4.1 46 | fsspec==2024.5.0 47 | gast==0.6.0 48 | gmpy2 # @ file:///tmp/build/80754af9/gmpy2_1645455533097/work 49 | google-pasta==0.2.0 50 | grpcio==1.65.4 51 | h5py==3.11.0 52 | huggingface-hub==0.23.4 53 | idna # @ file:///croot/idna_1666125576474/work 54 | ipykernel # @ file:///croot/ipykernel_1691121631942/work 55 | ipython # @ file:///croot/ipython_1694181358621/work 56 | ipython-genutils # @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 57 | ipywidgets # @ file:///croot/ipywidgets_1679394798311/work 58 | jedi # @ file:///tmp/build/80754af9/jedi_1644315229345/work 59 | Jinja2 # @ file:///croot/jinja2_1666908132255/work 60 | joblib==1.4.2 61 | json5 # @ file:///tmp/build/80754af9/json5_1624432770122/work 62 | jsonschema # @ file:///croot/jsonschema_1699041609003/work 63 | jsonschema-specifications # @ file:///croot/jsonschema-specifications_1699032386549/work 64 | jupyter # @ file:///tmp/abs_33h4eoipez/croots/recipe/jupyter_1659349046347/work 65 | jupyter-console # @ file:///croot/jupyter_console_1679999630278/work 66 | jupyter-events # @ file:///croot/jupyter_events_1699282461638/work 67 | jupyter-ydoc # @ file:///croot/jupyter_ydoc_1683747223142/work 68 | jupyter_client # @ file:///croot/jupyter_client_1699455897726/work 69 | jupyter_core # @ file:///croot/jupyter_core_1698937308754/work 70 | jupyter_server # @ file:///croot/jupyter_server_1699466442171/work 71 | jupyter_server_fileid # @ file:///croot/jupyter_server_fileid_1684273577568/work 72 | jupyter_server_terminals # @ file:///croot/jupyter_server_terminals_1686870725608/work 73 | jupyter_server_ydoc # @ file:///croot/jupyter_server_ydoc_1686767404829/work 74 | jupyterlab # @ file:///croot/jupyterlab_1686179668131/work 75 | jupyterlab-pygments # @ file:///croot/jupyterlab_pygments_1700168593176/work 76 | jupyterlab-widgets # @ file:///croot/jupyterlab_widgets_1700168618520/work 77 | jupyterlab_server # @ file:///croot/jupyterlab_server_1699555425460/work 78 | keras==3.4.1 79 | kiwisolver==1.4.5 80 | kornia==0.7.3 81 | kornia_rs==0.1.5 82 | libclang==18.1.1 83 | Markdown==3.6 84 | markdown-it-py==3.0.0 85 | MarkupSafe # @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work 86 | matplotlib==3.9.1 87 | matplotlib-inline # @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work 88 | mdurl==0.1.2 89 | mistune # @ file:///opt/conda/conda-bld/mistune_1661496219659/work 90 | mkl-fft # @ file:///croot/mkl_fft_1695058164594/work 91 | mkl-random # @ file:///croot/mkl_random_1695059800811/work 92 | # mkl-service==2.4.0 93 | mkl-service==2.4.2 94 | ml-dtypes==0.4.0 95 | mpmath # @ file:///croot/mpmath_1690848262763/work 96 | multidict==6.0.5 97 | multiprocess==0.70.16 98 | namex==0.0.8 99 | nbclassic # @ file:///croot/nbclassic_1699542793266/work 100 | nbclient # @ file:///croot/nbclient_1698934205032/work 101 | nbconvert # @ file:///croot/nbconvert_1699022732553/work 102 | nbformat # @ file:///croot/nbformat_1694616755618/work 103 | nest-asyncio # @ file:///croot/nest-asyncio_1672387112409/work 104 | networkx # @ file:///croot/networkx_1720002482208/work 105 | notebook # @ file:///croot/notebook_1681756172480/work 106 | notebook_shim # @ file:///croot/notebook-shim_1699455894279/work 107 | numpy # @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee 108 | nvidia-cublas-cu12==12.3.4.1 109 | nvidia-cuda-cupti-cu12==12.3.101 110 | nvidia-cuda-nvcc-cu12==12.3.107 111 | nvidia-cuda-nvrtc-cu12==12.3.107 112 | nvidia-cuda-runtime-cu12==12.3.101 113 | nvidia-cudnn-cu12==8.9.7.29 114 | nvidia-cufft-cu12==11.0.12.1 115 | nvidia-curand-cu12==10.3.4.107 116 | nvidia-cusolver-cu12==11.5.4.101 117 | nvidia-cusparse-cu12==12.2.0.103 118 | nvidia-nccl-cu12==2.19.3 119 | nvidia-nvjitlink-cu12==12.3.101 120 | OmegaFold @ git+https://github.com/Bo-Ni/OmegaFold_0.git@3db771f153c247dd3686abdf4495735a4f36d933 121 | opt-einsum==3.3.0 122 | optree==0.12.1 123 | overrides # @ file:///croot/overrides_1699371140756/work 124 | packaging # @ file:///croot/packaging_1693575174725/work 125 | pandas==2.2.2 126 | pandocfilters # @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work 127 | parso # @ file:///opt/conda/conda-bld/parso_1641458642106/work 128 | PeptideBuilder==1.1.0 129 | pexpect # @ file:///tmp/build/80754af9/pexpect_1605563209008/work 130 | pickleshare # @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 131 | pillow # @ file:///croot/pillow_1714398848491/work 132 | platformdirs # @ file:///croot/platformdirs_1692205439124/work 133 | ply==3.11 134 | prometheus-client # @ file:///tmp/abs_d3zeliano1/croots/recipe/prometheus_client_1659455100375/work 135 | prompt-toolkit # @ file:///croot/prompt-toolkit_1672387306916/work 136 | protobuf==4.25.4 137 | psutil # @ file:///opt/conda/conda-bld/psutil_1656431268089/work 138 | ptyprocess # @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 139 | pure-eval # @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work 140 | py3Dmol==2.2.0 141 | pyarrow==16.1.0 142 | pyarrow-hotfix==0.6 143 | pycparser # @ file:///tmp/build/80754af9/pycparser_1636541352034/work 144 | pydantic==2.8.2 145 | pydantic_core==2.20.1 146 | Pygments # @ file:///croot/pygments_1684279966437/work 147 | pyOpenSSL # @ file:///croot/pyopenssl_1690223430423/work 148 | pyparsing==3.1.2 149 | PyQt5==5.15.10 150 | PyQt5-sip # @ file:///croot/pyqt-split_1698769088074/work/pyqt_sip 151 | PySocks # @ file:///home/builder/ci_310/pysocks_1640793678128/work 152 | python-dateutil # @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 153 | python-json-logger # @ file:///croot/python-json-logger_1683823803357/work 154 | pytorch-warmup==0.1.1 155 | pytz # @ file:///croot/pytz_1695131579487/work 156 | PyYAML # @ file:///croot/pyyaml_1698096049011/work 157 | pyzmq # @ file:///croot/pyzmq_1686601365461/work 158 | qtconsole # @ file:///croot/qtconsole_1700160644874/work 159 | QtPy # @ file:///croot/qtpy_1700144840038/work 160 | referencing # @ file:///croot/referencing_1699012038513/work 161 | regex==2024.5.15 162 | requests==2.32.3 163 | rfc3339-validator # @ file:///croot/rfc3339-validator_1683077044675/work 164 | rfc3986-validator # @ file:///croot/rfc3986-validator_1683058983515/work 165 | rich==13.7.1 166 | rpds-py # @ file:///croot/rpds-py_1698945930462/work 167 | safetensors==0.4.3 168 | scikit-learn==1.5.1 169 | scipy==1.14.0 170 | seaborn==0.13.2 171 | Send2Trash # @ file:///croot/send2trash_1699371139552/work 172 | sentencepiece==0.2.0 173 | sip # @ file:///croot/sip_1698675935381/work 174 | six # @ file:///tmp/build/80754af9/six_1644875935023/work 175 | sniffio # @ file:///home/builder/ci_310/sniffio_1640794799774/work 176 | soupsieve # @ file:///croot/soupsieve_1696347547217/work 177 | stack-data # @ file:///opt/conda/conda-bld/stack_data_1646927590127/work 178 | sympy # @ file:///croot/sympy_1701397643339/work 179 | tensorboard==2.17.0 180 | tensorboard-data-server==0.7.2 181 | tensorflow==2.17.0 182 | tensorflow-io-gcs-filesystem==0.37.1 183 | termcolor==2.4.0 184 | terminado # @ file:///croot/terminado_1671751832461/work 185 | threadpoolctl==3.5.0 186 | tinycss2 # @ file:///croot/tinycss2_1668168815555/work 187 | tokenizers==0.19.1 188 | tomli # @ file:///opt/conda/conda-bld/tomli_1657175507142/work 189 | torch==2.3.1 190 | torchaudio==2.3.1 191 | torchinfo==1.8.0 192 | torchvision==0.18.1 193 | tornado # @ file:///croot/tornado_1696936946304/work 194 | tqdm==4.66.4 195 | traitlets # @ file:///croot/traitlets_1671143879854/work 196 | transformers==4.42.4 197 | triton==2.3.1 198 | typing_extensions==4.8.0 199 | tzdata==2024.1 200 | urllib3 # @ file:///croot/urllib3_1698257533958/work 201 | wcwidth # @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work 202 | webencodings==0.5.1 203 | websocket-client # @ file:///home/builder/ci_310/websocket-client_1640795866898/work 204 | Werkzeug==3.0.3 205 | widgetsnbextension # @ file:///croot/widgetsnbextension_1679313860248/work 206 | wrapt==1.16.0 207 | xxhash==3.4.1 208 | y-py # @ file:///croot/y-py_1683555143488/work 209 | yarl==1.9.4 210 | ypy-websocket # @ file:///croot/ypy-websocket_1684171737040/work 211 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /VibeGen/TrainerPack_advanced.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task: 3 | 1. create a trainer for ProteinDesigner 4 | 2. include train_loop, sample_loop 5 | 6 | Bo Ni, Sep 8, 2024 7 | """ 8 | 9 | # ////////////////////////////////////////////////////// 10 | # 0. load in packages 11 | # ////////////////////////////////////////////////////// 12 | 13 | import os 14 | from math import ceil 15 | from contextlib import contextmanager, nullcontext 16 | from functools import partial, wraps 17 | from collections.abc import Iterable 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.utils.data import random_split, DataLoader 23 | from torch.optim import Adam 24 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 25 | from torch.cuda.amp import autocast, GradScaler 26 | 27 | import pytorch_warmup as warmup 28 | 29 | from packaging import version 30 | 31 | import numpy as np 32 | 33 | from ema_pytorch import EMA 34 | 35 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 36 | 37 | from fsspec.core import url_to_fs 38 | from fsspec.implementations.local import LocalFileSystem 39 | 40 | # ////////////////////////////////////////////////////////////// 41 | # 2. special packages 42 | # ////////////////////////////////////////////////////////////// 43 | from VibeGen.ModelPack import ( 44 | ProteinDesigner_Base 45 | ) 46 | from VibeGen.imagen_x_imagen_pytorch import ( 47 | ElucidatedImagen_OneD, eval_decorator 48 | ) 49 | 50 | # ////////////////////////////////////////////////////////////// 51 | # 3. local setup parameters: for debug purpose 52 | # ////////////////////////////////////////////////////////////// 53 | PT_Init_Level = 1 54 | PT_Forw_Level = 1 55 | 56 | # ////////////////////////////////////////////////////////////// 57 | # 4. helper functions 58 | # ////////////////////////////////////////////////////////////// 59 | def cycle(dl): 60 | while True: 61 | for data in dl: 62 | yield data 63 | 64 | def exists(val): 65 | return val is not None 66 | 67 | def default(val, d): 68 | if exists(val): 69 | return val 70 | return d() if callable(d) else d 71 | 72 | def cast_tuple(val, length = 1): 73 | if isinstance(val, list): 74 | val = tuple(val) 75 | 76 | return val if isinstance(val, tuple) else ((val,) * length) 77 | 78 | def find_first(fn, arr): 79 | for ind, el in enumerate(arr): 80 | if fn(el): 81 | return ind 82 | return -1 83 | 84 | def pick_and_pop(keys, d): 85 | values = list(map(lambda key: d.pop(key), keys)) 86 | return dict(zip(keys, values)) 87 | 88 | def group_dict_by_key(cond, d): 89 | return_val = [dict(),dict()] 90 | for key in d.keys(): 91 | match = bool(cond(key)) 92 | ind = int(not match) 93 | return_val[ind][key] = d[key] 94 | return (*return_val,) 95 | 96 | def string_begins_with(prefix, str): 97 | return str.startswith(prefix) 98 | 99 | def group_by_key_prefix(prefix, d): 100 | return group_dict_by_key(partial(string_begins_with, prefix), d) 101 | 102 | def groupby_prefix_and_trim(prefix, d): 103 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 104 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 105 | return kwargs_without_prefix, kwargs 106 | 107 | def num_to_groups(num, divisor): 108 | groups = num // divisor 109 | remainder = num % divisor 110 | arr = [divisor] * groups 111 | if remainder > 0: 112 | arr.append(remainder) 113 | return arr 114 | 115 | # url to fs, bucket, path - for checkpointing to cloud 116 | 117 | def url_to_bucket(url): 118 | if '://' not in url: 119 | return url 120 | 121 | _, suffix = url.split('://') 122 | 123 | if prefix in {'gs', 's3'}: 124 | return suffix.split('/')[0] 125 | else: 126 | raise ValueError(f'storage type prefix "{prefix}" is not supported yet') 127 | 128 | # decorators 129 | 130 | def eval_decorator(fn): 131 | def inner(model, *args, **kwargs): 132 | was_training = model.training 133 | model.eval() 134 | out = fn(model, *args, **kwargs) 135 | model.train(was_training) 136 | return out 137 | return inner 138 | 139 | def cast_torch_tensor(fn, cast_fp16 = False): 140 | @wraps(fn) 141 | def inner(model, *args, **kwargs): 142 | device = kwargs.pop('_device', model.device) 143 | cast_device = kwargs.pop('_cast_device', True) 144 | 145 | should_cast_fp16 = cast_fp16 and model.cast_half_at_training 146 | 147 | kwargs_keys = kwargs.keys() 148 | all_args = (*args, *kwargs.values()) 149 | split_kwargs_index = len(all_args) - len(kwargs_keys) 150 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) 151 | 152 | if cast_device: 153 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) 154 | 155 | if should_cast_fp16: 156 | all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) 157 | 158 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] 159 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) 160 | 161 | out = fn(model, *args, **kwargs) 162 | return out 163 | return inner 164 | 165 | # gradient accumulation functions 166 | 167 | def split_iterable(it, split_size): 168 | accum = [] 169 | for ind in range(ceil(len(it) / split_size)): 170 | start_index = ind * split_size 171 | accum.append(it[start_index: (start_index + split_size)]) 172 | return accum 173 | 174 | def split(t, split_size = None): 175 | if not exists(split_size): 176 | return t 177 | 178 | if isinstance(t, torch.Tensor): 179 | return t.split(split_size, dim = 0) 180 | 181 | if isinstance(t, Iterable): 182 | return split_iterable(t, split_size) 183 | 184 | return TypeError 185 | 186 | def find_first(cond, arr): 187 | for el in arr: 188 | if cond(el): 189 | return el 190 | return None 191 | 192 | def split_args_and_kwargs(*args, split_size = None, **kwargs): 193 | all_args = (*args, *kwargs.values()) 194 | len_all_args = len(all_args) 195 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) 196 | assert exists(first_tensor) 197 | 198 | batch_size = len(first_tensor) 199 | split_size = default(split_size, batch_size) 200 | num_chunks = ceil(batch_size / split_size) 201 | 202 | dict_len = len(kwargs) 203 | dict_keys = kwargs.keys() 204 | split_kwargs_index = len_all_args - dict_len 205 | 206 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] 207 | chunk_sizes = num_to_groups(batch_size, split_size) 208 | 209 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): 210 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] 211 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) 212 | chunk_size_frac = chunk_size / batch_size 213 | yield chunk_size_frac, (chunked_args, chunked_kwargs) 214 | 215 | 216 | # imagen trainer 217 | 218 | def imagen_sample_in_chunks(fn): 219 | @wraps(fn) 220 | def inner(self, *args, max_batch_size = None, **kwargs): 221 | if not exists(max_batch_size): 222 | return fn(self, *args, **kwargs) 223 | 224 | if self.imagen.unconditional: 225 | batch_size = kwargs.get('batch_size') 226 | batch_sizes = num_to_groups(batch_size, max_batch_size) 227 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] 228 | else: 229 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] 230 | 231 | if isinstance(outputs[0], torch.Tensor): 232 | return torch.cat(outputs, dim = 0) 233 | 234 | return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) 235 | 236 | return inner 237 | 238 | 239 | def restore_parts(state_dict_target, state_dict_from): 240 | for name, param in state_dict_from.items(): 241 | 242 | if name not in state_dict_target: 243 | continue 244 | 245 | if param.size() == state_dict_target[name].size(): 246 | state_dict_target[name].copy_(param) 247 | else: 248 | print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") 249 | 250 | return state_dict_target 251 | 252 | # ////////////////////////////////////////////////////////////// 253 | # 5. Main class: 254 | # ////////////////////////////////////////////////////////////// 255 | 256 | class ProteinDesigner_Trainer(nn.Module): 257 | locked = False 258 | 259 | def __init__( 260 | self, 261 | # 1. on models 262 | ProtDesi = None, # provide a object 263 | ProtDesi_checkpoint_path = None, # provide a checkpoint path 264 | only_train_unet_number = None, 265 | # 2. on optimizer 266 | use_ema = True, 267 | lr = 1e-4, 268 | eps = 1e-8, 269 | beta1 = 0.9, 270 | beta2 = 0.99, 271 | max_grad_norm = None, 272 | group_wd_params = True, 273 | warmup_steps = None, 274 | cosine_decay_max_steps = None, 275 | 276 | fp16 = False, 277 | precision = None, 278 | split_batches = True, 279 | dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), 280 | verbose = True, 281 | split_valid_fraction = 0.025, 282 | split_valid_from_train = False, 283 | split_random_seed = 42, 284 | checkpoint_path = None, 285 | checkpoint_every = None, 286 | checkpoint_fs = None, 287 | fs_kwargs: dict = None, 288 | max_checkpoints_keep = 20, 289 | # ++ 290 | CKeys = {'Debug_Level':0}, 291 | **kwargs 292 | ): 293 | super().__init__() 294 | 295 | # 0. asserts some 296 | # ..................................................... 297 | assert not ProteinDesigner_Trainer.locked, 'ProteinDesigner_Trainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' 298 | 299 | assert exists(ProtDesi) ^ exists(ProtDesi_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' 300 | 301 | # ++ 302 | self.CKeys = CKeys 303 | if self.CKeys['Debug_Level']==PT_Init_Level: 304 | print (f"|||||||||||||||||||||||||||||||||||||||||||||||||||") 305 | print (f"Initialize Protein_Designer Trainer object...") 306 | 307 | # determine filesystem, using fsspec, for saving to local filesystem or cloud 308 | self.fs = checkpoint_fs 309 | 310 | if not exists(self.fs): 311 | fs_kwargs = default(fs_kwargs, {}) 312 | self.fs, _ = url_to_fs( 313 | default(checkpoint_path, './'), **fs_kwargs 314 | ) 315 | # ++ 316 | if self.CKeys['Debug_Level']==PT_Init_Level: 317 | print (f"file system: .fs: {self.fs}") 318 | 319 | assert isinstance(ProtDesi, (ProteinDesigner_Base)), \ 320 | "ProtDesi is not from ProteinDesigner_Base" 321 | 322 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) 323 | # ++ 324 | if self.CKeys['Debug_Level']==PT_Init_Level: 325 | print (f"ema_kwargs: {ema_kwargs}") 326 | print (f"kwargs: {kwargs}") 327 | 328 | self.is_elucidated = isinstance( 329 | ProtDesi.diffuser_core, ElucidatedImagen_OneD 330 | ) 331 | 332 | # create accelerator instance 333 | 334 | accelerate_kwargs, kwargs = groupby_prefix_and_trim( 335 | 'accelerate_', kwargs 336 | ) 337 | # ++ 338 | if self.CKeys['Debug_Level']==PT_Init_Level: 339 | print (f"create acce instance...") 340 | print (f"accelerate_kwargs: {accelerate_kwargs}") 341 | print (f"kwargs: {kwargs}") 342 | 343 | assert not (fp16 and exists(precision)), \ 344 | 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' 345 | accelerator_mixed_precision = default( 346 | precision, 347 | 'fp16' if fp16 else 'no' 348 | ) 349 | 350 | self.accelerator = Accelerator(**{ 351 | 'split_batches': split_batches, 352 | 'mixed_precision': accelerator_mixed_precision, 353 | 'kwargs_handlers': [ 354 | DistributedDataParallelKwargs(find_unused_parameters = True) 355 | ], 356 | **accelerate_kwargs}) 357 | 358 | # .is_distributed is a self fun 359 | ProteinDesigner_Trainer.locked = self.is_distributed 360 | # ++ 361 | if self.CKeys['Debug_Level']==PT_Init_Level: 362 | print (f".is_distributed or .locked: {ProteinDesigner_Trainer.locked}") 363 | 364 | # cast data to fp16 at training time if needed 365 | self.cast_half_at_training = accelerator_mixed_precision == 'fp16' 366 | 367 | # grad scaler must be managed outside of accelerator 368 | grad_scaler_enabled = fp16 369 | 370 | # ProteinDesigner, imagen, unets and ema unets 371 | self.ProtDesi = ProtDesi 372 | self.imagen = ProtDesi.diffuser_core # imagen 373 | self.num_unets = len(self.imagen.unets) 374 | 375 | self.use_ema = use_ema and self.is_main 376 | self.ema_unets = nn.ModuleList([]) 377 | # ++ 378 | if self.CKeys['Debug_Level']==PT_Init_Level: 379 | print (f".num_unets: {self.num_unets}") 380 | print (f".use_ema: {self.use_ema}") 381 | 382 | # keep track of what unet is being trained on 383 | # only going to allow 1 unet training at a time 384 | 385 | self.ema_unet_being_trained_index = -1 386 | # keeps track of which ema unet is being trained on 387 | 388 | # data related functions 389 | 390 | self.train_dl_iter = None 391 | self.train_dl = None 392 | 393 | self.valid_dl_iter = None 394 | self.valid_dl = None 395 | 396 | self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names 397 | 398 | # auto splitting validation from training, if dataset is passed in 399 | 400 | self.split_valid_from_train = split_valid_from_train 401 | 402 | assert 0 <= split_valid_fraction <= 1, \ 403 | 'split valid fraction must be between 0 and 1' 404 | self.split_valid_fraction = split_valid_fraction 405 | self.split_random_seed = split_random_seed 406 | 407 | # be able to finely customize learning rate, weight decay 408 | # per unet 409 | 410 | # ++ 411 | if self.CKeys['Debug_Level']==PT_Init_Level: 412 | print (f" Finely customize learning rate, weight decay") 413 | 414 | lr, eps, warmup_steps, cosine_decay_max_steps = map( 415 | partial(cast_tuple, length = self.num_unets), 416 | (lr, eps, warmup_steps, cosine_decay_max_steps) 417 | ) 418 | 419 | for ind, ( 420 | unet, unet_lr, unet_eps, 421 | unet_warmup_steps, unet_cosine_decay_max_steps 422 | ) in enumerate( 423 | zip( 424 | self.imagen.unets, 425 | lr, eps, warmup_steps, cosine_decay_max_steps 426 | ) 427 | ): 428 | 429 | optimizer = Adam( 430 | unet.parameters(), 431 | lr = unet_lr, 432 | eps = unet_eps, 433 | betas = (beta1, beta2), 434 | **kwargs 435 | ) 436 | 437 | if self.use_ema: 438 | self.ema_unets.append(EMA(unet, **ema_kwargs)) 439 | 440 | scaler = GradScaler(enabled = grad_scaler_enabled) 441 | 442 | scheduler = warmup_scheduler = None 443 | 444 | if exists(unet_cosine_decay_max_steps): 445 | scheduler = CosineAnnealingLR( 446 | optimizer, 447 | T_max = unet_cosine_decay_max_steps 448 | ) 449 | 450 | if exists(unet_warmup_steps): 451 | warmup_scheduler = warmup.LinearWarmup( 452 | optimizer, 453 | warmup_period = unet_warmup_steps 454 | ) 455 | 456 | if not exists(scheduler): 457 | scheduler = LambdaLR( 458 | optimizer, 459 | lr_lambda = lambda step: 1.0 460 | ) 461 | 462 | # set on object 463 | 464 | setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers 465 | setattr(self, f'scaler{ind}', scaler) 466 | setattr(self, f'scheduler{ind}', scheduler) 467 | setattr(self, f'warmup{ind}', warmup_scheduler) 468 | 469 | # ++ 470 | if self.CKeys['Debug_Level']==PT_Init_Level: 471 | print (f" on Unit-{ind}") 472 | print (f" scaler: {scaler}") 473 | print (f" scheduler: {scheduler}") 474 | print (f" warmup_scheduler: {warmup_scheduler}") 475 | 476 | 477 | # gradient clipping if needed 478 | 479 | self.max_grad_norm = max_grad_norm 480 | 481 | # step tracker and misc 482 | 483 | self.register_buffer('steps', torch.tensor([0] * self.num_unets)) 484 | 485 | self.verbose = verbose 486 | 487 | # automatic set devices based on what accelerator decided 488 | 489 | # self.imagen.to(self.device) 490 | self.ProtDesi.to(self.device) 491 | self.to(self.device) 492 | 493 | # checkpointing 494 | 495 | assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) 496 | self.checkpoint_path = checkpoint_path 497 | self.checkpoint_every = checkpoint_every 498 | self.max_checkpoints_keep = max_checkpoints_keep 499 | 500 | self.can_checkpoint = self.is_local_main \ 501 | if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main 502 | 503 | # ++ 504 | if self.CKeys['Debug_Level']==PT_Init_Level: 505 | print (f".checkpoint_path: {self.checkpoint_path}") 506 | print (f".checkpoint_every: {self.checkpoint_every}") 507 | print (f".max_checkpoints_keep: {self.max_checkpoints_keep}") 508 | print (f".can_checkpoint: {self.can_checkpoint}") 509 | 510 | if exists(checkpoint_path) and self.can_checkpoint: 511 | bucket = url_to_bucket(checkpoint_path) 512 | 513 | if not self.fs.exists(bucket): 514 | self.fs.mkdir(bucket) 515 | 516 | self.load_from_checkpoint_folder() 517 | 518 | # only allowing training for unet 519 | 520 | self.only_train_unet_number = only_train_unet_number 521 | self.prepared = False 522 | # ++ 523 | if self.CKeys['Debug_Level']==PT_Init_Level: 524 | print (f".only_train_unet_number: {self.only_train_unet_number}") 525 | print (f".prepared: {self.prepared}") 526 | 527 | # |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| 528 | # computed values 529 | @property 530 | def device(self): 531 | return self.accelerator.device 532 | 533 | @property 534 | def is_distributed(self): 535 | return not ( 536 | self.accelerator.distributed_type == DistributedType.NO \ 537 | and self.accelerator.num_processes == 1 538 | ) 539 | 540 | @property 541 | def is_main(self): 542 | return self.accelerator.is_main_process 543 | 544 | @property 545 | def is_local_main(self): 546 | return self.accelerator.is_local_main_process 547 | 548 | @property 549 | def unwrapped_unet(self): 550 | return self.accelerator.unwrap_model(self.unet_being_trained) 551 | 552 | # |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| 553 | # optimizer helper functions 554 | 555 | 556 | # |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| 557 | 558 | def load_from_checkpoint_folder( 559 | self, 560 | last_total_steps = -1 561 | ): 562 | if last_total_steps != -1: 563 | filepath = os.path.join( 564 | self.checkpoint_path, 565 | f'checkpoint.{last_total_steps}.pt' 566 | ) 567 | self.load(filepath) 568 | return 569 | 570 | sorted_checkpoints = self.all_checkpoints_sorted 571 | 572 | if len(sorted_checkpoints) == 0: 573 | self.print( 574 | f'no checkpoints found to load from at {self.checkpoint_path}' 575 | ) 576 | return 577 | 578 | last_checkpoint = sorted_checkpoints[0] 579 | self.load(last_checkpoint) 580 | 581 | 582 | 583 | 584 | 585 | 586 | # |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| 587 | # Forward_Pack 588 | 589 | # validating the unet number 590 | 591 | def validate_unet_number( 592 | self, 593 | unet_number = None 594 | ): 595 | if self.num_unets == 1: 596 | unet_number = default(unet_number, 1) 597 | 598 | assert 0 < unet_number <= self.num_unets, \ 599 | f'unet number should be in between 1 and {self.num_unets}' 600 | 601 | return unet_number 602 | 603 | # function for allowing only one unet from being trained at a time 604 | 605 | def validate_and_set_unet_being_trained(self, unet_number = None): 606 | if exists(unet_number): 607 | self.validate_unet_number(unet_number) 608 | 609 | assert not exists(self.only_train_unet_number) or \ 610 | self.only_train_unet_number == unet_number, \ 611 | 'you can only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' 612 | 613 | self.only_train_unet_number = unet_number 614 | self.imagen.only_train_unet_number = unet_number 615 | 616 | if not exists(unet_number): 617 | return 618 | 619 | self.wrap_unet(unet_number) 620 | 621 | 622 | def wrap_unet(self, unet_number): 623 | if hasattr(self, 'one_unet_wrapped'): 624 | return 625 | 626 | unet = self.imagen.get_unet(unet_number) 627 | unet_index = unet_number - 1 628 | 629 | optimizer = getattr(self, f'optim{unet_index}') 630 | scheduler = getattr(self, f'scheduler{unet_index}') 631 | 632 | if self.train_dl: 633 | self.unet_being_trained, self.train_dl, optimizer\ 634 | = self.accelerator.prepare( 635 | unet, self.train_dl, optimizer 636 | ) 637 | else: 638 | self.unet_being_trained, optimizer\ 639 | = self.accelerator.prepare(unet, optimizer) 640 | 641 | if exists(scheduler): 642 | scheduler = self.accelerator.prepare(scheduler) 643 | 644 | setattr(self, f'optim{unet_index}', optimizer) 645 | setattr(self, f'scheduler{unet_index}', scheduler) 646 | 647 | self.one_unet_wrapped = True 648 | 649 | # hacking accelerator due to not having separate gradscaler per optimizer 650 | 651 | def set_accelerator_scaler(self, unet_number): 652 | 653 | def patch_optimizer_step(accelerated_optimizer, method): 654 | def patched_step(*args, **kwargs): 655 | accelerated_optimizer._accelerate_step_called = True 656 | return method(*args, **kwargs) 657 | return patched_step 658 | 659 | unet_number = self.validate_unet_number(unet_number) 660 | scaler = getattr(self, f'scaler{unet_number - 1}') 661 | 662 | self.accelerator.scaler = scaler 663 | for optimizer in self.accelerator._optimizers: 664 | optimizer.scaler = scaler 665 | optimizer._accelerate_step_called = False 666 | optimizer._optimizer_original_step_method = optimizer.optimizer.step 667 | optimizer._optimizer_patched_step_method = patch_optimizer_step( 668 | optimizer, optimizer.optimizer.step 669 | ) 670 | 671 | 672 | 673 | @partial(cast_torch_tensor, cast_fp16 = True) 674 | def forward( 675 | self, 676 | *args, 677 | unet_number = None, 678 | max_batch_size = None, 679 | **kwargs 680 | ): 681 | # ++ 682 | if self.CKeys['Debug_Level']==PT_Forw_Level: 683 | print (f"Debug mode for trainer.forward...") 684 | 685 | unet_number = self.validate_unet_number(unet_number) # check if unet_number is in the range 686 | # ++ 687 | if self.CKeys['Debug_Level']==PT_Forw_Level: 688 | print (f"Train UNet number: {unet_number}") 689 | 690 | self.validate_and_set_unet_being_trained(unet_number) 691 | self.set_accelerator_scaler(unet_number) 692 | 693 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' 694 | 695 | total_loss = 0. 696 | 697 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): 698 | 699 | with self.accelerator.autocast(): 700 | #-- 701 | # loss = self.imagen( 702 | #++ 703 | loss = self.ProtDesi( 704 | *chunked_args, 705 | unet = self.unet_being_trained, 706 | unet_number = unet_number, **chunked_kwargs 707 | ) 708 | loss = loss * chunk_size_frac 709 | # ++ 710 | if self.CKeys['Debug_Level']==PT_Forw_Level: 711 | print (f"get loss for a fraction: {loss}") 712 | 713 | total_loss += loss.item() 714 | # ++ 715 | if self.CKeys['Debug_Level']==PT_Forw_Level: 716 | print (f"update tot_loss: {total_loss}") 717 | 718 | if self.training: 719 | self.accelerator.backward(loss) 720 | 721 | return total_loss 722 | -------------------------------------------------------------------------------- /VibeGen/UtilityPack.py: -------------------------------------------------------------------------------- 1 | # ========================================================== 2 | # Utility functions 3 | # ========================================================== 4 | import os 5 | from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator 6 | import numpy as np 7 | import math 8 | import matplotlib.pyplot as plt 9 | 10 | from Bio.PDB import PDBParser 11 | from Bio.PDB.DSSP import DSSP 12 | from Bio.PDB import PDBList 13 | 14 | import torch 15 | from einops import rearrange 16 | import esm 17 | 18 | import json 19 | 20 | # ========================================================= 21 | # 22 | def Print(this_line): 23 | # may update for multi-core case later 24 | print (this_line) 25 | 26 | def print_dict_content(this_dict): 27 | 28 | for this_key in this_dict.keys(): 29 | print (f" {this_key}: {this_dict[this_key]}") 30 | 31 | # ========================================================= 32 | # create a folder path if not exist 33 | def create_path(this_path): 34 | if not os.path.exists(this_path): 35 | print('Creating the given path...') 36 | os.mkdir (this_path) 37 | path_stat = 1 38 | print('Done.') 39 | else: 40 | print('The given path already exists!') 41 | path_stat = 2 42 | return path_stat 43 | 44 | # ============================================================ 45 | # on esm, rebuild AA sequence from embedding 46 | # ============================================================ 47 | 48 | def decode_one_ems_token_rec(this_token, esm_alphabet): 49 | # print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] ) 50 | # print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] ) 51 | # print( (this_token==100).nonzero(as_tuple=True)[0]==None ) 52 | 53 | id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] 54 | id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] 55 | 56 | 57 | if len(id_e)==0: 58 | # no ending for this one, so id_e points to the end 59 | id_e=len(this_token) 60 | else: 61 | id_e=id_e[0] 62 | if len(id_b)==0: 63 | id_b=0 64 | else: 65 | id_b=id_b[-1] 66 | 67 | this_seq = [] 68 | # this_token_used = [] 69 | for ii in range(id_b+1,id_e,1): 70 | # this_token_used.append(this_token[ii]) 71 | this_seq.append( 72 | esm_alphabet.get_tok(this_token[ii]) 73 | ) 74 | 75 | this_seq = "".join(this_seq) 76 | 77 | # print(this_seq) 78 | # print(len(this_seq)) 79 | # # print(this_token[id_b+1:id_e]) 80 | return this_seq 81 | 82 | 83 | def decode_many_ems_token_rec(batch_tokens, esm_alphabet): 84 | rev_y_seq = [] 85 | for jj in range(len(batch_tokens)): 86 | # do for one seq: this_seq 87 | this_seq = decode_one_ems_token_rec( 88 | batch_tokens[jj], esm_alphabet 89 | ) 90 | rev_y_seq.append(this_seq) 91 | return rev_y_seq 92 | 93 | def Print_model_params (model): 94 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 95 | pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 96 | 97 | Print ( 98 | f"Total model parameters: {pytorch_total_params}\nTrainable parameters: {pytorch_total_params_trainable}\n" 99 | ) 100 | 101 | def get_model_params (model): 102 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 103 | pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 104 | 105 | resu = { 106 | 'tot': pytorch_total_params, 107 | 'trainable': pytorch_total_params_trainable, 108 | 'freezed': pytorch_total_params-pytorch_total_params_trainable 109 | } 110 | 111 | return resu 112 | 113 | def write_one_line_to_file( 114 | this_line, 115 | file_name, 116 | mode, 117 | accelerator=None 118 | ): 119 | with open(file_name, mode) as f: 120 | f.write(this_line) 121 | 122 | # ============================================================== 123 | # 124 | # def convert_into_tokens_using_prob( 125 | # prob_result, 126 | # pLM_Model_Name 127 | # ): 128 | # if pLM_Model_Name=='esm2_t33_650M_UR50D' \ 129 | # or pLM_Model_Name=='esm2_t36_3B_UR50D' \ 130 | # or pLM_Model_Name=='esm2_t30_150M_UR50D' \ 131 | # or pLM_Model_Name=='esm2_t12_35M_UR50D' : 132 | 133 | # repre=rearrange( 134 | # prob_result, 135 | # 'b c l -> b l c' 136 | # ) 137 | # # with torch.no_grad(): 138 | # # logits=model.lm_head(repre) # (b, l, token_dim) 139 | # logits = repre 140 | 141 | # tokens=logits.max(2).indices # (b,l) 142 | 143 | # else: 144 | # print("pLM_Model is not defined...") 145 | # return tokens,logits 146 | 147 | def read_mask_from_input( 148 | # consider different type of inputs 149 | # raw data: x_data (sequences) 150 | # tokenized: x_data_tokenized 151 | tokenized_data=None, # X_train_batch, 152 | mask_value=None, 153 | seq_data=None, # Y_train_batch, 154 | max_seq_length=None, 155 | ): 156 | # # old: 157 | # mask = X_train_batch!=mask_value 158 | # new 159 | if seq_data!=None: 160 | # use the real sequence length to create mask 161 | n_seq = len(seq_data) 162 | mask = torch.zeros(n_seq, max_seq_length) 163 | for ii in range(n_seq): 164 | this_len = len(seq_data[ii]) 165 | mask[ii,1:1+this_len]=1 166 | mask = mask==1 167 | # 168 | elif tokenized_data!=None: 169 | n_seq = len(tokenized_data) 170 | mask = tokenized_data!=mask_value 171 | # fix the beginning part: 0+content+00, not 00+content+00 172 | for ii in range(n_seq): 173 | # get all nonzero index 174 | id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0] 175 | # correction for ForcPath, 176 | # pick up 0.0 for zero-force padding at the beginning 177 | mask[ii,1:id_1[0]]=True 178 | 179 | return mask 180 | 181 | # on pLM tokens 182 | # basic 20 in abr order: ARNDCEQGHILKMFPSTWYV 183 | # in esm, tot = 33 184 | # basic 20 in esm order: LAGVSERTIDPKQNFYMHWC 185 | # others (4): 186 | # special (9): X B U Z O . - 187 | # LAGVSERTIDPKQNFYMHWC: toke the channels: 4-23 188 | # full dict 189 | esm_tok_to_idx = \ 190 | {'': 0, '': 1, '': 2, '': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '': 31, '': 32} 191 | 192 | esm_idx_to_tok = \ 193 | {'0': '', '1': '', '2': '', '3': '', '4': 'L', '5': 'A', '6': 'G', '7': 'V', '8': 'S', '9': 'E', '10': 'R', '11': 'T', '12': 'I', '13': 'D', '14': 'P', '15': 'K', '16': 'Q', '17': 'N', '18': 'F', '19': 'Y', '20': 'M', '21': 'H', '22': 'W', '23': 'C', '24': 'X', '25': 'B', '26': 'U', '27': 'Z', '28': 'O', '29': '.', '30': '-', '31': '', '32': ''} 194 | 195 | common_AA_list = "LAGVSERTIDPKQNFYMHWC" 196 | 197 | 198 | # common_AA_idx_in_esm = [] 199 | # for ii in range(len(common_AA_list)): 200 | # common_AA_idx_in_esm.append( 201 | # esm_tok_to_idx[ 202 | # common_AA_list[ii] 203 | # ] 204 | # ) 205 | 206 | common_AA_idx_in_esm = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] 207 | 208 | def keep_only_20AA_channels_in_one_pLM_logits( 209 | full_logits, # (seq_len, channel) 210 | keep_channels=common_AA_idx_in_esm 211 | ): 212 | assert full_logits.shape[-1]==33, \ 213 | "Not ESM logits shape" 214 | 215 | n_channel = full_logits.shape[-1] 216 | for this_c in range(n_channel): 217 | if not (this_c in keep_channels): 218 | full_logits[:,this_c]=-float('inf') 219 | 220 | return full_logits 221 | 222 | def get_toks_list_from_Y_batch( 223 | batch_GT, # (b, seq_len) 224 | batch_mask, # (b, seq_len) 225 | ): 226 | toks_list = [] 227 | seqs_list = [] 228 | 229 | for ii in range(len(batch_GT)): 230 | this_GT = batch_GT[ii] 231 | this_mask = batch_mask[ii] 232 | this_GT = this_GT[this_mask==True] 233 | # 234 | toks_list.append(this_GT) 235 | this_seq = [ 236 | esm_idx_to_tok[str(jj.item())] for jj in this_GT 237 | ] 238 | this_seq = "".join(this_seq) 239 | seqs_list.append(this_seq) 240 | 241 | 242 | return toks_list, seqs_list 243 | 244 | def compare_two_seq_strings(seq_PR, seq_GT): 245 | # take seq_GT as the ref, 246 | # assume len(seq_GT)>=len(seq_PR) 247 | len_comp = min( len(seq_PR), len(seq_GT)) 248 | num_hit = 0 249 | for ii in range(len_comp): 250 | if seq_PR[ii]==seq_GT[ii]: 251 | num_hit += 1 252 | ratio_hit = num_hit/len_comp 253 | 254 | return ratio_hit 255 | 256 | def save_2d_tensor_as_np_arr_txt( 257 | X_tensor, # (a, b) 258 | mask = None, # (b) 259 | outname = None, 260 | ): 261 | assert X_tensor.dim() == 2 262 | 263 | if not (mask is None): 264 | assert mask.dim() == 1 265 | 266 | if not (mask is None): 267 | X_tensor = X_tensor[:, mask] 268 | 269 | 270 | test_one_X_arr = X_tensor.cpu().detach().numpy() 271 | if outname is None: 272 | print (test_one_X_arr) 273 | else: 274 | np.savetxt(outname, test_one_X_arr) 275 | # # to read back as a 2d np arr 276 | # test_one_X_arr_1 = np.loadtxt(test_file) 277 | 278 | # ++ read back for checking 279 | def read_2d_np_arr_from_txt( 280 | test_file 281 | ): 282 | test_one_X_arr_1 = np.loadtxt(test_file) 283 | return test_one_X_arr_1 284 | 285 | def string_diff (seq1, seq2): 286 | return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2)) 287 | 288 | # def write_fasta_file( 289 | # this_seq, 290 | # this_head, 291 | # this_file 292 | # ): 293 | # with open(this_file, mode = 'w') as f: 294 | # f.write (f">{this_head}\n") 295 | # f.write (f"{this_seq}") 296 | 297 | def write_fasta_file( 298 | this_seq_list, 299 | this_head_list, 300 | this_file 301 | ): 302 | n_seq = len(this_seq_list) 303 | 304 | with open(this_file, mode = 'w') as f: 305 | for i_seq in range(n_seq): 306 | 307 | f.write (f">{this_head_list[i_seq]}\n") 308 | f.write (f"{this_seq_list[i_seq]}\n") 309 | 310 | # ++ 311 | def read_recover_AAs_only(test_fasta_file): 312 | 313 | file1 = open(test_fasta_file, 'r') 314 | Lines = file1.readlines() 315 | # only get AA 316 | AA_GT = Lines[1].strip() 317 | AA_recon_GT = Lines[3].strip() 318 | 319 | resu = {} 320 | resu['AA_GT'] = AA_GT 321 | resu['AA_recon_GT'] = AA_recon_GT 322 | 323 | return resu 324 | 325 | # =================================================================================== 326 | # old one 327 | def fold_one_AA_to_SS_using_omegafold_for_5_Diffusionfold( 328 | sequence, 329 | num_cycle=16, 330 | device=None, 331 | # ++++++++++++++ 332 | prefix=None, 333 | AA_file_path=None, 334 | PDB_file_path=None, # output file path 335 | head_note=None, 336 | ): 337 | AA_file_name = f"{AA_file_path}/{prefix}_.fasta" 338 | print ("Writing FASTA file: ", AA_file_name) 339 | head_line = f"{head_note}" 340 | with open (AA_file_name, mode ='w') as f: 341 | f.write (f'>{head_line}\n') 342 | f.write (f'{sequence}') 343 | # 344 | # 345 | PDB_result=f"{PDB_file_path}/{head_line}.pdb" 346 | if not os.path.exists(PDB_result): 347 | print (f"Now run OmegaFold.... on device={device}") 348 | # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device 349 | cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --num_cycle {num_cycle} --device={device}" 350 | print(os.popen(cmd_line).read()) 351 | 352 | print ("Done OmegaFold") 353 | 354 | # PDB_result=f"{prefix}{OUTFILE}.PDB" 355 | 356 | print (f"Resulting PDB file...: {PDB_result}") 357 | else: 358 | print (f"PDB file already exist.") 359 | 360 | return PDB_result, AA_file_name 361 | # 362 | # =================================================================================== 363 | # new one: need to install the modified omegafold from self-hold repo 364 | # https://github.com/Bo-Ni/OmegaFold_0.git 365 | def get_subbatch_size(L): 366 | if L < 500: return 500 367 | if L < 1000: return 500 # 500 # 200 368 | return 150 369 | 370 | def fold_one_AA_to_SS_using_omegafold( 371 | sequence, 372 | num_cycle=16, 373 | device=None, 374 | # ++++++++++++++ 375 | prefix="Temp", # None, 376 | AA_file_path="./", # None, 377 | PDB_file_path="./", # output file path 378 | head_note="Temp_", # None, 379 | ): 380 | AA_file_name = f"{AA_file_path}/{prefix}_.fasta" 381 | print ("Writing FASTA file: ", AA_file_name) 382 | head_line = f"{head_note}" 383 | with open (AA_file_name, mode ='w') as f: 384 | f.write (f'>{head_line}\n') 385 | f.write (f'{sequence}') 386 | # 387 | subbatch_size = get_subbatch_size(len(sequence)) 388 | # 389 | PDB_result=f"{PDB_file_path}/{head_line}.pdb" 390 | 391 | if not os.path.exists(PDB_result): 392 | Print (f"Now run OmegaFold.... on device={device}\n\n") 393 | # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device 394 | # cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --num_cycle {num_cycle} --device={device}" 395 | cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --subbatch_size {str(subbatch_size)} --num_cycle {num_cycle} --device={device}" 396 | 397 | Print(os.popen(cmd_line).read()) 398 | 399 | Print ("Done OmegaFold") 400 | 401 | # PDB_result=f"{prefix}{OUTFILE}.PDB" 402 | 403 | Print (f"Resulting PDB file...: {PDB_result}\n\n") 404 | else: 405 | Print (f"PDB file already exist.") 406 | 407 | return PDB_result, AA_file_name 408 | # 409 | # =================================================================================== 410 | # plot 411 | import py3Dmol 412 | 413 | def plot_plddt_legend(dpi=100): 414 | thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)'] 415 | plt.figure(figsize=(1,0.1),dpi=dpi) 416 | ######################################## 417 | for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]: 418 | plt.bar(0, 0, color=c) 419 | plt.legend(thresh, frameon=False, 420 | loc='center', ncol=6, 421 | handletextpad=1, 422 | columnspacing=1, 423 | markerscale=0.5,) 424 | plt.axis(False) 425 | return plt 426 | 427 | color = "lDDT" # choose from ["chain", "lDDT", "rainbow"] 428 | show_sidechains = False #choose from {type:"boolean"} 429 | show_mainchains = False #choose from {type:"boolean"} 430 | 431 | def show_pdb( 432 | pdb_file, 433 | flag=0, 434 | show_sidechains=False, 435 | show_mainchains=False, 436 | color="lDDT" 437 | ): 438 | model_name = f"Flag_{flag}" 439 | view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',) 440 | view.addModel(open(pdb_file,'r').read(),'pdb') 441 | 442 | if color == "lDDT": 443 | view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}}) 444 | elif color == "rainbow": 445 | view.setStyle({'cartoon': {'color':'spectrum'}}) 446 | elif color == "chain": 447 | chains = len(queries[0][1]) + 1 if is_complex else 1 448 | for n,chain,color in zip( 449 | range(chains),list("ABCDEFGH"), 450 | ["lime","cyan","magenta","yellow","salmon","white","blue","orange"] 451 | ): 452 | view.setStyle({'chain':chain},{'cartoon': {'color':color}}) 453 | 454 | if show_sidechains: 455 | BB = ['C','O','N'] 456 | view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, 457 | {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) 458 | view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, 459 | {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) 460 | view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, 461 | {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) 462 | if show_mainchains: 463 | BB = ['C','O','N','CA'] 464 | view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) 465 | 466 | view.zoomTo() 467 | if color == "lDDT": 468 | plot_plddt_legend().show() 469 | # 470 | return view 471 | # 472 | # =================================================================================== 473 | # SecStr 474 | from Bio.PDB import PDBParser 475 | from Bio.PDB.DSSP import DSSP 476 | from Bio.PDB import PDBList 477 | 478 | Unique_SecStr_Q8_String="HET~BGIS" 479 | Unique_SecStr_Q3_String="HEC" 480 | # 481 | # ============================================= 482 | # count statistics of Q8 based SecStr 483 | # 484 | def count_ratio_for_Q8( 485 | this_secstr, 486 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String, 487 | ): 488 | resu = {} 489 | seq_len = len(this_secstr) 490 | for this_char in Unique_SecStr_Q8_String: 491 | resu[this_char] = this_secstr.count(this_char)/seq_len 492 | # 493 | return resu 494 | # ============================================= 495 | # count statistics of Q3 based SecStr 496 | # 497 | def count_ratio_for_Q3( 498 | this_secstr, 499 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String, 500 | ): 501 | resu = {} 502 | seq_len = len(this_secstr) 503 | for this_char in Unique_SecStr_Q3_String: 504 | resu[this_char] = this_secstr.count(this_char)/seq_len 505 | # 506 | return resu 507 | # =============================================== 508 | # 509 | def analyze_SS_Q8_Q3_for_df( 510 | df_smo_recon_BSDB_4P_expanded, 511 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String, 512 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String, 513 | ): 514 | # 515 | # do statistics on Q8 516 | this_key_to_add = 'stat_Q8' 517 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()): 518 | print (f"Add new key {this_key_to_add}") 519 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply( 520 | # ================ change this part =========================== 521 | lambda row: count_ratio_for_Q8( 522 | row['SS_Q8'], 523 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String, 524 | ), 525 | # ================ change this part =========================== 526 | axis=1, 527 | ) 528 | 529 | # do statistics on Q3 530 | this_key_to_add = 'stat_Q3' 531 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()): 532 | print (f"Add new key {this_key_to_add}") 533 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply( 534 | # ================ change this part =========================== 535 | lambda row: count_ratio_for_Q3( 536 | row['SS_Q3'], 537 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String, 538 | ), 539 | # ================ change this part =========================== 540 | axis=1, 541 | ) 542 | # 543 | # expand to df columns 544 | for this_char in Unique_SecStr_Q3_String: 545 | print (f"working on Q3 {this_char}") 546 | this_key_to_add = 'stat_Q3_'+this_char 547 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()): 548 | print (f"Add new key {this_key_to_add}") 549 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply( 550 | # ================ change this part =========================== 551 | lambda row: row['stat_Q3'][this_char], 552 | # ================ change this part =========================== 553 | axis=1, 554 | ) 555 | # expand to Q8 556 | for this_char in Unique_SecStr_Q8_String: 557 | print (f"working on Q8 {this_char}") 558 | this_key_to_add = 'stat_Q8_'+this_char 559 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()): 560 | print (f"Add new key {this_key_to_add}") 561 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply( 562 | # ================ change this part =========================== 563 | lambda row: row['stat_Q8'][this_char], 564 | # ================ change this part =========================== 565 | axis=1, 566 | ) 567 | 568 | return df_smo_recon_BSDB_4P_expanded 569 | # ================================================== 570 | 571 | def get_DSSP_result (fname): 572 | pdb_list = [fname] 573 | 574 | # parse structure 575 | p = PDBParser() 576 | for i in pdb_list: 577 | structure = p.get_structure(i, fname) 578 | # use only the first model 579 | model = structure[0] 580 | # calculate DSSP 581 | dssp = DSSP(model, fname, file_type='PDB' ) 582 | # extract sequence and secondary structure from the DSSP tuple 583 | sequence = '' 584 | sec_structure = '' 585 | for z in range(len(dssp)): 586 | a_key = list(dssp.keys())[z] 587 | sequence += dssp[a_key][1] 588 | sec_structure += dssp[a_key][2] 589 | 590 | # print extracted sequence and structure 591 | #print(i) 592 | #print(sequence) 593 | #print(sec_structure) 594 | # 595 | # The DSSP codes for secondary structure used here are: 596 | # ===== ==== 597 | # Code Structure 598 | # ===== ==== 599 | # H Alpha helix (4-12) 600 | # B Isolated beta-bridge residue 601 | # E Strand 602 | # G 3-10 helix 603 | # I Pi helix 604 | # T Turn 605 | # S Bend 606 | # - None 607 | # ===== ==== 608 | # 609 | 610 | sec_structure = sec_structure.replace('-', '~') 611 | sec_structure_3state=sec_structure 612 | 613 | 614 | # if desired, convert DSSP's 8-state assignments into 3-state [C - coil, E - extended (beta-strand), H - helix] 615 | sec_structure_3state = sec_structure_3state.replace('~', 'C') 616 | sec_structure_3state = sec_structure_3state.replace('I', 'C') 617 | sec_structure_3state = sec_structure_3state.replace('T', 'C') 618 | sec_structure_3state = sec_structure_3state.replace('S', 'C') 619 | sec_structure_3state = sec_structure_3state.replace('G', 'H') 620 | sec_structure_3state = sec_structure_3state.replace('B', 'E') 621 | 622 | return sec_structure,sec_structure_3state, sequence 623 | 624 | # ++ for postprocess 625 | def get_DSSP_set_result(fname): 626 | sec_structure,sec_structure_3state, sequence = get_DSSP_result (fname) 627 | resu={} 628 | resu['SecStr_Q8']=sec_structure 629 | resu['SecStr_Q3']=sec_structure_3state 630 | resu['AA_from_DSSP']=sequence 631 | 632 | return resu 633 | 634 | def write_DSSP_result_to_json( 635 | sec_structure, 636 | sec_structure_3state, 637 | sequence, 638 | filename, 639 | ): 640 | resu = { 641 | "Q8": sec_structure, 642 | "Q3": sec_structure_3state, 643 | "AA_from_DSSP": sequence 644 | } 645 | resu_json = json.dumps(resu, indent=4) 646 | 647 | with open(filename, "w") as f: 648 | f.write(resu_json) 649 | 650 | # # to read back 651 | # with open(filename, 'r') as openfile: 652 | # # Reading from json file 653 | # json_object = json.load(openfile) 654 | 655 | # print(json_object) 656 | # print(type(json_object)) # dict 657 | 658 | # ============================================================== 659 | # pick some Normal Mode from a df 660 | # For NMS vectors only 661 | def build_XCond_list_from_df( 662 | df, 663 | key_list, 664 | pick_id_list, 665 | ): 666 | n_mode = len(key_list) 667 | n_samp = len(pick_id_list) 668 | resu = [] 669 | for id_samp in pick_id_list: 670 | this_X_list = [] 671 | for this_key in key_list: 672 | add_one = df[this_key].values[id_samp] 673 | 674 | this_X_list.append( 675 | add_one 676 | ) 677 | this_X = np.array(this_X_list) 678 | resu.append(this_X) 679 | 680 | return resu 681 | 682 | # For AA Seq only 683 | def build_AA_list_from_df( 684 | df, 685 | AA_key, 686 | pick_id_list, 687 | ): 688 | n_samp = len(pick_id_list) 689 | resu = [] 690 | for id_samp in pick_id_list: 691 | resu.append( 692 | df[AA_key].values[id_samp] 693 | ) 694 | 695 | return resu 696 | 697 | # ============================================================== 698 | # add for Protein Predictor 699 | def get_nms_vec_as_arr_list_from_batch_using_mask( 700 | result_mask, # (b, seq_len) # torch.tensor 701 | output_diffuser, # (b, n_mode, seq_len) 702 | NormFac_list, # (n_mode, ) 703 | ): 704 | n_samp = output_diffuser.shape[0] 705 | n_mode = output_diffuser.shape[1] 706 | 707 | nms_vecs_list = [] 708 | for i_samp in range(n_samp): 709 | this_mask = result_mask[i_samp] # (seq_len, ) 710 | this_nms_vecs = output_diffuser[i_samp] 711 | 712 | # to take care of multi-modes 713 | this_nms_arr = [] 714 | for i_mode in range(n_mode): 715 | this_add = this_nms_vecs[i_mode][this_mask==True] # only work for 1D tensor 716 | this_add = this_add * NormFac_list[i_mode] # map it back to real values 717 | this_nms_arr.append( 718 | this_add.cpu().detach().numpy() 719 | ) 720 | this_nms_arr = np.array(this_nms_arr) # convert into np.arr 721 | 722 | # deliver to the list to store 723 | nms_vecs_list.append(this_nms_arr) 724 | 725 | return nms_vecs_list 726 | 727 | # compare two nms_vecs 728 | def compare_two_nms_vecs_arr( 729 | PR_nms_vecs, 730 | GT_nms_vecs, 731 | ): 732 | n_mode = GT_nms_vecs.shape[0] 733 | # calculate error for each mode and the tot 734 | # calculate rela_L2 error 735 | resu = {} 736 | for i_mode in range(n_mode): 737 | resu["rela_L2_err_Mode_"+str(i_mode)]=np.linalg.norm(PR_nms_vecs[i_mode]-GT_nms_vecs[i_mode])/np.linalg.norm(GT_nms_vecs[i_mode]) 738 | # 739 | # calculate for multi-modes 740 | resu["rela_L2_err_MulMode"]=np.linalg.norm(PR_nms_vecs-GT_nms_vecs)/np.linalg.norm(GT_nms_vecs) 741 | 742 | return resu 743 | 744 | # ====================================================== 745 | 746 | def translate_seqs_list_into_idx_tensor_w_pLM( 747 | # 1. model converter 748 | esm_batch_converter, 749 | AA_seq_max_len, 750 | # 2. on input 751 | raw_condition_list, 752 | device 753 | ): 754 | 755 | seqs_ext=[] 756 | # add a fake one to make sure the padding length 757 | dummy_seq = 'A'*(AA_seq_max_len-2) 758 | seqs_ext.append( 759 | (" ", dummy_seq) 760 | ) 761 | 762 | for i in range(len(raw_condition_list)): 763 | seqs_ext.append( 764 | (" ", raw_condition_list[i]) 765 | ) 766 | # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) 767 | _, y_strs, y_data = esm_batch_converter(seqs_ext) 768 | # y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) 769 | # print(batch_tokens.shape) 770 | # 771 | # ++ remove the dummy one 772 | y_data = y_data[1:] 773 | seqs_ext = seqs_ext[1:] 774 | 775 | y_data = y_data.to(device) 776 | 777 | return y_data 778 | 779 | # ================================================== 780 | 781 | # def cal_err_list_using_ -------------------------------------------------------------------------------- /VibeGen/JointSamplingPack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task: 3 | 1. create a trainer for ProteinDesigner 4 | 2. include train_loop, sample_loop 5 | 6 | Bo Ni, Sep 8, 2024 7 | """ 8 | 9 | # ////////////////////////////////////////////////////// 10 | # 0. load in packages 11 | # ////////////////////////////////////////////////////// 12 | 13 | import os 14 | from math import ceil 15 | from contextlib import contextmanager, nullcontext 16 | from functools import partial, wraps 17 | from collections.abc import Iterable 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.utils.data import random_split, DataLoader 23 | from torch.optim import Adam 24 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 25 | from torch.cuda.amp import autocast, GradScaler 26 | 27 | import pytorch_warmup as warmup 28 | 29 | from packaging import version 30 | 31 | import numpy as np 32 | import math 33 | import pandas as pd 34 | 35 | from ema_pytorch import EMA 36 | 37 | from einops import rearrange 38 | 39 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 40 | 41 | from fsspec.core import url_to_fs 42 | from fsspec.implementations.local import LocalFileSystem 43 | 44 | import shutil 45 | import matplotlib.pyplot as plt 46 | 47 | from sklearn.metrics import r2_score 48 | from scipy.stats import spearmanr, pearsonr 49 | 50 | # ////////////////////////////////////////////////////////////// 51 | # 2. special packages 52 | # ////////////////////////////////////////////////////////////// 53 | from VibeGen.DataSetPack import ( 54 | pad_a_np_arr_esm_for_NMS 55 | ) 56 | from VibeGen.ModelPack import ( 57 | ProteinDesigner_Base, 58 | ProteinPredictor_Base 59 | ) 60 | from VibeGen.imagen_x_imagen_pytorch import ( 61 | ElucidatedImagen_OneD, eval_decorator 62 | ) 63 | # 64 | from VibeGen.UtilityPack import ( 65 | Print, Print_model_params, 66 | create_path, 67 | get_toks_list_from_Y_batch, 68 | save_2d_tensor_as_np_arr_txt, 69 | write_fasta_file, 70 | compare_two_seq_strings, 71 | fold_one_AA_to_SS_using_omegafold, 72 | show_pdb, 73 | get_DSSP_result, 74 | write_DSSP_result_to_json, 75 | write_one_line_to_file, 76 | decode_many_ems_token_rec, 77 | get_nms_vec_as_arr_list_from_batch_using_mask, 78 | compare_two_nms_vecs_arr, 79 | translate_seqs_list_into_idx_tensor_w_pLM 80 | ) 81 | 82 | # ////////////////////////////////////////////////////////////// 83 | # 3. local setup parameters: for debug purpose 84 | # ////////////////////////////////////////////////////////////// 85 | PT_Init_Level = 1 86 | PT_Forw_Level = 1 87 | 88 | Local_Debug_Level = 0 89 | # ////////////////////////////////////////////////////////////// 90 | # 4. helper functions 91 | # ////////////////////////////////////////////////////////////// 92 | def merge_two_topk( 93 | y_goo, 94 | y_bad, 95 | ): 96 | y={} 97 | y['indices']=torch.concatenate( 98 | (y_goo.indices,y_bad.indices) 99 | ) 100 | y['values']= torch.concatenate( 101 | (y_goo.values,y_bad.values) 102 | ) 103 | 104 | len_goo = len(y_goo.indices) 105 | len_bad = len(y_bad.indices) 106 | name_list_goo = [f'min_err_{ii}' for ii in range(len_goo)] 107 | name_list_bad = [f'max_err_{ii}' for ii in range(len_bad)] 108 | name_list = name_list_goo+name_list_bad 109 | 110 | y['name_type']=name_list 111 | 112 | # indices=torch.concatenate( 113 | # (y1.indices,y2.indices) 114 | # ) 115 | # values= torch.concatenate( 116 | # (y1.values,y2.values) 117 | # ) 118 | # y=torch.return_types.topk_out( 119 | # values=values, 120 | # indices=indices 121 | # ) 122 | return y 123 | 124 | # ////////////////////////////////////////////////////////////// 125 | # 5. Main class/functions: a base trainer wrap for ProteinDesigner 126 | # ////////////////////////////////////////////////////////////// 127 | 128 | 129 | # |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||| 130 | 131 | # joint PD & PP for sampling 132 | 133 | def joint_sampling_w_PD_and_PP( 134 | # 1. model 135 | PD_wk_ProteinDesigner, 136 | PP_wk_ProteinPredictor, 137 | # 2. data 138 | PD_test_set_condition_list, # input as a list of NMS vecs 139 | PD_test_set_AA_list=None, # whether GT is provided 140 | PD_DataKeys=None, 141 | # 3. control param 142 | n_try_w_PD = 100, # For PD, try this number times as a batch 143 | n_keep_w_PP_goo = 2, # Use PP to pick the top this_number samples 144 | # ++ 145 | n_keep_w_PP_bad=2, # also keep the worest one as ref: 146 | # 147 | PD_cond_scal = 7.5, 148 | PP_cond_scal = 7.5, 149 | # 4. outputs 150 | joint_sampling_dir = None, 151 | joint_sampling_prefix = f"TestSet_", # 3. on postprocessing 152 | # 153 | IF_plot_PP = True, 154 | IF_showfig = True, 155 | IF_save_pred_pack = True, 156 | IF_plot_PD = True, 157 | IF_fold_seq = True, 158 | IF_show_foldding = True, 159 | IF_DSSP = True, 160 | # others 161 | device = None, 162 | ): 163 | 164 | if not (PD_test_set_AA_list is None): 165 | assert len(PD_test_set_condition_list)==len(PD_test_set_AA_list), \ 166 | "the input Conditioning and GT don't have the same length..." 167 | else: 168 | Print(f"Only input Conditioning is provided...") 169 | 170 | # prepare wk dir 171 | if not os.path.exists(joint_sampling_dir): 172 | Print(f"Create joint sampling path...") 173 | create_path(joint_sampling_dir) 174 | Print(f"Done.") 175 | else: 176 | Print(f"Dir exists. Use caution...") 177 | 178 | # model status 179 | PD_wk_ProteinDesigner.turn_on_eval_mode() 180 | PP_wk_ProteinPredictor.turn_on_eval_mode() 181 | 182 | 183 | # prepare 184 | # on PD 185 | text_len_max = PD_wk_ProteinDesigner.text_max_len 186 | img_len_max = PD_wk_ProteinDesigner.diffuser_core.image_sizes[0] 187 | len_in = min(text_len_max, img_len_max) # depend on problem statement, may change 188 | 189 | text_embed_input_dim = PD_wk_ProteinDesigner.text_embed_input_dim 190 | cond_img_channels = PD_wk_ProteinDesigner.diffuser_core.unets[0].cond_images_channels 191 | mode_in = min(text_embed_input_dim, cond_img_channels) 192 | 193 | print (len_in) 194 | print (mode_in) 195 | 196 | # on PP 197 | AA_seq_max_len = PP_wk_ProteinPredictor.seq_obj_max_size 198 | esm_batch_converter = PP_wk_ProteinPredictor.pLM_alphabet.get_batch_converter( 199 | truncation_seq_length=AA_seq_max_len-2 200 | ) 201 | PP_wk_ProteinPredictor.pLM.eval() 202 | 203 | PR_err_for_PP = torch.nn.MSELoss(reduction='none') 204 | 205 | n_keep_w_PP = n_keep_w_PP_goo+n_keep_w_PP_bad 206 | 207 | # ++ 208 | # ++ get GT for PD if exists 209 | # 6. translate back to a batch for PP 210 | if not (PD_test_set_AA_list is None): 211 | GT_test_set_AA_batch_for_PD = \ 212 | translate_seqs_list_into_idx_tensor_w_pLM( 213 | # 1. model converter 214 | esm_batch_converter, 215 | AA_seq_max_len, 216 | # 2. on input 217 | raw_condition_list=PD_test_set_AA_list, 218 | # 3. on outpt 219 | device=device 220 | ) # (batch, seq_len) 221 | # 222 | # get len info as shape mask 223 | mask_from_Y_all = PD_wk_ProteinDesigner.read_mask_from_seq_toks_using_pLM( 224 | GT_test_set_AA_batch_for_PD 225 | ) 226 | # 227 | GT_idx_list, GT_seqs_list = get_toks_list_from_Y_batch( 228 | GT_test_set_AA_batch_for_PD, 229 | mask_from_Y_all 230 | ) 231 | else: 232 | GT_test_set_AA_batch_for_PD = None 233 | mask_from_Y_all = None 234 | GT_idx_list = None 235 | GT_seqs_list = None 236 | 237 | # ++ for picking up from the previous runs 238 | # ............................................................. 239 | reco_csv = joint_sampling_dir+'/'+joint_sampling_prefix+\ 240 | f'Try_{n_try_w_PD}_Pick_{n_keep_w_PP}'+'_reco.csv' 241 | 242 | Print (f"Use reco file: \n{reco_csv}") 243 | if not os.path.isfile(reco_csv): 244 | # first time 245 | Print (f"First run of the sampling...\n\n") 246 | # write the top line 247 | csv_top_line = f"root_path,error_L2,r2" 248 | write_one_line_to_file( 249 | this_line=csv_top_line+'\n', 250 | file_name=reco_csv, 251 | mode='w', 252 | ) 253 | n_pick_finished = 0 254 | n_samp_finished = 0 255 | else: 256 | df_reco = pd.read_csv(reco_csv) 257 | n_pick_finished = len(df_reco)//n_keep_w_PP 258 | n_samp_finished = len(df_reco)%n_keep_w_PP 259 | Print (f"Previously, finished input #: {n_pick_finished}") 260 | Print (f"finished samp #: {n_samp_finished}\n\n") 261 | 262 | 263 | 264 | X_file_list = [] 265 | 266 | # pick one sample 267 | for i_pick in range(len(PD_test_set_condition_list)): 268 | 269 | if i_pick > n_pick_finished-1: # pick up from the previous 270 | 271 | Print (f"\n\nWorking on Input #: {i_pick}\n\n") 272 | 273 | # i_pick = 1 274 | 275 | # 1. get X data padded 276 | X_arr = np.zeros( 277 | (mode_in, len_in) 278 | ) # (n_mode, seq_len) 279 | 280 | for j in range(mode_in): 281 | X_arr[j, :] = pad_a_np_arr_esm_for_NMS( 282 | PD_test_set_condition_list[i_pick][j, :], 283 | 0, 284 | len_in 285 | ) 286 | print (X_arr.shape) 287 | 288 | # 2. get X normalized and formated 289 | for j in range(mode_in): 290 | X_arr[j, :] = X_arr[j, :]/PD_DataKeys['Xnormfac'][j] 291 | 292 | X_train = torch.from_numpy(X_arr).float() # (c, seq_len) 293 | 294 | # 3. expand into a batch 295 | X_train = X_train.unsqueeze(0).repeat(n_try_w_PD,1,1) 296 | 297 | print (X_train.shape) 298 | 299 | X_train_batch = X_train.to(device) 300 | 301 | # 4. prep the GT for NMS vecs 302 | seq_len_pick = PD_test_set_condition_list[i_pick].shape[1] 303 | GT_NMS_tensor_pick = torch.from_numpy( 304 | PD_test_set_condition_list[i_pick] 305 | ).float() # (n_mode, this_seq_len) 306 | GT_NMS_tensor = GT_NMS_tensor_pick.unsqueeze(0).repeat( 307 | n_try_w_PD,1,1 308 | ) # (batch, n_mode, this_seq_len) 309 | GT_NMS_tensor = GT_NMS_tensor.to(device) 310 | 311 | # 5. make prediction w. PD 312 | print (f"\n\nPD making {str(n_try_w_PD)} designs ...\n\n") 313 | 314 | PR_toks_list, PR_seqs_list, result_mask = \ 315 | PD_wk_ProteinDesigner.sample_to_pLM_idx_seq( 316 | # 317 | common_AA_only=True, # False, 318 | mask_from_Y=None, # mask_from_Y 319 | # if none, will use mask from X, cond_img then text 320 | # 321 | text_con_input = X_train_batch, 322 | cond_images = X_train_batch, 323 | # 324 | cond_scale = PD_cond_scal, 325 | ) 326 | # result_mask: (batch, seq_len) 327 | 328 | # 6. translate back to a batch for PP 329 | print (f"\n\nPP predicting performances ...\n\n") 330 | 331 | y_data_for_PP = \ 332 | translate_seqs_list_into_idx_tensor_w_pLM( 333 | # 1. model converter 334 | esm_batch_converter, 335 | AA_seq_max_len, 336 | # 2. on input 337 | raw_condition_list=PR_seqs_list, 338 | # 3. on outpt 339 | device=device 340 | ) # (batch, seq_len) 341 | 342 | print (y_data_for_PP.shape) 343 | 344 | # 7. make prediction w PP 345 | PR_NMS_arr_list = PP_wk_ProteinPredictor.sample_to_NMS_list( 346 | # mask 347 | mask_from_Y = result_mask, 348 | NormFac_list = PD_DataKeys['Xnormfac'], 349 | # 350 | text_con_input = y_data_for_PP, 351 | cond_images = y_data_for_PP, 352 | # 353 | cond_scale = PP_cond_scal, 354 | ) # list (n_mode, seq_len) 355 | # make the list into a tensor 356 | PR_NMS_tensor = torch.from_numpy( 357 | np.stack(PR_NMS_arr_list, axis=0) # (b, n_mode, this_seq_len) 358 | ).float() 359 | PR_NMS_tensor = PR_NMS_tensor.to(device) 360 | 361 | # 8. calc the error for NMS vecs 362 | 363 | PR_NMS_err_batch = PR_err_for_PP( 364 | PR_NMS_tensor, 365 | GT_NMS_tensor, 366 | ) # (b, n_mode, this_seq_len) 367 | PR_NMS_err_batch = torch.sum( 368 | PR_NMS_err_batch, 369 | dim=(1,2) 370 | ) # (b, ) 371 | 372 | print (f"Pick the best {n_keep_w_PP_goo}...") 373 | 374 | idxs_vals_to_pick_goo = torch.topk( 375 | PR_NMS_err_batch, 376 | k=n_keep_w_PP_goo, 377 | largest=False, 378 | ) 379 | # have indices and values 380 | # (n_keep_w_PP, ) 381 | print (f"Pick the worst {n_keep_w_PP_bad}...") 382 | idxs_vals_to_pick_bad = torch.topk( 383 | PR_NMS_err_batch, 384 | k=n_keep_w_PP_bad, 385 | largest=True, 386 | ) 387 | 388 | print (f"N\n\now, fold the picked best {n_keep_w_PP_goo} and worst {n_keep_w_PP_bad} samples...") 389 | idxs_vals_to_pick = merge_two_topk( 390 | y_goo=idxs_vals_to_pick_goo, 391 | y_bad=idxs_vals_to_pick_bad, 392 | ) 393 | 394 | 395 | # 9. postprocess 396 | for i_in_k in range(n_keep_w_PP): 397 | 398 | if i_in_k > n_samp_finished-1: # pick up from the previous 399 | 400 | Print (f"\n\nProcessing Picked #: Input {i_pick+1} -- Design {i_in_k+1}\n\n") 401 | if i_in_k n_samp_finished-1: # pick up from the previous 718 | 719 | else: 720 | pass # this record is already finished 721 | 722 | 723 | -------------------------------------------------------------------------------- /VibeGen/DataSetPack.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.preprocessing import text, sequence 2 | from tensorflow.keras.preprocessing.text import Tokenizer 3 | 4 | from torch.utils.data import DataLoader,Dataset 5 | import pandas as pd 6 | import seaborn as sns 7 | 8 | import torchvision 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | from torch import nn 14 | from torch import optim 15 | import torch.nn.functional as F 16 | from torchvision import datasets, transforms, models 17 | 18 | import torch.optim as optim 19 | from torch.optim.lr_scheduler import ExponentialLR, StepLR 20 | from functools import partial, wraps 21 | 22 | from sklearn.model_selection import train_test_split 23 | from sklearn.preprocessing import QuantileTransformer 24 | from sklearn.preprocessing import RobustScaler 25 | 26 | from matplotlib.ticker import MaxNLocator 27 | 28 | import torch 29 | 30 | import esm 31 | 32 | # special packages 33 | 34 | import VibeGen.UtilityPack as UPack 35 | from VibeGen.UtilityPack import ( 36 | decode_one_ems_token_rec, 37 | decode_many_ems_token_rec 38 | ) 39 | 40 | # 41 | DPack_Random = 123456 42 | 43 | class RegressionDataset(Dataset): 44 | 45 | def __init__(self, X_data, y_data): 46 | self.X_data = X_data 47 | self.y_data = y_data 48 | 49 | def __getitem__(self, index): 50 | return self.X_data[index], self.y_data[index] 51 | 52 | def __len__ (self): 53 | return len(self.X_data) 54 | 55 | # ============================================================ 56 | # handle NMA result 57 | # 58 | # 1. screen the dataset 59 | # ============================================================ 60 | def screen_dataset_MD_NMS_MultiModes( 61 | # # -- 62 | # file_path, 63 | # ++ 64 | csv_file=None, 65 | pk_file =None, 66 | PKeys=None, 67 | CKeys=None, 68 | ): 69 | # unload the parameters 70 | 71 | store_path = PKeys['data_dir'] 72 | IF_SaveFig = CKeys['SilentRun'] 73 | min_AASeq_len = PKeys['min_AA_seq_len'] 74 | max_AASeq_len = PKeys['max_AA_seq_len'] 75 | max_used_Seg_Num = PKeys['max_used_Seg_Num'] 76 | 77 | # max_used_Smo_F = PKeys['max_Force_cap'] 78 | 79 | # working part 80 | if csv_file != None: 81 | # not used for now 82 | # functions 83 | print('=============================================') 84 | print('1. read in the csv file...') 85 | print('=============================================') 86 | arr_key = PKeys['arr_key'] 87 | 88 | df_raw = pd.read_csv(csv_file) 89 | UPack.Print("Raw df has keys:") 90 | UPack.Print(df_raw.keys()) 91 | 92 | # convert string array back to array 93 | for this_key in arr_key: 94 | # np.array(list(map(float, one_record.split(" ")))) 95 | df_raw[this_key] = df_raw[this_key].apply(lambda x: np.array(list(map(float, x.split(" "))))) 96 | # ===================================================== 97 | # adjust if needed 98 | # patch up 99 | df_raw.rename(columns={"sample_FORCEpN_data":"sample_FORCE_data"}, inplace=True) 100 | print('Updated keys: \n', df_raw.keys()) 101 | 102 | elif pk_file != None: 103 | # functions 104 | print('=============================================') 105 | print('1. read in the pk file...') 106 | print('=============================================') 107 | # 108 | df_raw = pd.read_pickle(pk_file) 109 | 110 | UPack.Print("Raw df has keys:") 111 | UPack.Print(df_raw.keys()) 112 | 113 | # .............................................................................. 114 | # -- 115 | fig = plt.figure(figsize=(24,16),dpi=200) 116 | fig, ax0 = plt.subplots() 117 | for ii in range(len( df_raw )): 118 | if df_raw['AA_Eff_Len'][ii]<=6400: 119 | # # + 120 | # ax0.plot( 121 | # df_disp_forc_smo['normalized_pull_gap_data'][ii], 122 | # df_disp_forc_smo['forc_data'][ii], 123 | # color="blue",label='full data' 124 | # ) 125 | # # 126 | ax0.plot( 127 | df_raw['Norm_Resi_Ind_List'][ii], 128 | # df_raw['sample_FORCEpN_data'][ii], 129 | df_raw['Mode7_NormDisAmp'][ii], 130 | alpha=0.1, 131 | # color="green",label='simplified data', 132 | # linestyle='None',marker='^' 133 | ) 134 | # ============================================ 135 | # # too slow to do this 136 | # ax0.scatter( 137 | # df_raw['NormResiIndx_At_MaxVibrAmp_Mode7'][ii], 138 | # df_raw['NormDisAmp_At_MaxVibrAmp_Mode7'][ii], 139 | # ) 140 | else: 141 | print(df_raw['pdb_id'][ii]) 142 | # we see mistakes in: 1. wrong len of the AA; 2. wrong # of residue of the beginning and end 143 | plt.xlabel('Normalized residue index') 144 | plt.ylabel('Normalized vibrational disp. amp.') 145 | outname = store_path+'CSV_0_NMS_Mode7_Dist.jpg' 146 | if IF_SaveFig==1: 147 | plt.savefig(outname, dpi=200) 148 | else: 149 | plt.show() 150 | plt.close() 151 | 152 | print('=============================================') 153 | print('2. screen the entries...') 154 | print('=============================================') 155 | # 156 | df_isnull = pd.DataFrame( 157 | round( 158 | (df_raw.isnull().sum().sort_values(ascending=False)/df_raw.shape[0])*100, 159 | 1 160 | ) 161 | ).reset_index() 162 | df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))}) 163 | cm = sns.light_palette("skyblue", as_cmap=True) 164 | df_isnull = df_isnull.style.background_gradient(cmap=cm) 165 | print('Check null...') 166 | print( df_isnull ) 167 | 168 | print('Working on a dataframe with useful keywords') 169 | # suppose to be a smaller one 170 | # Focus on mode 7 For the moment 171 | # Expand to modes 7,8,9 172 | protein_df = pd.DataFrame().assign( 173 | pdb_id=df_raw['pdb_id'], 174 | AA=df_raw['AA_Full'], 175 | seq_len=df_raw['AA_Eff_Len'], 176 | AA_Seg_Num=df_raw['AA_Seg_Num'], 177 | Norm_Resi_Ind_List=df_raw['Norm_Resi_Ind_List'], 178 | # on mode 7 179 | Mode7_NormDisAmp=df_raw['Mode7_NormDisAmp'], 180 | ScaFac_7=df_raw['ScaFac_7'], 181 | Mode7_NormDis=df_raw['Mode7_NormDis'], 182 | Mode7_Freq=df_raw['Mode7_Freq'], 183 | NormResiIndx_At_MaxVibrAmp_Mode7=df_raw['NormResiIndx_At_MaxVibrAmp_Mode7'], 184 | NormDisAmp_At_MaxVibrAmp_Mode7=df_raw['NormDisAmp_At_MaxVibrAmp_Mode7'], 185 | Mode7_FixLen_NormDisAmp=df_raw['Mode7_FixLen_NormDisAmp'], 186 | # on mode 8 187 | Mode8_NormDisAmp=df_raw['Mode8_NormDisAmp'], 188 | ScaFac_8=df_raw['ScaFac_8'], 189 | Mode8_NormDis=df_raw['Mode8_NormDis'], 190 | Mode8_Freq=df_raw['Mode8_Freq'], 191 | NormResiIndx_At_MaxVibrAmp_Mode8=df_raw['NormResiIndx_At_MaxVibrAmp_Mode8'], 192 | NormDisAmp_At_MaxVibrAmp_Mode8=df_raw['NormDisAmp_At_MaxVibrAmp_Mode8'], 193 | # Mode8_FixLen_NormDisAmp=df_raw['Mode8_FixLen_NormDisAmp'], 194 | # on mode 9 195 | Mode9_NormDisAmp=df_raw['Mode9_NormDisAmp'], 196 | ScaFac_9=df_raw['ScaFac_9'], 197 | Mode9_NormDis=df_raw['Mode9_NormDis'], 198 | Mode9_Freq=df_raw['Mode9_Freq'], 199 | NormResiIndx_At_MaxVibrAmp_Mode9=df_raw['NormResiIndx_At_MaxVibrAmp_Mode9'], 200 | NormDisAmp_At_MaxVibrAmp_Mode9=df_raw['NormDisAmp_At_MaxVibrAmp_Mode9'], 201 | # Mode9_FixLen_NormDisAmp=df_raw['Mode9_FixLen_NormDisAmp'], 202 | ) 203 | # ++ add new keys on energy if needed 204 | 205 | # screen using AA length 206 | print('a. screen using sequence length...') 207 | print('original sequences #: ', len(protein_df)) 208 | # 209 | protein_df.drop( 210 | protein_df[protein_df['seq_len']>max_AASeq_len-2].index, 211 | inplace = True 212 | ) 213 | protein_df.drop( 214 | protein_df[protein_df['seq_len'] max_used_Seg_Num].index, 225 | inplace = True 226 | ) 227 | # protein_df.drop( 228 | # protein_df[protein_df['seq_len'] 598 | # 599 | x1 = x0.copy() 600 | x1 = np.insert(x1,0,add_x) 601 | n0 = len(x1) 602 | if n0 616 | # # x1 = [add_x]+x1 # somehow, this one doesn't work 617 | # # print(x1) 618 | # # print('x1 len: ',len(x1) ) 619 | # n0 = len(x1) 620 | # # 621 | # if n00: 680 | jj_0 = jj 681 | break 682 | for jj in range(max_AA_len): 683 | if np.fabs(this_print_arr[-(jj+1)])>0: 684 | jj_1 = jj 685 | break 686 | jj_data_len = max_AA_len-jj_0-jj_1 687 | print (f" Begin_padding: {jj_0}") 688 | print (f" End_padding: {jj_1}") 689 | print (f" Data_len: {jj_data_len}") 690 | 691 | UPack.Print (f"Now, calculate Normalization Factor for each mode") 692 | UPack.Print (f"Upper bound of the NFs: {np.amax(X)}") 693 | 694 | X_NF_List = [] 695 | for ii, this_X_Key in enumerate(X_Keys): 696 | this_x_max = np.amax( 697 | X[:,ii,:] 698 | ) 699 | X_NF_List.append(this_x_max) 700 | # normalization 701 | X[:,ii,:] = X[:,ii,:]/this_x_max 702 | 703 | UPack.Print (f"X_NF_List: {X_NF_List}") 704 | 705 | UPack.Print("======================================================") 706 | UPack.Print("2. work on Y data: AA Sequence") 707 | UPack.Print("======================================================") 708 | # take care of the y part: AA encoding 709 | #create and fit tokenizer for AA sequences 710 | seqs = protein_df.AA.values 711 | # ++ for pLM: esm 712 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 713 | print("pLM model: ", PKeys['ESM-2_Model']) 714 | 715 | if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D': 716 | # print('Debug block') 717 | # embed dim: 1280 718 | esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() 719 | len_toks=len(esm_alphabet.all_toks) 720 | elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D': 721 | # embed dim: 480 722 | esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() 723 | len_toks=len(esm_alphabet.all_toks) 724 | elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D': 725 | # embed dim: 2560 726 | esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() 727 | len_toks=len(esm_alphabet.all_toks) 728 | elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D': 729 | # embed dim: 640 730 | esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() 731 | len_toks=len(esm_alphabet.all_toks) 732 | else: 733 | print("protein language model is not defined.") 734 | # 735 | # for check 736 | print("esm_alphabet.use_msa: ", esm_alphabet.use_msa) 737 | print("# of tokens in AA alphabet: ", len_toks) 738 | # need to save 2 positions for and 739 | esm_batch_converter = esm_alphabet.get_batch_converter( 740 | truncation_seq_length=PKeys['max_AA_seq_len']-2 741 | ) 742 | esm_model.eval() # disables dropout for deterministic results 743 | # prepare seqs for the "esm_batch_converter..." 744 | # add dummy labels 745 | seqs_ext=[] 746 | for i in range(len(seqs)): 747 | seqs_ext.append( 748 | (" ", seqs[i]) 749 | ) 750 | # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext) 751 | _, y_strs, y_data = esm_batch_converter(seqs_ext) 752 | y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1) 753 | # print(batch_tokens.shape) 754 | print ("y_data.dim: ", y_data.dtype) 755 | 756 | fig_handle = sns.histplot( 757 | data=pd.DataFrame({'AA code': np.array(y_data).flatten()}), 758 | x='AA code', 759 | bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)]) 760 | # binwidth=1, 761 | ) 762 | fig = fig_handle.get_figure() 763 | fig_handle.set_xlim(-1, 33+1) 764 | # fig_handle.set_ylim(0, 100000) 765 | outname=store_path+'CSV_5_DataSet_AACode_dist.jpg' 766 | if IF_SaveFig==1: 767 | plt.savefig(outname, dpi=200) 768 | else: 769 | plt.show() 770 | plt.close() 771 | 772 | # ----------------------------------------------------------- 773 | # print ("#################################") 774 | # print ("DICTIONARY y_data") 775 | # dictt=tokenizer_y.get_config() 776 | # print (dictt) 777 | # num_words = len(tokenizer_y.word_index) + 1 778 | # print ("################## y max token: ",num_words ) 779 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 780 | print ("#################################") 781 | print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model']) 782 | print ("################## y max token: ",len_toks ) 783 | 784 | #revere 785 | print ("TEST REVERSE: ") 786 | 787 | # # -------------------------------------------------------------- 788 | # y_data_reversed=tokenizer_y.sequences_to_texts (y_data) 789 | 790 | # for iii in range (len(y_data_reversed)): 791 | # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "") 792 | 793 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 794 | # assume y_data is reversiable 795 | y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet) 796 | 797 | 798 | print ("Element 0", y_data_reversed[0]) 799 | print ("Number of y samples",len (y_data_reversed) ) 800 | 801 | for iii in [0,2,6]: 802 | print("Ori and REVERSED SEQ: ", iii) 803 | print(seqs[iii]) 804 | print(y_data_reversed[iii]) 805 | 806 | # print ("Original: ", y_data[:3,:]) 807 | # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3]) 808 | 809 | print ("Len 0 as example: ", len (y_data_reversed[0]) ) 810 | print ("CHeck ori: ", len (seqs[0]) ) 811 | print ("Len 2 as example: ", len (y_data_reversed[2]) ) 812 | print ("CHeck ori: ", len (seqs[2]) ) 813 | 814 | # placeholder 815 | tokenizer_X = None 816 | tokenizer_Y = None 817 | 818 | return X, X_NF_List, y_data, y_data_reversed,tokenizer_X, tokenizer_Y 819 | 820 | # ============================================================= 821 | # build loaders 822 | def build_dataloaders( 823 | X, 824 | y_data, 825 | protein_df, 826 | PKeys=None, 827 | CKeys=None, 828 | ): 829 | # unload the parameters 830 | store_path = PKeys['data_dir'] 831 | IF_SaveFig = CKeys['SilentRun'] 832 | 833 | batch_size = PKeys['batch_size'] 834 | TestSet_ratio = PKeys['testset_ratio'] 835 | maxdata=PKeys['maxdata'] 836 | 837 | 838 | if maxdata(-1,1) 143 | def normalize_neg_one_to_one(img): 144 | return img * 2 - 1 145 | 146 | def unnormalize_zero_to_one(normed_img): 147 | return (normed_img + 1) * 0.5 148 | 149 | def compact(input_dict): 150 | return {key: value for key, value in input_dict.items() if exists(value)} 151 | 152 | def maybe_transform_dict_key(input_dict, key, fn): 153 | if key not in input_dict: 154 | return input_dict 155 | 156 | copied_dict = input_dict.copy() 157 | copied_dict[key] = fn(copied_dict[key]) 158 | return copied_dict 159 | 160 | # tensor helpers 161 | 162 | def log(t, eps: float = 1e-12): 163 | return torch.log(t.clamp(min = eps)) 164 | 165 | # 166 | # =========================================================== 167 | # =========================================================== 168 | # =========================================================== 169 | # main class: ElucidatedImagen 170 | # =========================================================== 171 | # =========================================================== 172 | # =========================================================== 173 | # 174 | # on diffusion scheduler 175 | # 176 | # gaussian diffusion with continuous time helper functions and classes 177 | # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py 178 | 179 | @torch.jit.script 180 | def beta_linear_log_snr(t): 181 | return -torch.log(expm1(1e-4 + 10 * (t ** 2))) 182 | 183 | @torch.jit.script 184 | def alpha_cosine_log_snr(t, s: float = 0.008): 185 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version 186 | 187 | def log_snr_to_alpha_sigma(log_snr): 188 | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) 189 | # 190 | class GaussianDiffusionContinuousTimes(nn.Module): 191 | def __init__( 192 | self, 193 | *, 194 | noise_schedule, 195 | timesteps = 1000, 196 | ): 197 | super().__init__() 198 | 199 | if noise_schedule == "linear": 200 | self.log_snr = beta_linear_log_snr 201 | elif noise_schedule == "cosine": 202 | self.log_snr = alpha_cosine_log_snr 203 | else: 204 | raise ValueError(f'invalid noise schedule {noise_schedule}') 205 | 206 | self.num_timesteps = timesteps 207 | 208 | def get_times( 209 | self, 210 | batch_size, 211 | noise_level, 212 | *, 213 | device 214 | ): 215 | return torch.full( 216 | (batch_size,), 217 | noise_level, 218 | device = device, 219 | dtype = torch.float32 220 | ) 221 | 222 | def sample_random_times( 223 | self, 224 | batch_size, 225 | *, 226 | device 227 | ): 228 | return torch.zeros( 229 | (batch_size,), 230 | device = device 231 | ).float().uniform_(0, 1) 232 | 233 | def get_condition(self, times): 234 | return maybe(self.log_snr)(times) 235 | 236 | def get_sampling_timesteps( 237 | self, 238 | batch, 239 | *, 240 | device 241 | ): 242 | times = torch.linspace( 243 | 1., 244 | 0., 245 | self.num_timesteps + 1, 246 | device = device 247 | ) 248 | times = repeat(times, 't -> b t', b = batch) 249 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) 250 | times = times.unbind(dim = -1) 251 | return times 252 | 253 | def q_posterior( 254 | self, 255 | x_start, 256 | x_t, 257 | t, 258 | *, 259 | t_next = None 260 | ): 261 | t_next = default( 262 | t_next, 263 | lambda: (t - 1. / self.num_timesteps).clamp(min = 0.) 264 | ) 265 | 266 | """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ 267 | log_snr = self.log_snr(t) 268 | log_snr_next = self.log_snr(t_next) 269 | log_snr, log_snr_next = map( 270 | partial(right_pad_dims_to, x_t), 271 | (log_snr, log_snr_next) 272 | ) 273 | 274 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 275 | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) 276 | 277 | # c - as defined near eq 33 278 | c = -expm1(log_snr - log_snr_next) 279 | posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) 280 | 281 | # following (eq. 33) 282 | posterior_variance = (sigma_next ** 2) * c 283 | posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) 284 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 285 | 286 | def q_sample( 287 | self, 288 | x_start, 289 | t, 290 | noise = None 291 | ): 292 | dtype = x_start.dtype 293 | 294 | if isinstance(t, float): 295 | batch = x_start.shape[0] 296 | t = torch.full((batch,), t, device = x_start.device, dtype = dtype) 297 | 298 | noise = default(noise, lambda: torch.randn_like(x_start)) 299 | log_snr = self.log_snr(t).type(dtype) 300 | log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) 301 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) 302 | 303 | return alpha * x_start + sigma * noise, log_snr, alpha, sigma 304 | 305 | def q_sample_from_to( 306 | self, 307 | x_from, 308 | from_t, 309 | to_t, 310 | noise = None 311 | ): 312 | shape, device, dtype = x_from.shape, x_from.device, x_from.dtype 313 | batch = shape[0] 314 | 315 | if isinstance(from_t, float): 316 | from_t = torch.full((batch,), from_t, device = device, dtype = dtype) 317 | 318 | if isinstance(to_t, float): 319 | to_t = torch.full((batch,), to_t, device = device, dtype = dtype) 320 | 321 | noise = default(noise, lambda: torch.randn_like(x_from)) 322 | 323 | log_snr = self.log_snr(from_t) 324 | log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) 325 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) 326 | 327 | log_snr_to = self.log_snr(to_t) 328 | log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) 329 | alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) 330 | 331 | return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha 332 | 333 | def predict_start_from_v(self, x_t, t, v): 334 | log_snr = self.log_snr(t) 335 | log_snr = right_pad_dims_to(x_t, log_snr) 336 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 337 | return alpha * x_t - sigma * v 338 | 339 | def predict_start_from_noise(self, x_t, t, noise): 340 | log_snr = self.log_snr(t) 341 | log_snr = right_pad_dims_to(x_t, log_snr) 342 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 343 | return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) 344 | 345 | # =========================================================== 346 | # constants 347 | 348 | Hparams_fields = [ 349 | 'num_sample_steps', 350 | 'sigma_min', 351 | 'sigma_max', 352 | 'sigma_data', 353 | 'rho', 354 | 'P_mean', 355 | 'P_std', 356 | 'S_churn', 357 | 'S_tmin', 358 | 'S_tmax', 359 | 'S_noise' 360 | ] 361 | 362 | Hparams = namedtuple('Hparams', Hparams_fields) 363 | 364 | 365 | # =========================================================== 366 | # =========================================================== 367 | # add for OneD data format 368 | # 369 | class ElucidatedImagen_OneD(nn.Module): 370 | def __init__( 371 | self, 372 | # 1. unets: many setups of UNet will be passed on via UNet itself 373 | unets, 374 | *, 375 | channels = 3, 376 | # 2. in-output image size 377 | image_sizes, # for cascading ddpm, image size at each stage 378 | # 3. on text conditioning 379 | text_encoder_name = None, # TBU: DEFAULT_T5_NAME, 380 | text_embed_dim = None, 381 | cond_drop_prob = 0.1, 382 | condition_on_text = True, 383 | # 384 | random_crop_sizes = None, 385 | resize_mode = 'nearest', 386 | temporal_downsample_factor = 1, 387 | resize_cond_video_frames = True, 388 | lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level 389 | per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find 390 | auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader 391 | dynamic_thresholding = True, 392 | dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper 393 | only_train_unet_number = None, 394 | lowres_noise_schedule = 'linear', 395 | num_sample_steps = 32, # number of sampling steps 396 | sigma_min = 0.002, # min noise level 397 | sigma_max = 80, # max noise level 398 | sigma_data = 0.5, # standard deviation of data distribution 399 | rho = 7, # controls the sampling schedule 400 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 401 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 402 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 403 | S_tmin = 0.05, 404 | S_tmax = 50, 405 | S_noise = 1.003, 406 | # ++ 407 | CKeys = {'Debug_Level':0}, # for debug purpose: 0--silence mode 408 | ): 409 | super().__init__() 410 | 411 | # ++ for debug 412 | self.CKeys = CKeys 413 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 414 | print ("Debug mode: Initialization of EImagen...\n") 415 | 416 | self.only_train_unet_number = only_train_unet_number 417 | # ++ 418 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 419 | print (f".only_train_unet_number: {self.only_train_unet_number}") 420 | 421 | # conditioning hparams 422 | 423 | self.condition_on_text = condition_on_text 424 | self.unconditional = not condition_on_text 425 | # ++ 426 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 427 | print (f".condition_on_text: {self.condition_on_text}") 428 | print (f".unconditional: {self.unconditional}") 429 | 430 | # channels 431 | 432 | self.channels = channels 433 | # ++ 434 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 435 | print (f".channels: {self.channels}") 436 | 437 | # automatically take care of ensuring that first unet is unconditional 438 | # while the rest of the unets are conditioned on the low resolution image produced by previous unet 439 | 440 | unets = cast_tuple(unets) 441 | num_unets = len(unets) 442 | 443 | # randomly cropping for upsampler training 444 | 445 | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) 446 | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' 447 | # may get rid of this when moving to 1d case 448 | # ++ 449 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 450 | print (f".random_crop_sizes: {self.random_crop_sizes}") 451 | 452 | # lowres augmentation noise schedule 453 | 454 | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes( 455 | noise_schedule = lowres_noise_schedule 456 | ) 457 | # ++ 458 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 459 | print (f".lowres_noise_schedule: {self.lowres_noise_schedule}") 460 | 461 | # get text encoder 462 | 463 | self.text_encoder_name = text_encoder_name 464 | # --: a dull one is enough 465 | # self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name)) 466 | # ++ 467 | self.text_embed_dim = text_embed_dim 468 | # ++ 469 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 470 | print (f".text_encoder_name: {self.text_encoder_name}") 471 | print (f".text_embed_dim: {self.text_embed_dim}") 472 | 473 | # -- text channel is not updated yet 474 | # self.encode_text = partial(t5_encode_text, name = text_encoder_name) 475 | # ++: TBU if needed 476 | self.encode_text = None 477 | # ++ 478 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 479 | print (f".encode_text: {self.encode_text}") 480 | 481 | # construct unets 482 | 483 | self.unets = nn.ModuleList([]) 484 | self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment 485 | 486 | for ind, one_unet in enumerate(unets): 487 | # check the class of the unet: accept Unet_OneD, NullUnet 488 | assert isinstance(one_unet, (Unet_OneD, Unet3D, NullUnet)) 489 | is_first = ind == 0 490 | 491 | one_unet = one_unet.cast_model_parameters( 492 | lowres_cond = not is_first, # may open this channel 493 | cond_on_text = self.condition_on_text, 494 | text_embed_dim = self.text_embed_dim if self.condition_on_text else None, 495 | channels = self.channels, 496 | channels_out = self.channels, 497 | ) 498 | 499 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 500 | print (f"Add one UNet: ") 501 | print (one_unet) 502 | print (f"======================================== ") 503 | 504 | self.unets.append(one_unet) 505 | 506 | # determine whether we are training on images or video 507 | 508 | is_video = any([isinstance(unet, Unet3D) for unet in self.unets]) 509 | self.is_video = is_video 510 | # ++ 511 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 512 | print (f".is_video: {self.is_video}") 513 | 514 | self.right_pad_dims_to_datatype = partial( 515 | rearrange, 516 | # -- 517 | # pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1') 518 | # ++ may think adding one for video of 1d data 519 | pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1 1') 520 | ) 521 | # ++ 522 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 523 | print (f".right_pad_dims_to_datatype: {self.right_pad_dims_to_datatype}") 524 | 525 | self.resize_to = resize_video_to if is_video else resize_2d_image_to 526 | # only triggered when the last dimension doesn't match the traget one 527 | # input: (mini-batch, channels, width) # assume it works for 1d 528 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html 529 | 530 | self.resize_to = partial(self.resize_to, mode = resize_mode) 531 | # ++ 532 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 533 | print (f".resize_to: {self.resize_to}") 534 | 535 | # unet image sizes 536 | 537 | self.image_sizes = cast_tuple(image_sizes) 538 | assert num_unets == len(self.image_sizes), \ 539 | f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' 540 | # ++ 541 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 542 | print (f".image_size: {self.image_sizes}") 543 | 544 | self.sample_channels = cast_tuple(self.channels, num_unets) 545 | # ++ 546 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 547 | print (f".sample_channels: {self.sample_channels}") 548 | 549 | # cascading ddpm related stuff 550 | 551 | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) 552 | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), \ 553 | 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' 554 | 555 | self.lowres_sample_noise_level = lowres_sample_noise_level 556 | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level 557 | # ++ 558 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 559 | print (f".lowres_sample_noise_level: {self.lowres_sample_noise_level}") 560 | print (f".per_sample_random_aug_noise_level: {self.per_sample_random_aug_noise_level}") 561 | 562 | # classifier free guidance 563 | 564 | self.cond_drop_prob = cond_drop_prob 565 | self.can_classifier_guidance = cond_drop_prob > 0. 566 | # ++ 567 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 568 | print (f".cond_drop_prob: {self.cond_drop_prob}") 569 | print (f".can_classifier_guidance: {self.can_classifier_guidance}") 570 | 571 | # normalize and unnormalize image functions 572 | 573 | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity 574 | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity 575 | self.input_image_range = (0. if auto_normalize_img else -1., 1.) 576 | # ++ 577 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 578 | print (f".normalize_img: {self.normalize_img}") 579 | print (f".unnormalize_img: {self.unnormalize_img}") 580 | print (f".input_image_range: {self.input_image_range}") 581 | 582 | # dynamic thresholding 583 | 584 | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) 585 | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile 586 | # ++ 587 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 588 | print (f".dynamic_thresholding: {self.dynamic_thresholding}") 589 | print (f".dynamic_thresholding_percentile: {self.dynamic_thresholding_percentile}") 590 | 591 | # temporal interpolations 592 | 593 | temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) 594 | self.temporal_downsample_factor = temporal_downsample_factor 595 | 596 | self.resize_cond_video_frames = resize_cond_video_frames 597 | self.temporal_downsample_divisor = temporal_downsample_factor[0] 598 | # ++ 599 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 600 | print (f".temporal_downsample_factor: {self.temporal_downsample_factor}") 601 | print (f".resize_cond_video_frames: {self.resize_cond_video_frames}") 602 | print (f".temporal_downsample_divisor: {self.temporal_downsample_divisor}") 603 | 604 | assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' 605 | assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending' 606 | 607 | # elucidating parameters 608 | 609 | hparams = [ 610 | num_sample_steps, 611 | sigma_min, 612 | sigma_max, 613 | sigma_data, 614 | rho, 615 | P_mean, 616 | P_std, 617 | S_churn, 618 | S_tmin, 619 | S_tmax, 620 | S_noise, 621 | ] 622 | 623 | hparams = [cast_tuple(hp, num_unets) for hp in hparams] 624 | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] 625 | # ++ 626 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 627 | print (f".hparams: {self.hparams}") 628 | 629 | # one temp parameter for keeping track of device 630 | 631 | self.register_buffer('_temp', torch.tensor([0.]), persistent = False) 632 | 633 | # default to device of unets passed in 634 | 635 | self.to(next(self.unets.parameters()).device) 636 | # ++ 637 | if self.CKeys['Debug_Level']==Imagen_Init_Level: 638 | print (f".device: {next(self.unets.parameters()).device}") 639 | 640 | def force_unconditional_(self): 641 | self.condition_on_text = False 642 | self.unconditional = True 643 | 644 | for unet in self.unets: 645 | unet.cond_on_text = False 646 | 647 | @property 648 | def device(self): 649 | return self._temp.device 650 | 651 | def get_unet(self, unet_number): 652 | assert 0 < unet_number <= len(self.unets) 653 | index = unet_number - 1 654 | 655 | if isinstance(self.unets, nn.ModuleList): 656 | unets_list = [unet for unet in self.unets] 657 | delattr(self, 'unets') 658 | self.unets = unets_list 659 | 660 | if index != self.unet_being_trained_index: 661 | for unet_index, unet in enumerate(self.unets): 662 | unet.to(self.device if unet_index == index else 'cpu') 663 | 664 | self.unet_being_trained_index = index 665 | return self.unets[index] 666 | 667 | def reset_unets_all_one_device(self, device = None): 668 | device = default(device, self.device) 669 | self.unets = nn.ModuleList([*self.unets]) 670 | self.unets.to(device) 671 | 672 | self.unet_being_trained_index = -1 673 | 674 | @contextmanager 675 | def one_unet_in_gpu(self, unet_number = None, unet = None): 676 | assert exists(unet_number) ^ exists(unet) 677 | 678 | if exists(unet_number): 679 | unet = self.unets[unet_number - 1] 680 | 681 | cpu = torch.device('cpu') 682 | 683 | devices = [module_device(unet) for unet in self.unets] 684 | 685 | self.unets.to(cpu) 686 | unet.to(self.device) 687 | 688 | yield 689 | 690 | for unet, device in zip(self.unets, devices): 691 | unet.to(device) 692 | 693 | # overriding state dict functions 694 | 695 | def state_dict(self, *args, **kwargs): 696 | self.reset_unets_all_one_device() 697 | return super().state_dict(*args, **kwargs) 698 | 699 | def load_state_dict(self, *args, **kwargs): 700 | self.reset_unets_all_one_device() 701 | return super().load_state_dict(*args, **kwargs) 702 | 703 | # dynamic thresholding 704 | 705 | def threshold_x_start(self, x_start, dynamic_threshold = True): 706 | if not dynamic_threshold: 707 | return x_start.clamp(-1., 1.) 708 | 709 | s = torch.quantile( 710 | rearrange(x_start, 'b ... -> b (...)').abs(), 711 | self.dynamic_thresholding_percentile, 712 | dim = -1 713 | ) 714 | 715 | s.clamp_(min = 1.) 716 | s = right_pad_dims_to(x_start, s) 717 | return x_start.clamp(-s, s) / s 718 | 719 | # derived preconditioning params - Table 1 720 | 721 | def c_skip(self, sigma_data, sigma): 722 | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) 723 | 724 | def c_out(self, sigma_data, sigma): 725 | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 726 | 727 | def c_in(self, sigma_data, sigma): 728 | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 729 | 730 | def c_noise(self, sigma): 731 | return log(sigma) * 0.25 732 | 733 | # preconditioned network output 734 | # equation (7) in the paper 735 | 736 | def preconditioned_network_forward( 737 | self, 738 | unet_forward, 739 | noised_images, 740 | sigma, 741 | *, 742 | sigma_data, 743 | clamp = False, 744 | dynamic_threshold = True, 745 | **kwargs 746 | ): 747 | batch, device = noised_images.shape[0], noised_images.device 748 | 749 | if isinstance(sigma, float): 750 | sigma = torch.full((batch,), sigma, device = device) 751 | 752 | padded_sigma = self.right_pad_dims_to_datatype(sigma) 753 | 754 | net_out = unet_forward( 755 | self.c_in(sigma_data, padded_sigma) * noised_images, 756 | self.c_noise(sigma), 757 | **kwargs 758 | ) 759 | 760 | out = self.c_skip(sigma_data, padded_sigma) * noised_images \ 761 | + self.c_out(sigma_data, padded_sigma) * net_out 762 | 763 | if not clamp: 764 | return out 765 | 766 | return self.threshold_x_start(out, dynamic_threshold) 767 | 768 | # sampling 769 | 770 | # sample schedule 771 | # equation (5) in the paper 772 | 773 | def sample_schedule( 774 | self, 775 | num_sample_steps, 776 | rho, 777 | sigma_min, 778 | sigma_max 779 | ): 780 | N = num_sample_steps 781 | inv_rho = 1 / rho 782 | 783 | steps = torch.arange( 784 | num_sample_steps, 785 | device = self.device, 786 | dtype = torch.float32 787 | ) 788 | sigmas = ( 789 | sigma_max ** inv_rho \ 790 | + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho) 791 | ) ** rho 792 | 793 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 794 | return sigmas 795 | 796 | @torch.no_grad() 797 | def one_unet_sample( 798 | self, 799 | unet, 800 | shape, 801 | *, 802 | unet_number, 803 | clamp = True, 804 | dynamic_threshold = True, 805 | cond_scale = 1., 806 | use_tqdm = True, 807 | inpaint_videos = None, 808 | inpaint_images = None, 809 | inpaint_masks = None, 810 | inpaint_resample_times = 5, 811 | init_images = None, 812 | skip_steps = None, 813 | sigma_min = None, 814 | sigma_max = None, 815 | **kwargs 816 | ): 817 | # video 818 | 819 | is_video = len(shape) == 5 820 | frames = shape[-3] if is_video else None 821 | resize_kwargs = dict(target_frames = frames) if exists(frames) else dict() 822 | 823 | # get specific sampling hyperparameters for unet 824 | 825 | hp = self.hparams[unet_number - 1] 826 | 827 | sigma_min = default(sigma_min, hp.sigma_min) 828 | sigma_max = default(sigma_max, hp.sigma_max) 829 | 830 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 831 | 832 | sigmas = self.sample_schedule( 833 | hp.num_sample_steps, 834 | hp.rho, 835 | sigma_min, sigma_max 836 | ) 837 | 838 | gammas = torch.where( 839 | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), 840 | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), 841 | 0. 842 | ) 843 | 844 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 845 | 846 | # images is noise at the beginning 847 | 848 | init_sigma = sigmas[0] 849 | 850 | images = init_sigma * torch.randn(shape, device = self.device) 851 | 852 | # initializing with an image 853 | 854 | if exists(init_images): 855 | images += init_images 856 | 857 | # keeping track of x0, for self conditioning if needed 858 | 859 | x_start = None 860 | 861 | # prepare inpainting images and mask 862 | 863 | inpaint_images = default(inpaint_videos, inpaint_images) 864 | has_inpainting = exists(inpaint_images) and exists(inpaint_masks) 865 | resample_times = inpaint_resample_times if has_inpainting else 1 866 | 867 | if has_inpainting: 868 | inpaint_images = self.normalize_img(inpaint_images) 869 | inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs) 870 | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool() 871 | 872 | # unet kwargs 873 | 874 | unet_kwargs = dict( 875 | sigma_data = hp.sigma_data, 876 | clamp = clamp, 877 | dynamic_threshold = dynamic_threshold, 878 | cond_scale = cond_scale, 879 | **kwargs 880 | ) 881 | 882 | # gradually denoise 883 | 884 | initial_step = default(skip_steps, 0) 885 | sigmas_and_gammas = sigmas_and_gammas[initial_step:] 886 | 887 | total_steps = len(sigmas_and_gammas) 888 | 889 | for ind, (sigma, sigma_next, gamma) in tqdm( 890 | enumerate(sigmas_and_gammas), 891 | total = total_steps, 892 | desc = 'sampling time step', 893 | disable = not use_tqdm 894 | ): 895 | is_last_timestep = ind == (total_steps - 1) 896 | 897 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 898 | 899 | for r in reversed(range(resample_times)): 900 | is_last_resample_step = r == 0 901 | 902 | eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 903 | 904 | sigma_hat = sigma + gamma * sigma 905 | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps 906 | 907 | images_hat = images + added_noise 908 | 909 | self_cond = x_start if unet.self_cond else None 910 | 911 | if has_inpainting: 912 | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks 913 | 914 | model_output = self.preconditioned_network_forward( 915 | unet.forward_with_cond_scale, 916 | images_hat, 917 | sigma_hat, 918 | self_cond = self_cond, 919 | **unet_kwargs 920 | ) 921 | 922 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 923 | 924 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 925 | 926 | # second order correction, if not the last timestep 927 | 928 | has_second_order_correction = sigma_next != 0 929 | 930 | if has_second_order_correction: 931 | self_cond = model_output if unet.self_cond else None 932 | 933 | model_output_next = self.preconditioned_network_forward( 934 | unet.forward_with_cond_scale, 935 | images_next, 936 | sigma_next, 937 | self_cond = self_cond, 938 | **unet_kwargs 939 | ) 940 | 941 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 942 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 943 | 944 | images = images_next 945 | 946 | if has_inpainting and not (is_last_resample_step or is_last_timestep): 947 | # renoise in repaint and then resample 948 | repaint_noise = torch.randn(shape, device = self.device) 949 | images = images + (sigma - sigma_next) * repaint_noise 950 | 951 | x_start = model_output if not has_second_order_correction else model_output_next # save model output for self conditioning 952 | 953 | images = images.clamp(-1., 1.) 954 | 955 | if has_inpainting: 956 | images = images * ~inpaint_masks + inpaint_images * inpaint_masks 957 | 958 | return self.unnormalize_img(images) 959 | 960 | @torch.no_grad() 961 | @eval_decorator 962 | def sample( 963 | self, 964 | # 1. on text condition 965 | texts: List[str] = None, 966 | text_masks = None, 967 | text_embeds = None, 968 | # 2. on condition images 969 | cond_images = None, 970 | cond_video_frames = None, 971 | post_cond_video_frames = None, 972 | # 3. inpaint images 973 | inpaint_videos = None, 974 | inpaint_images = None, 975 | inpaint_masks = None, 976 | inpaint_resample_times = 5, 977 | # 978 | init_images = None, 979 | skip_steps = None, 980 | sigma_min = None, 981 | sigma_max = None, 982 | video_frames = None, 983 | batch_size = 1, 984 | cond_scale = 1., 985 | lowres_sample_noise_level = None, 986 | start_at_unet_number = 1, 987 | start_image_or_video = None, 988 | stop_at_unet_number = None, 989 | return_all_unet_outputs = False, 990 | return_pil_images = False, 991 | use_tqdm = True, 992 | use_one_unet_in_gpu = True, 993 | device = None, 994 | ): 995 | # ++ 996 | if self.CKeys['Debug_Level']==Imagen_Samp_Level: 997 | print (f"Debug mode for .sample func...") 998 | 999 | device = default(device, self.device) 1000 | self.reset_unets_all_one_device(device = device) 1001 | # ++ 1002 | if self.CKeys['Debug_Level']==Imagen_Samp_Level: 1003 | print (f"device for unets: {device}") 1004 | 1005 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 1006 | # ++ 1007 | if self.CKeys['Debug_Level']==Imagen_Samp_Level: 1008 | if not cond_images==None: 1009 | print (f"input cond_images.shape: {cond_images.shape}") 1010 | else: 1011 | print (f"input cond_images: None") 1012 | 1013 | # Channel t-1: use texts directly, not text_embeds; otherwise, text_embeds will be passed in 1014 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 1015 | assert all([*map(len, texts)]), 'text cannot be empty' 1016 | 1017 | with autocast(enabled = False): 1018 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 1019 | 1020 | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) 1021 | 1022 | if not self.unconditional: 1023 | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' 1024 | 1025 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 1026 | batch_size = text_embeds.shape[0] 1027 | # ++ 1028 | if self.CKeys['Debug_Level']==Imagen_Samp_Level: 1029 | if not (text_embeds==None): 1030 | print (f"text_embeds.shape: {text_embeds.shape}") 1031 | if not (text_masks==None): 1032 | print (f"text_masks.shape: {text_masks.shape}") 1033 | 1034 | # inpainting 1035 | 1036 | inpaint_images = default(inpaint_videos, inpaint_images) 1037 | 1038 | if exists(inpaint_images): 1039 | if self.unconditional: 1040 | if batch_size == 1: # assume researcher wants to broadcast along inpainted images 1041 | batch_size = inpaint_images.shape[0] 1042 | 1043 | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' 1044 | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' 1045 | 1046 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' 1047 | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' 1048 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 1049 | 1050 | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' 1051 | 1052 | outputs = [] 1053 | 1054 | is_cuda = next(self.parameters()).is_cuda 1055 | device = next(self.parameters()).device 1056 | 1057 | # will be applied to the lowres images 1058 | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) 1059 | 1060 | num_unets = len(self.unets) 1061 | cond_scale = cast_tuple(cond_scale, num_unets) 1062 | 1063 | # handle video and frame dimension 1064 | 1065 | if self.is_video and exists(inpaint_images): 1066 | video_frames = inpaint_images.shape[2] 1067 | 1068 | if inpaint_masks.ndim == 3: 1069 | inpaint_masks = repeat( 1070 | inpaint_masks, 1071 | # 'b h w -> b f h w', 1072 | 'b h -> b f h', 1073 | f = video_frames 1074 | ) 1075 | 1076 | assert inpaint_masks.shape[1] == video_frames 1077 | 1078 | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' 1079 | 1080 | # determine the frame dimensions, if needed 1081 | 1082 | all_frame_dims = calc_all_frame_dims( 1083 | self.temporal_downsample_factor, 1084 | video_frames 1085 | ) 1086 | 1087 | # initializing with an image or video 1088 | 1089 | init_images = cast_tuple(init_images, num_unets) 1090 | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] 1091 | # ++ 1092 | if self.CKeys['Debug_Level']==Imagen_Samp_Level: 1093 | print (f"init_images: {init_images}") 1094 | 1095 | skip_steps = cast_tuple(skip_steps, num_unets) 1096 | 1097 | sigma_min = cast_tuple(sigma_min, num_unets) 1098 | sigma_max = cast_tuple(sigma_max, num_unets) 1099 | 1100 | # handle starting at a unet greater than 1, for training only-upscaler training 1101 | 1102 | if start_at_unet_number > 1: 1103 | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' 1104 | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number 1105 | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' 1106 | 1107 | prev_image_size = self.image_sizes[start_at_unet_number - 2] 1108 | img = self.resize_to(start_image_or_video, prev_image_size) 1109 | 1110 | # go through each unet in cascade 1111 | 1112 | for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm( 1113 | zip( 1114 | range(1, num_unets + 1), self.unets, 1115 | self.sample_channels, self.image_sizes, 1116 | all_frame_dims, self.hparams, 1117 | self.dynamic_thresholding, cond_scale, 1118 | init_images, skip_steps, 1119 | sigma_min, sigma_max 1120 | ), 1121 | disable = not use_tqdm 1122 | ): 1123 | if unet_number < start_at_unet_number: 1124 | continue 1125 | 1126 | assert not isinstance(unet, NullUnet), 'cannot sample from null unet' 1127 | 1128 | context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext() 1129 | 1130 | with context: 1131 | lowres_cond_img = lowres_noise_times = None 1132 | 1133 | # -- 1134 | # shape = (batch_size, channel, *frame_dims, image_size, image_size) 1135 | # ++ 1136 | shape = (batch_size, channel, *frame_dims, image_size) 1137 | 1138 | resize_kwargs = dict() 1139 | video_kwargs = dict() 1140 | 1141 | if self.is_video: 1142 | resize_kwargs = dict(target_frames = frame_dims[0]) 1143 | 1144 | video_kwargs = dict( 1145 | cond_video_frames = cond_video_frames, 1146 | post_cond_video_frames = post_cond_video_frames 1147 | ) 1148 | 1149 | video_kwargs = compact(video_kwargs) 1150 | 1151 | # handle video conditioning frames 1152 | 1153 | if self.is_video and self.resize_cond_video_frames: 1154 | downsample_scale = self.temporal_downsample_factor[unet_number - 1] 1155 | temporal_downsample_fn = partial( 1156 | scale_video_time, 1157 | downsample_scale = downsample_scale 1158 | ) 1159 | video_kwargs = maybe_transform_dict_key( 1160 | video_kwargs, 'cond_video_frames', 1161 | temporal_downsample_fn 1162 | ) 1163 | video_kwargs = maybe_transform_dict_key( 1164 | video_kwargs, 'post_cond_video_frames', 1165 | temporal_downsample_fn 1166 | ) 1167 | 1168 | # low resolution conditioning 1169 | 1170 | if unet.lowres_cond: 1171 | lowres_noise_times = self.lowres_noise_schedule.get_times( 1172 | batch_size, lowres_sample_noise_level, device = device 1173 | ) 1174 | 1175 | lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs) 1176 | lowres_cond_img = self.normalize_img(lowres_cond_img) 1177 | 1178 | lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample( 1179 | x_start = lowres_cond_img, 1180 | t = lowres_noise_times, 1181 | noise = torch.randn_like(lowres_cond_img) 1182 | ) 1183 | 1184 | if exists(unet_init_images): 1185 | unet_init_images = self.resize_to( 1186 | unet_init_images, image_size, **resize_kwargs 1187 | ) 1188 | 1189 | # -- 1190 | # shape = (batch_size, self.channels, *frame_dims, image_size, image_size) 1191 | # ++ 1192 | shape = (batch_size, self.channels, *frame_dims, image_size) 1193 | 1194 | img = self.one_unet_sample( 1195 | unet, 1196 | shape, 1197 | unet_number = unet_number, 1198 | text_embeds = text_embeds, 1199 | text_mask = text_masks, 1200 | cond_images = cond_images, 1201 | inpaint_images = inpaint_images, 1202 | inpaint_masks = inpaint_masks, 1203 | inpaint_resample_times = inpaint_resample_times, 1204 | init_images = unet_init_images, 1205 | skip_steps = unet_skip_steps, 1206 | sigma_min = unet_sigma_min, 1207 | sigma_max = unet_sigma_max, 1208 | cond_scale = unet_cond_scale, 1209 | lowres_cond_img = lowres_cond_img, 1210 | lowres_noise_times = lowres_noise_times, 1211 | dynamic_threshold = dynamic_threshold, 1212 | use_tqdm = use_tqdm, 1213 | **video_kwargs 1214 | ) 1215 | 1216 | outputs.append(img) 1217 | 1218 | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: 1219 | break 1220 | 1221 | output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs 1222 | 1223 | if not return_pil_images: 1224 | return outputs[output_index] 1225 | 1226 | if not return_all_unet_outputs: 1227 | outputs = outputs[-1:] 1228 | 1229 | # assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' 1230 | 1231 | # pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs)) 1232 | 1233 | # return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png) 1234 | 1235 | # end of sampling =================================================================================== 1236 | 1237 | # training 1238 | 1239 | def loss_weight(self, sigma_data, sigma): 1240 | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 1241 | 1242 | def noise_distribution(self, P_mean, P_std, batch_size): 1243 | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() 1244 | 1245 | def forward( 1246 | self, 1247 | images, # rename to images or video 1248 | unet: Union[Unet_OneD, Unet3D, NullUnet, DistributedDataParallel] = None, 1249 | texts: List[str] = None, 1250 | text_embeds = None, 1251 | text_masks = None, 1252 | unet_number = None, 1253 | cond_images = None, 1254 | **kwargs 1255 | ): 1256 | # ++ 1257 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1258 | print (f"Now, in EImagen.forward() ...") 1259 | 1260 | if self.is_video and images.ndim == 4: 1261 | # -- 1262 | # images = rearrange(images, 'b c h w -> b c 1 h w') 1263 | # ++ 1264 | images = rearrange(images, 'b c h -> b c 1 h') 1265 | kwargs.update(ignore_time = True) 1266 | 1267 | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' 1268 | unet_number = default(unet_number, 1) 1269 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' 1270 | 1271 | images = cast_uint8_images_to_float(images) # do nothing if input is not uint8: float btw (0, 1) 1272 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 1273 | # ++ for one_D need adjustment 1274 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1275 | print (f"Reformat images and cond_images into float") 1276 | print (f"images.shape: {images.shape}") 1277 | print (f"images.dtype: {images.dtype}") 1278 | print (f"max and min: {torch.max(images)} and {torch.min(images)}") 1279 | if not (cond_images==None): 1280 | print (f"cond_images.shape: {cond_images.shape}") 1281 | else: 1282 | print (f"cond_images: None") 1283 | 1284 | assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead' 1285 | 1286 | unet_index = unet_number - 1 1287 | 1288 | unet = default(unet, lambda: self.get_unet(unet_number)) 1289 | 1290 | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' 1291 | 1292 | target_image_size = self.image_sizes[unet_index] 1293 | random_crop_size = self.random_crop_sizes[unet_index] 1294 | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None 1295 | hp = self.hparams[unet_index] 1296 | # ++ 1297 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1298 | print (f"target_image_size: {target_image_size}") 1299 | print (f"random_crop_size: {random_crop_size}") 1300 | print (f"prev_image_size: {prev_image_size}") 1301 | print (f"hp: {hp}") 1302 | 1303 | # -- 1304 | # batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5) 1305 | # ++ 1306 | batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4) 1307 | # ++ 1308 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1309 | print (f"batch_size: {batch_size}") 1310 | print (f"channel c: {c}") 1311 | print (f"1d image size, h: {h} ") 1312 | 1313 | 1314 | frames = images.shape[2] if is_video else None 1315 | all_frame_dims = tuple( 1316 | safe_get_tuple_index(el, 0) for el in calc_all_frame_dims( 1317 | self.temporal_downsample_factor, frames 1318 | ) 1319 | ) 1320 | ignore_time = kwargs.get('ignore_time', False) 1321 | # ++ 1322 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1323 | print (f"frames: {frames}") 1324 | print (f"all_frame_dims: {all_frame_dims}") 1325 | print (f"ignore_time: {ignore_time}") 1326 | 1327 | target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None 1328 | prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None 1329 | frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() 1330 | # ++ 1331 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1332 | print (f"target_frame_size: {target_frame_size}") 1333 | print (f"prev_frame_size: {prev_frame_size}") 1334 | print (f"frames_to_resize_kwargs: {frames_to_resize_kwargs}") 1335 | 1336 | assert images.shape[1] == self.channels 1337 | assert h >= target_image_size # and w >= target_image_size 1338 | 1339 | # texts provided, not text_embeds 1340 | # 1341 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 1342 | assert all([*map(len, texts)]), 'text cannot be empty' 1343 | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' 1344 | 1345 | with autocast(enabled = False): 1346 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 1347 | 1348 | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) 1349 | # now we have text_embeds, and text_masks 1350 | 1351 | if not self.unconditional: 1352 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 1353 | # ++ 1354 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1355 | # print (f"text_masks: \n{text_masks}") 1356 | print (f"text_masks.shape: \n{text_masks.shape}") 1357 | 1358 | assert not ( 1359 | self.condition_on_text and not exists(text_embeds) 1360 | ), 'text or text encodings must be passed into decoder if specified' 1361 | assert not ( 1362 | not self.condition_on_text and exists(text_embeds) 1363 | ), 'decoder specified not to be conditioned on text, yet it is presented' 1364 | 1365 | assert not ( 1366 | exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim 1367 | ), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 1368 | 1369 | # handle video conditioning frames 1370 | 1371 | if self.is_video and self.resize_cond_video_frames: 1372 | downsample_scale = self.temporal_downsample_factor[unet_index] 1373 | temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale) 1374 | kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn) 1375 | kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn) 1376 | 1377 | # low resolution conditioning 1378 | # this part is on if the trained one is the 2nd unet 1379 | 1380 | lowres_cond_img = lowres_aug_times = None 1381 | if exists(prev_image_size): # so, this is the 2nd unet 1382 | # ++ 1383 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1384 | print (f"prev_image_size detected. So, this is for 2nd UNet") 1385 | print (f"Create lowres_cond_img by resizing true image") 1386 | print (f" images.shape: {images.shape}") 1387 | lowres_cond_img = self.resize_to( 1388 | images, 1389 | prev_image_size, 1390 | **frames_to_resize_kwargs(prev_frame_size), 1391 | clamp_range = self.input_image_range 1392 | ) 1393 | # ++ 1394 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1395 | print (f" 1. resize full image to previous size: coarsening") 1396 | print (f" .resize_to(images,prev_image_size)->lowres_cond_img.shape: {lowres_cond_img.shape}") 1397 | lowres_cond_img = self.resize_to( 1398 | lowres_cond_img, target_image_size, 1399 | **frames_to_resize_kwargs(target_frame_size), 1400 | clamp_range = self.input_image_range 1401 | ) 1402 | # ++ 1403 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1404 | print (f" 2. fit into the traget size: only change size") 1405 | print (f" .resize_to(lowres_cond_img,target_image_size)->lowres_cond_img.shape: {lowres_cond_img.shape}") 1406 | 1407 | if self.per_sample_random_aug_noise_level: 1408 | lowres_aug_times = self.lowres_noise_schedule.sample_random_times( 1409 | batch_size, device = device 1410 | ) 1411 | else: # i.e., all samples in the batch use the same 1412 | lowres_aug_time = self.lowres_noise_schedule.sample_random_times( 1413 | 1, device = device 1414 | ) 1415 | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) 1416 | # ++ 1417 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1418 | print (f"get noise schedule for lowres_cond_img, lowres_aug_time.shape: {lowres_aug_times.shape}") 1419 | 1420 | # ++ 1421 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1422 | print (f"images.shape: {images.shape}") 1423 | images = self.resize_to( 1424 | images, 1425 | target_image_size, 1426 | **frames_to_resize_kwargs(target_frame_size) 1427 | ) 1428 | # not triggered if images.shape[-1]==target_image_size 1429 | # ++ 1430 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1431 | print (f".resize_to() -> images.shape: {images.shape}") 1432 | 1433 | # normalize to [-1, 1] 1434 | # ++ 1435 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1436 | print (f"Bef. normalize_img, min/ax of images: {torch.max(images)}, {torch.min(images)}") 1437 | print (f".normalize_img: {self.normalize_img}") 1438 | images = self.normalize_img(images) # assume images (0,1)->(-1,1) 1439 | lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) 1440 | # ++ 1441 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1442 | print (f"After normalize_img, should be [-1,1]") 1443 | print (f"images max and min: {torch.max(images)} and {torch.min(images)}") 1444 | if exists(lowres_cond_img): 1445 | print (f"lowres_cond_img max and min: {torch.max(lowres_cond_img)} and {torch.min(lowres_cond_img)}") 1446 | 1447 | # random cropping during training 1448 | # for upsamplers 1449 | 1450 | if exists(random_crop_size): 1451 | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) 1452 | 1453 | if is_video: 1454 | images, lowres_cond_img = map( 1455 | # -- 1456 | # lambda t: rearrange(t, 'b c f h w -> (b f) c h w'), 1457 | # ++ 1458 | lambda t: rearrange(t, 'b c f h -> (b f) c h'), 1459 | (images, lowres_cond_img) 1460 | ) 1461 | 1462 | # make sure low res conditioner and image both get augmented the same way 1463 | # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop 1464 | images = aug(images) 1465 | lowres_cond_img = aug(lowres_cond_img, params = aug._params) 1466 | 1467 | if is_video: 1468 | images, lowres_cond_img = map( 1469 | # -- 1470 | # lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames), 1471 | # ++ 1472 | lambda t: rearrange(t, '(b f) c h -> b c f h', f = frames), 1473 | (images, lowres_cond_img) 1474 | ) 1475 | 1476 | # noise the lowres conditioning image 1477 | # at sample time, they then fix the noise level of 0.1 - 0.3 1478 | 1479 | lowres_cond_img_noisy = None 1480 | if exists(lowres_cond_img): 1481 | lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample( 1482 | x_start = lowres_cond_img, 1483 | t = lowres_aug_times, 1484 | noise = torch.randn_like(lowres_cond_img) 1485 | ) 1486 | # ++ 1487 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1488 | print (f"add noise to lowres_cond_img...") 1489 | print (f"lowres_cond_img_noisy.shape: {lowres_cond_img_noisy.shape}") 1490 | 1491 | 1492 | # get the sigmas 1493 | 1494 | sigmas = self.noise_distribution( 1495 | hp.P_mean, hp.P_std, batch_size 1496 | ) 1497 | padded_sigmas = self.right_pad_dims_to_datatype(sigmas) 1498 | # ++ 1499 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1500 | print (f"sigmas.shape: {sigmas.shape}") 1501 | print (f"padded_sigmas.shape: {padded_sigmas.shape}") 1502 | 1503 | # noise 1504 | 1505 | noise = torch.randn_like(images) 1506 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 1507 | # ++ 1508 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1509 | print (f"add noise into images") 1510 | 1511 | # unet kwargs 1512 | 1513 | unet_kwargs = dict( 1514 | sigma_data = hp.sigma_data, 1515 | text_embeds = text_embeds, 1516 | text_mask = text_masks, 1517 | cond_images = cond_images, 1518 | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), 1519 | lowres_cond_img = lowres_cond_img_noisy, 1520 | cond_drop_prob = self.cond_drop_prob, 1521 | **kwargs 1522 | ) 1523 | 1524 | # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower 1525 | 1526 | # Because 'unet' can be an instance of DistributedDataParallel coming from the 1527 | # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to 1528 | # access the member 'module' of the wrapped unet instance. 1529 | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond 1530 | 1531 | if self_cond and random() < 0.5: 1532 | # ++ 1533 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1534 | print (f"self_cond is triggered.") 1535 | print (f"get into unet.......") 1536 | 1537 | with torch.no_grad(): 1538 | pred_x0 = self.preconditioned_network_forward( 1539 | unet.forward, 1540 | noised_images, 1541 | sigmas, 1542 | **unet_kwargs 1543 | ).detach() 1544 | # ++ 1545 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1546 | print (f"get out of unet.......") 1547 | print (f"prop noised_images via the net.") 1548 | print (f"get pred_x0.shape: {pred_x0.shape}") 1549 | 1550 | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} 1551 | 1552 | # get prediction 1553 | # ++ 1554 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1555 | # print (f"unet_kwargs: \n{unet_kwargs}") 1556 | print (f"unet_kwargs include keys: ") 1557 | for this_key in unet_kwargs.keys(): 1558 | print (" "+this_key) 1559 | if torch.is_tensor(unet_kwargs[this_key]): 1560 | print (f" {unet_kwargs[this_key].shape}") 1561 | else: 1562 | print (f" {type(unet_kwargs[this_key])}") 1563 | 1564 | denoised_images = self.preconditioned_network_forward( 1565 | unet.forward, 1566 | noised_images, 1567 | sigmas, 1568 | **unet_kwargs 1569 | ) 1570 | # ++ 1571 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1572 | print (f"Back to EImagen, get denoised_images") 1573 | print (f"denoised_images.shape: {denoised_images.shape}") 1574 | 1575 | # losses 1576 | 1577 | losses = F.mse_loss(denoised_images, images, reduction = 'none') 1578 | losses = reduce(losses, 'b ... -> b', 'mean') 1579 | 1580 | # loss weighting 1581 | 1582 | losses = losses * self.loss_weight(hp.sigma_data, sigmas) 1583 | # ++ 1584 | if self.CKeys['Debug_Level']==Imagen_Forw_Level: 1585 | print (f"postprocess losses based on hp.sigma_data") 1586 | 1587 | # return average loss 1588 | 1589 | return losses.mean() 1590 | --------------------------------------------------------------------------------