├── VibeGen
├── __init__.py
├── TrainerPack_advanced.py
├── UtilityPack.py
├── JointSamplingPack.py
├── DataSetPack.py
└── imagen_x_imagen_pytorch.py
├── setup.py
├── .gitignore
├── README.md
├── requirements.txt
└── LICENSE.txt
/VibeGen/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # VibeGen/__init__.py
3 |
4 | # import .UtilityPack as UPack
5 |
6 | # import .DataSetPack as DPack
7 |
8 | # import .ModelPack as MPack
9 |
10 | # import .TrainerPack as TPack
11 |
12 | # import .JointSamplingPack as SPack
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | # Function to read requirements.txt
4 | def parse_requirements(filename):
5 | with open(filename, 'r') as req_file:
6 | return req_file.read().splitlines()
7 |
8 | setup(
9 | name='VibeGen',
10 | version='0.1.0',
11 | packages=find_packages(),
12 | install_requires=parse_requirements('requirements.txt'),
13 | description='VibeGen: End-to-end de novo protein generation targeting normal mode vibrations using a language diffusion model duo',
14 | author='Bo Ni',
15 | url='https://github.com/lamm-mit/ModeShapeDiffusionDesign',
16 | )
17 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # +++
165 | Local_Store/
166 |
167 | VibeGen/working_note.MD
168 |
169 | wk_dir/
170 |
171 | trained_duo/
172 |
173 | VibeGen_env/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VibeGen: Agentic End-to-End De Novo Protein Design for Tailored Dynamics Using a Language Diffusion Model
2 |
3 | Bo Ni1,2, Markus J. Buehler1,3,4*
4 |
5 | 1 Laboratory for Atomistic and Molecular Mechanics (LAMM), Massachusetts Institute of Technology
6 |
7 | 2 Department of Materials Science and Engineering, Carnegie Mellon University
8 |
9 | 3 Center for Computational Science and Engineering, Schwarzman College of Computing, Massachusetts Institute of Technology
10 |
11 | 4 Lead contact
12 |
13 | * Correspondence: mbuehler@MIT.EDU
14 |
15 | Proteins are dynamic molecular machines whose biological functions, spanning enzymatic catalysis, signal transduction, and structural adaptation, are intrinsically linked to their motions. We introduce VibeGen, a generative AI model based on an agentic dual-model architecture, comprising a protein designer that generates sequence candidates based on specified vibrational modes and a protein predictor that evaluates their dynamic accuracy. Via direct validation using full-atom molecular simulations, we demonstrate that the designed proteins accurately reproduce the prescribed normal mode amplitudes across the backbone while adopting various stable, functionally relevant structures. Generated sequences are de novo, exhibiting no significant similarity to natural proteins, thereby expanding the accessible protein space beyond evolutionary constraints. Our model establishes a direct, bidirectional link between sequence and vibrational behavior, unlocking new pathways for engineering biomolecules with tailored dynamical and functional properties. Our model holds broad implications for the rational design of enzymes, dynamic scaffolds, and biomaterials via dynamics-informed protein engineering.
16 |
17 | 
18 |
19 | ## Installation
20 |
21 | Create a virtual environment
22 |
23 | ```bash
24 | conda create --prefix=./VibeGen_env
25 | conda activate ./VibeGen_env
26 |
27 | ```
28 |
29 | Install:
30 | ```bash
31 | pip install git+https://github.com/lamm-mit/ModeShapeDiffusionDesign.git
32 |
33 | ```
34 | If you want to create an editable installation, clone the repository using `git`:
35 | ```bash
36 | git clone https://github.com/lamm-mit/ModeShapeDiffusionDesign.git
37 | cd ModeShapeDiffusionDesign
38 | ```
39 | Then, install:
40 | ```bash
41 | pip install -r requirements.txt
42 | pip install -e .
43 | ```
44 |
45 | ### Directory structure
46 | ```
47 | ModeShapeDiffusionDesign/
48 | │
49 | ├── VibeGen/ # Source code directory
50 | │ ├── DataSetPack.py
51 | │ ├── ModelPack.py
52 | │ ├── TrainerPack.py
53 | │ ├── UtilityPack.py
54 | │ ├── JointSamplingPack.py
55 | │ └── ...
56 | │
57 | ├── demo_1_Inferrence_with_trained_duo.ipynb # demo 1: make an inference
58 | │
59 | ├── colab_demo/ # demos for colab
60 | │ ├── Inference_demo.ipynb # demo 1: make an inference
61 | │ └── ...
62 | │
63 | ├── setup.py # The setup file for packaging
64 | ├── requirements.txt # List of dependencies
65 | ├── README.md # Documentation
66 | ├── assets/ # Support materials
67 | └── ...
68 | ```
69 |
70 | ## Usage
71 |
72 | ### Inference notebooks
73 | In the following example, for each input normal mode shape condition, we use the trained ProteinDesigner to propose 20 candidates. Then the trained ProteinPredictor will pick the best and worst two from them based on its predition. The chosen seqeucnes then will be folded using OmegaFold and the seondary strucutre of them will be analyzed.
74 |
75 | ```
76 | demo_1_inference_with_trained_duo.ipynb
77 | ```
78 |
79 | Alternatively, similar demo can run using Colab.
80 |
81 | [](https://colab.research.google.com/github/lamm-mit/ModeShapeDiffusionDesign/blob/main/colab_demo/Inference_demo.ipynb)
82 |
83 | ### Pretrained models
84 | The checkpoints of the pretrained models that make up the agentic system is hosted at the [repository](https://huggingface.co/lamm-mit/VibeGen) on Huggingface.
85 |
86 | ### Reference
87 |
88 | ```bibtex
89 | @paper{BoBuehler2025VibeGen,
90 | title={VibeGen: Agentic End-to-End De Novo Protein Design for Tailored Dynamics Using a Language Diffusion Model},
91 | author={Bo Ni and Markus J. Buehler},
92 | year={2025},
93 | eprint={2502.10173},
94 | archivePrefix={arXiv},
95 | primaryClass={q-bio.BM},
96 | url={https://arxiv.org/abs/2502.10173},
97 | }
98 | ```
99 |
100 | Our implementation is inspired by the [imagen-pytorch](https://github.com/lucidrains/imagen-pytorch) repository by [Phil Wang](https://github.com/lucidrains).
101 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.32.1
3 | aiofiles # @ file:///croot/aiofiles_1683773582346/work
4 | aiohttp==3.9.5
5 | aiosignal==1.3.1
6 | aiosqlite # @ file:///croot/aiosqlite_1683773899903/work
7 | annotated-types==0.7.0
8 | anyio # @ file:///tmp/build/80754af9/anyio_1644481695334/work/dist
9 | argon2-cffi # @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
10 | argon2-cffi-bindings # @ file:///tmp/build/80754af9/argon2-cffi-bindings_1644553347904/work
11 | asttokens # @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
12 | astunparse==1.6.3
13 | async-timeout==4.0.3
14 | attrs # @ file:///croot/attrs_1695717823297/work
15 | Babel # @ file:///croot/babel_1671781930836/work
16 | backcall # @ file:///home/ktietz/src/ci/backcall_1611930011877/work
17 | beartype==0.18.5
18 | beautifulsoup4 # @ file:///croot/beautifulsoup4-split_1681493039619/work
19 | biopython==1.84
20 | bleach # @ file:///opt/conda/conda-bld/bleach_1641577558959/work
21 | Brotli # @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work
22 | certifi # @ file:///croot/certifi_1700501669400/work/certifi
23 | cffi # @ file:///croot/cffi_1700254295673/work
24 | charset-normalizer # @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
25 | comm # @ file:///croot/comm_1671231121260/work
26 | contourpy==1.2.1
27 | cryptography # @ file:///croot/cryptography_1694444244250/work
28 | cycler==0.12.1
29 | datasets==2.20.0
30 | debugpy # @ file:///croot/debugpy_1690905042057/work
31 | decorator # @ file:///opt/conda/conda-bld/decorator_1643638310831/work
32 | defusedxml # @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
33 | dill==0.3.8
34 | einops==0.8.0
35 | einops-exts==0.0.4
36 | ema-pytorch==0.5.2
37 | evaluate==0.4.2
38 | exceptiongroup # @ file:///croot/exceptiongroup_1668714342571/work
39 | executing # @ file:///opt/conda/conda-bld/executing_1646925071911/work
40 | fair-esm==2.0.0
41 | fastjsonschema # @ file:///opt/conda/conda-bld/python-fastjsonschema_1661371079312/work
42 | filelock # @ file:///croot/filelock_1700591183607/work
43 | flatbuffers==24.3.25
44 | fonttools==4.53.1
45 | frozenlist==1.4.1
46 | fsspec==2024.5.0
47 | gast==0.6.0
48 | gmpy2 # @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
49 | google-pasta==0.2.0
50 | grpcio==1.65.4
51 | h5py==3.11.0
52 | huggingface-hub==0.23.4
53 | idna # @ file:///croot/idna_1666125576474/work
54 | ipykernel # @ file:///croot/ipykernel_1691121631942/work
55 | ipython # @ file:///croot/ipython_1694181358621/work
56 | ipython-genutils # @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
57 | ipywidgets # @ file:///croot/ipywidgets_1679394798311/work
58 | jedi # @ file:///tmp/build/80754af9/jedi_1644315229345/work
59 | Jinja2 # @ file:///croot/jinja2_1666908132255/work
60 | joblib==1.4.2
61 | json5 # @ file:///tmp/build/80754af9/json5_1624432770122/work
62 | jsonschema # @ file:///croot/jsonschema_1699041609003/work
63 | jsonschema-specifications # @ file:///croot/jsonschema-specifications_1699032386549/work
64 | jupyter # @ file:///tmp/abs_33h4eoipez/croots/recipe/jupyter_1659349046347/work
65 | jupyter-console # @ file:///croot/jupyter_console_1679999630278/work
66 | jupyter-events # @ file:///croot/jupyter_events_1699282461638/work
67 | jupyter-ydoc # @ file:///croot/jupyter_ydoc_1683747223142/work
68 | jupyter_client # @ file:///croot/jupyter_client_1699455897726/work
69 | jupyter_core # @ file:///croot/jupyter_core_1698937308754/work
70 | jupyter_server # @ file:///croot/jupyter_server_1699466442171/work
71 | jupyter_server_fileid # @ file:///croot/jupyter_server_fileid_1684273577568/work
72 | jupyter_server_terminals # @ file:///croot/jupyter_server_terminals_1686870725608/work
73 | jupyter_server_ydoc # @ file:///croot/jupyter_server_ydoc_1686767404829/work
74 | jupyterlab # @ file:///croot/jupyterlab_1686179668131/work
75 | jupyterlab-pygments # @ file:///croot/jupyterlab_pygments_1700168593176/work
76 | jupyterlab-widgets # @ file:///croot/jupyterlab_widgets_1700168618520/work
77 | jupyterlab_server # @ file:///croot/jupyterlab_server_1699555425460/work
78 | keras==3.4.1
79 | kiwisolver==1.4.5
80 | kornia==0.7.3
81 | kornia_rs==0.1.5
82 | libclang==18.1.1
83 | Markdown==3.6
84 | markdown-it-py==3.0.0
85 | MarkupSafe # @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
86 | matplotlib==3.9.1
87 | matplotlib-inline # @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work
88 | mdurl==0.1.2
89 | mistune # @ file:///opt/conda/conda-bld/mistune_1661496219659/work
90 | mkl-fft # @ file:///croot/mkl_fft_1695058164594/work
91 | mkl-random # @ file:///croot/mkl_random_1695059800811/work
92 | # mkl-service==2.4.0
93 | mkl-service==2.4.2
94 | ml-dtypes==0.4.0
95 | mpmath # @ file:///croot/mpmath_1690848262763/work
96 | multidict==6.0.5
97 | multiprocess==0.70.16
98 | namex==0.0.8
99 | nbclassic # @ file:///croot/nbclassic_1699542793266/work
100 | nbclient # @ file:///croot/nbclient_1698934205032/work
101 | nbconvert # @ file:///croot/nbconvert_1699022732553/work
102 | nbformat # @ file:///croot/nbformat_1694616755618/work
103 | nest-asyncio # @ file:///croot/nest-asyncio_1672387112409/work
104 | networkx # @ file:///croot/networkx_1720002482208/work
105 | notebook # @ file:///croot/notebook_1681756172480/work
106 | notebook_shim # @ file:///croot/notebook-shim_1699455894279/work
107 | numpy # @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee
108 | nvidia-cublas-cu12==12.3.4.1
109 | nvidia-cuda-cupti-cu12==12.3.101
110 | nvidia-cuda-nvcc-cu12==12.3.107
111 | nvidia-cuda-nvrtc-cu12==12.3.107
112 | nvidia-cuda-runtime-cu12==12.3.101
113 | nvidia-cudnn-cu12==8.9.7.29
114 | nvidia-cufft-cu12==11.0.12.1
115 | nvidia-curand-cu12==10.3.4.107
116 | nvidia-cusolver-cu12==11.5.4.101
117 | nvidia-cusparse-cu12==12.2.0.103
118 | nvidia-nccl-cu12==2.19.3
119 | nvidia-nvjitlink-cu12==12.3.101
120 | OmegaFold @ git+https://github.com/Bo-Ni/OmegaFold_0.git@3db771f153c247dd3686abdf4495735a4f36d933
121 | opt-einsum==3.3.0
122 | optree==0.12.1
123 | overrides # @ file:///croot/overrides_1699371140756/work
124 | packaging # @ file:///croot/packaging_1693575174725/work
125 | pandas==2.2.2
126 | pandocfilters # @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
127 | parso # @ file:///opt/conda/conda-bld/parso_1641458642106/work
128 | PeptideBuilder==1.1.0
129 | pexpect # @ file:///tmp/build/80754af9/pexpect_1605563209008/work
130 | pickleshare # @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
131 | pillow # @ file:///croot/pillow_1714398848491/work
132 | platformdirs # @ file:///croot/platformdirs_1692205439124/work
133 | ply==3.11
134 | prometheus-client # @ file:///tmp/abs_d3zeliano1/croots/recipe/prometheus_client_1659455100375/work
135 | prompt-toolkit # @ file:///croot/prompt-toolkit_1672387306916/work
136 | protobuf==4.25.4
137 | psutil # @ file:///opt/conda/conda-bld/psutil_1656431268089/work
138 | ptyprocess # @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
139 | pure-eval # @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
140 | py3Dmol==2.2.0
141 | pyarrow==16.1.0
142 | pyarrow-hotfix==0.6
143 | pycparser # @ file:///tmp/build/80754af9/pycparser_1636541352034/work
144 | pydantic==2.8.2
145 | pydantic_core==2.20.1
146 | Pygments # @ file:///croot/pygments_1684279966437/work
147 | pyOpenSSL # @ file:///croot/pyopenssl_1690223430423/work
148 | pyparsing==3.1.2
149 | PyQt5==5.15.10
150 | PyQt5-sip # @ file:///croot/pyqt-split_1698769088074/work/pyqt_sip
151 | PySocks # @ file:///home/builder/ci_310/pysocks_1640793678128/work
152 | python-dateutil # @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
153 | python-json-logger # @ file:///croot/python-json-logger_1683823803357/work
154 | pytorch-warmup==0.1.1
155 | pytz # @ file:///croot/pytz_1695131579487/work
156 | PyYAML # @ file:///croot/pyyaml_1698096049011/work
157 | pyzmq # @ file:///croot/pyzmq_1686601365461/work
158 | qtconsole # @ file:///croot/qtconsole_1700160644874/work
159 | QtPy # @ file:///croot/qtpy_1700144840038/work
160 | referencing # @ file:///croot/referencing_1699012038513/work
161 | regex==2024.5.15
162 | requests==2.32.3
163 | rfc3339-validator # @ file:///croot/rfc3339-validator_1683077044675/work
164 | rfc3986-validator # @ file:///croot/rfc3986-validator_1683058983515/work
165 | rich==13.7.1
166 | rpds-py # @ file:///croot/rpds-py_1698945930462/work
167 | safetensors==0.4.3
168 | scikit-learn==1.5.1
169 | scipy==1.14.0
170 | seaborn==0.13.2
171 | Send2Trash # @ file:///croot/send2trash_1699371139552/work
172 | sentencepiece==0.2.0
173 | sip # @ file:///croot/sip_1698675935381/work
174 | six # @ file:///tmp/build/80754af9/six_1644875935023/work
175 | sniffio # @ file:///home/builder/ci_310/sniffio_1640794799774/work
176 | soupsieve # @ file:///croot/soupsieve_1696347547217/work
177 | stack-data # @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
178 | sympy # @ file:///croot/sympy_1701397643339/work
179 | tensorboard==2.17.0
180 | tensorboard-data-server==0.7.2
181 | tensorflow==2.17.0
182 | tensorflow-io-gcs-filesystem==0.37.1
183 | termcolor==2.4.0
184 | terminado # @ file:///croot/terminado_1671751832461/work
185 | threadpoolctl==3.5.0
186 | tinycss2 # @ file:///croot/tinycss2_1668168815555/work
187 | tokenizers==0.19.1
188 | tomli # @ file:///opt/conda/conda-bld/tomli_1657175507142/work
189 | torch==2.3.1
190 | torchaudio==2.3.1
191 | torchinfo==1.8.0
192 | torchvision==0.18.1
193 | tornado # @ file:///croot/tornado_1696936946304/work
194 | tqdm==4.66.4
195 | traitlets # @ file:///croot/traitlets_1671143879854/work
196 | transformers==4.42.4
197 | triton==2.3.1
198 | typing_extensions==4.8.0
199 | tzdata==2024.1
200 | urllib3 # @ file:///croot/urllib3_1698257533958/work
201 | wcwidth # @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
202 | webencodings==0.5.1
203 | websocket-client # @ file:///home/builder/ci_310/websocket-client_1640795866898/work
204 | Werkzeug==3.0.3
205 | widgetsnbextension # @ file:///croot/widgetsnbextension_1679313860248/work
206 | wrapt==1.16.0
207 | xxhash==3.4.1
208 | y-py # @ file:///croot/y-py_1683555143488/work
209 | yarl==1.9.4
210 | ypy-websocket # @ file:///croot/ypy-websocket_1684171737040/work
211 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/VibeGen/TrainerPack_advanced.py:
--------------------------------------------------------------------------------
1 | """
2 | Task:
3 | 1. create a trainer for ProteinDesigner
4 | 2. include train_loop, sample_loop
5 |
6 | Bo Ni, Sep 8, 2024
7 | """
8 |
9 | # //////////////////////////////////////////////////////
10 | # 0. load in packages
11 | # //////////////////////////////////////////////////////
12 |
13 | import os
14 | from math import ceil
15 | from contextlib import contextmanager, nullcontext
16 | from functools import partial, wraps
17 | from collections.abc import Iterable
18 |
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 | from torch.utils.data import random_split, DataLoader
23 | from torch.optim import Adam
24 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
25 | from torch.cuda.amp import autocast, GradScaler
26 |
27 | import pytorch_warmup as warmup
28 |
29 | from packaging import version
30 |
31 | import numpy as np
32 |
33 | from ema_pytorch import EMA
34 |
35 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
36 |
37 | from fsspec.core import url_to_fs
38 | from fsspec.implementations.local import LocalFileSystem
39 |
40 | # //////////////////////////////////////////////////////////////
41 | # 2. special packages
42 | # //////////////////////////////////////////////////////////////
43 | from VibeGen.ModelPack import (
44 | ProteinDesigner_Base
45 | )
46 | from VibeGen.imagen_x_imagen_pytorch import (
47 | ElucidatedImagen_OneD, eval_decorator
48 | )
49 |
50 | # //////////////////////////////////////////////////////////////
51 | # 3. local setup parameters: for debug purpose
52 | # //////////////////////////////////////////////////////////////
53 | PT_Init_Level = 1
54 | PT_Forw_Level = 1
55 |
56 | # //////////////////////////////////////////////////////////////
57 | # 4. helper functions
58 | # //////////////////////////////////////////////////////////////
59 | def cycle(dl):
60 | while True:
61 | for data in dl:
62 | yield data
63 |
64 | def exists(val):
65 | return val is not None
66 |
67 | def default(val, d):
68 | if exists(val):
69 | return val
70 | return d() if callable(d) else d
71 |
72 | def cast_tuple(val, length = 1):
73 | if isinstance(val, list):
74 | val = tuple(val)
75 |
76 | return val if isinstance(val, tuple) else ((val,) * length)
77 |
78 | def find_first(fn, arr):
79 | for ind, el in enumerate(arr):
80 | if fn(el):
81 | return ind
82 | return -1
83 |
84 | def pick_and_pop(keys, d):
85 | values = list(map(lambda key: d.pop(key), keys))
86 | return dict(zip(keys, values))
87 |
88 | def group_dict_by_key(cond, d):
89 | return_val = [dict(),dict()]
90 | for key in d.keys():
91 | match = bool(cond(key))
92 | ind = int(not match)
93 | return_val[ind][key] = d[key]
94 | return (*return_val,)
95 |
96 | def string_begins_with(prefix, str):
97 | return str.startswith(prefix)
98 |
99 | def group_by_key_prefix(prefix, d):
100 | return group_dict_by_key(partial(string_begins_with, prefix), d)
101 |
102 | def groupby_prefix_and_trim(prefix, d):
103 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
104 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
105 | return kwargs_without_prefix, kwargs
106 |
107 | def num_to_groups(num, divisor):
108 | groups = num // divisor
109 | remainder = num % divisor
110 | arr = [divisor] * groups
111 | if remainder > 0:
112 | arr.append(remainder)
113 | return arr
114 |
115 | # url to fs, bucket, path - for checkpointing to cloud
116 |
117 | def url_to_bucket(url):
118 | if '://' not in url:
119 | return url
120 |
121 | _, suffix = url.split('://')
122 |
123 | if prefix in {'gs', 's3'}:
124 | return suffix.split('/')[0]
125 | else:
126 | raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
127 |
128 | # decorators
129 |
130 | def eval_decorator(fn):
131 | def inner(model, *args, **kwargs):
132 | was_training = model.training
133 | model.eval()
134 | out = fn(model, *args, **kwargs)
135 | model.train(was_training)
136 | return out
137 | return inner
138 |
139 | def cast_torch_tensor(fn, cast_fp16 = False):
140 | @wraps(fn)
141 | def inner(model, *args, **kwargs):
142 | device = kwargs.pop('_device', model.device)
143 | cast_device = kwargs.pop('_cast_device', True)
144 |
145 | should_cast_fp16 = cast_fp16 and model.cast_half_at_training
146 |
147 | kwargs_keys = kwargs.keys()
148 | all_args = (*args, *kwargs.values())
149 | split_kwargs_index = len(all_args) - len(kwargs_keys)
150 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
151 |
152 | if cast_device:
153 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
154 |
155 | if should_cast_fp16:
156 | all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
157 |
158 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
159 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
160 |
161 | out = fn(model, *args, **kwargs)
162 | return out
163 | return inner
164 |
165 | # gradient accumulation functions
166 |
167 | def split_iterable(it, split_size):
168 | accum = []
169 | for ind in range(ceil(len(it) / split_size)):
170 | start_index = ind * split_size
171 | accum.append(it[start_index: (start_index + split_size)])
172 | return accum
173 |
174 | def split(t, split_size = None):
175 | if not exists(split_size):
176 | return t
177 |
178 | if isinstance(t, torch.Tensor):
179 | return t.split(split_size, dim = 0)
180 |
181 | if isinstance(t, Iterable):
182 | return split_iterable(t, split_size)
183 |
184 | return TypeError
185 |
186 | def find_first(cond, arr):
187 | for el in arr:
188 | if cond(el):
189 | return el
190 | return None
191 |
192 | def split_args_and_kwargs(*args, split_size = None, **kwargs):
193 | all_args = (*args, *kwargs.values())
194 | len_all_args = len(all_args)
195 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
196 | assert exists(first_tensor)
197 |
198 | batch_size = len(first_tensor)
199 | split_size = default(split_size, batch_size)
200 | num_chunks = ceil(batch_size / split_size)
201 |
202 | dict_len = len(kwargs)
203 | dict_keys = kwargs.keys()
204 | split_kwargs_index = len_all_args - dict_len
205 |
206 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
207 | chunk_sizes = num_to_groups(batch_size, split_size)
208 |
209 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
210 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
211 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
212 | chunk_size_frac = chunk_size / batch_size
213 | yield chunk_size_frac, (chunked_args, chunked_kwargs)
214 |
215 |
216 | # imagen trainer
217 |
218 | def imagen_sample_in_chunks(fn):
219 | @wraps(fn)
220 | def inner(self, *args, max_batch_size = None, **kwargs):
221 | if not exists(max_batch_size):
222 | return fn(self, *args, **kwargs)
223 |
224 | if self.imagen.unconditional:
225 | batch_size = kwargs.get('batch_size')
226 | batch_sizes = num_to_groups(batch_size, max_batch_size)
227 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
228 | else:
229 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
230 |
231 | if isinstance(outputs[0], torch.Tensor):
232 | return torch.cat(outputs, dim = 0)
233 |
234 | return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))
235 |
236 | return inner
237 |
238 |
239 | def restore_parts(state_dict_target, state_dict_from):
240 | for name, param in state_dict_from.items():
241 |
242 | if name not in state_dict_target:
243 | continue
244 |
245 | if param.size() == state_dict_target[name].size():
246 | state_dict_target[name].copy_(param)
247 | else:
248 | print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
249 |
250 | return state_dict_target
251 |
252 | # //////////////////////////////////////////////////////////////
253 | # 5. Main class:
254 | # //////////////////////////////////////////////////////////////
255 |
256 | class ProteinDesigner_Trainer(nn.Module):
257 | locked = False
258 |
259 | def __init__(
260 | self,
261 | # 1. on models
262 | ProtDesi = None, # provide a object
263 | ProtDesi_checkpoint_path = None, # provide a checkpoint path
264 | only_train_unet_number = None,
265 | # 2. on optimizer
266 | use_ema = True,
267 | lr = 1e-4,
268 | eps = 1e-8,
269 | beta1 = 0.9,
270 | beta2 = 0.99,
271 | max_grad_norm = None,
272 | group_wd_params = True,
273 | warmup_steps = None,
274 | cosine_decay_max_steps = None,
275 |
276 | fp16 = False,
277 | precision = None,
278 | split_batches = True,
279 | dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
280 | verbose = True,
281 | split_valid_fraction = 0.025,
282 | split_valid_from_train = False,
283 | split_random_seed = 42,
284 | checkpoint_path = None,
285 | checkpoint_every = None,
286 | checkpoint_fs = None,
287 | fs_kwargs: dict = None,
288 | max_checkpoints_keep = 20,
289 | # ++
290 | CKeys = {'Debug_Level':0},
291 | **kwargs
292 | ):
293 | super().__init__()
294 |
295 | # 0. asserts some
296 | # .....................................................
297 | assert not ProteinDesigner_Trainer.locked, 'ProteinDesigner_Trainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
298 |
299 | assert exists(ProtDesi) ^ exists(ProtDesi_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
300 |
301 | # ++
302 | self.CKeys = CKeys
303 | if self.CKeys['Debug_Level']==PT_Init_Level:
304 | print (f"|||||||||||||||||||||||||||||||||||||||||||||||||||")
305 | print (f"Initialize Protein_Designer Trainer object...")
306 |
307 | # determine filesystem, using fsspec, for saving to local filesystem or cloud
308 | self.fs = checkpoint_fs
309 |
310 | if not exists(self.fs):
311 | fs_kwargs = default(fs_kwargs, {})
312 | self.fs, _ = url_to_fs(
313 | default(checkpoint_path, './'), **fs_kwargs
314 | )
315 | # ++
316 | if self.CKeys['Debug_Level']==PT_Init_Level:
317 | print (f"file system: .fs: {self.fs}")
318 |
319 | assert isinstance(ProtDesi, (ProteinDesigner_Base)), \
320 | "ProtDesi is not from ProteinDesigner_Base"
321 |
322 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
323 | # ++
324 | if self.CKeys['Debug_Level']==PT_Init_Level:
325 | print (f"ema_kwargs: {ema_kwargs}")
326 | print (f"kwargs: {kwargs}")
327 |
328 | self.is_elucidated = isinstance(
329 | ProtDesi.diffuser_core, ElucidatedImagen_OneD
330 | )
331 |
332 | # create accelerator instance
333 |
334 | accelerate_kwargs, kwargs = groupby_prefix_and_trim(
335 | 'accelerate_', kwargs
336 | )
337 | # ++
338 | if self.CKeys['Debug_Level']==PT_Init_Level:
339 | print (f"create acce instance...")
340 | print (f"accelerate_kwargs: {accelerate_kwargs}")
341 | print (f"kwargs: {kwargs}")
342 |
343 | assert not (fp16 and exists(precision)), \
344 | 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
345 | accelerator_mixed_precision = default(
346 | precision,
347 | 'fp16' if fp16 else 'no'
348 | )
349 |
350 | self.accelerator = Accelerator(**{
351 | 'split_batches': split_batches,
352 | 'mixed_precision': accelerator_mixed_precision,
353 | 'kwargs_handlers': [
354 | DistributedDataParallelKwargs(find_unused_parameters = True)
355 | ],
356 | **accelerate_kwargs})
357 |
358 | # .is_distributed is a self fun
359 | ProteinDesigner_Trainer.locked = self.is_distributed
360 | # ++
361 | if self.CKeys['Debug_Level']==PT_Init_Level:
362 | print (f".is_distributed or .locked: {ProteinDesigner_Trainer.locked}")
363 |
364 | # cast data to fp16 at training time if needed
365 | self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
366 |
367 | # grad scaler must be managed outside of accelerator
368 | grad_scaler_enabled = fp16
369 |
370 | # ProteinDesigner, imagen, unets and ema unets
371 | self.ProtDesi = ProtDesi
372 | self.imagen = ProtDesi.diffuser_core # imagen
373 | self.num_unets = len(self.imagen.unets)
374 |
375 | self.use_ema = use_ema and self.is_main
376 | self.ema_unets = nn.ModuleList([])
377 | # ++
378 | if self.CKeys['Debug_Level']==PT_Init_Level:
379 | print (f".num_unets: {self.num_unets}")
380 | print (f".use_ema: {self.use_ema}")
381 |
382 | # keep track of what unet is being trained on
383 | # only going to allow 1 unet training at a time
384 |
385 | self.ema_unet_being_trained_index = -1
386 | # keeps track of which ema unet is being trained on
387 |
388 | # data related functions
389 |
390 | self.train_dl_iter = None
391 | self.train_dl = None
392 |
393 | self.valid_dl_iter = None
394 | self.valid_dl = None
395 |
396 | self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
397 |
398 | # auto splitting validation from training, if dataset is passed in
399 |
400 | self.split_valid_from_train = split_valid_from_train
401 |
402 | assert 0 <= split_valid_fraction <= 1, \
403 | 'split valid fraction must be between 0 and 1'
404 | self.split_valid_fraction = split_valid_fraction
405 | self.split_random_seed = split_random_seed
406 |
407 | # be able to finely customize learning rate, weight decay
408 | # per unet
409 |
410 | # ++
411 | if self.CKeys['Debug_Level']==PT_Init_Level:
412 | print (f" Finely customize learning rate, weight decay")
413 |
414 | lr, eps, warmup_steps, cosine_decay_max_steps = map(
415 | partial(cast_tuple, length = self.num_unets),
416 | (lr, eps, warmup_steps, cosine_decay_max_steps)
417 | )
418 |
419 | for ind, (
420 | unet, unet_lr, unet_eps,
421 | unet_warmup_steps, unet_cosine_decay_max_steps
422 | ) in enumerate(
423 | zip(
424 | self.imagen.unets,
425 | lr, eps, warmup_steps, cosine_decay_max_steps
426 | )
427 | ):
428 |
429 | optimizer = Adam(
430 | unet.parameters(),
431 | lr = unet_lr,
432 | eps = unet_eps,
433 | betas = (beta1, beta2),
434 | **kwargs
435 | )
436 |
437 | if self.use_ema:
438 | self.ema_unets.append(EMA(unet, **ema_kwargs))
439 |
440 | scaler = GradScaler(enabled = grad_scaler_enabled)
441 |
442 | scheduler = warmup_scheduler = None
443 |
444 | if exists(unet_cosine_decay_max_steps):
445 | scheduler = CosineAnnealingLR(
446 | optimizer,
447 | T_max = unet_cosine_decay_max_steps
448 | )
449 |
450 | if exists(unet_warmup_steps):
451 | warmup_scheduler = warmup.LinearWarmup(
452 | optimizer,
453 | warmup_period = unet_warmup_steps
454 | )
455 |
456 | if not exists(scheduler):
457 | scheduler = LambdaLR(
458 | optimizer,
459 | lr_lambda = lambda step: 1.0
460 | )
461 |
462 | # set on object
463 |
464 | setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
465 | setattr(self, f'scaler{ind}', scaler)
466 | setattr(self, f'scheduler{ind}', scheduler)
467 | setattr(self, f'warmup{ind}', warmup_scheduler)
468 |
469 | # ++
470 | if self.CKeys['Debug_Level']==PT_Init_Level:
471 | print (f" on Unit-{ind}")
472 | print (f" scaler: {scaler}")
473 | print (f" scheduler: {scheduler}")
474 | print (f" warmup_scheduler: {warmup_scheduler}")
475 |
476 |
477 | # gradient clipping if needed
478 |
479 | self.max_grad_norm = max_grad_norm
480 |
481 | # step tracker and misc
482 |
483 | self.register_buffer('steps', torch.tensor([0] * self.num_unets))
484 |
485 | self.verbose = verbose
486 |
487 | # automatic set devices based on what accelerator decided
488 |
489 | # self.imagen.to(self.device)
490 | self.ProtDesi.to(self.device)
491 | self.to(self.device)
492 |
493 | # checkpointing
494 |
495 | assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
496 | self.checkpoint_path = checkpoint_path
497 | self.checkpoint_every = checkpoint_every
498 | self.max_checkpoints_keep = max_checkpoints_keep
499 |
500 | self.can_checkpoint = self.is_local_main \
501 | if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
502 |
503 | # ++
504 | if self.CKeys['Debug_Level']==PT_Init_Level:
505 | print (f".checkpoint_path: {self.checkpoint_path}")
506 | print (f".checkpoint_every: {self.checkpoint_every}")
507 | print (f".max_checkpoints_keep: {self.max_checkpoints_keep}")
508 | print (f".can_checkpoint: {self.can_checkpoint}")
509 |
510 | if exists(checkpoint_path) and self.can_checkpoint:
511 | bucket = url_to_bucket(checkpoint_path)
512 |
513 | if not self.fs.exists(bucket):
514 | self.fs.mkdir(bucket)
515 |
516 | self.load_from_checkpoint_folder()
517 |
518 | # only allowing training for unet
519 |
520 | self.only_train_unet_number = only_train_unet_number
521 | self.prepared = False
522 | # ++
523 | if self.CKeys['Debug_Level']==PT_Init_Level:
524 | print (f".only_train_unet_number: {self.only_train_unet_number}")
525 | print (f".prepared: {self.prepared}")
526 |
527 | # ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
528 | # computed values
529 | @property
530 | def device(self):
531 | return self.accelerator.device
532 |
533 | @property
534 | def is_distributed(self):
535 | return not (
536 | self.accelerator.distributed_type == DistributedType.NO \
537 | and self.accelerator.num_processes == 1
538 | )
539 |
540 | @property
541 | def is_main(self):
542 | return self.accelerator.is_main_process
543 |
544 | @property
545 | def is_local_main(self):
546 | return self.accelerator.is_local_main_process
547 |
548 | @property
549 | def unwrapped_unet(self):
550 | return self.accelerator.unwrap_model(self.unet_being_trained)
551 |
552 | # ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
553 | # optimizer helper functions
554 |
555 |
556 | # ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
557 |
558 | def load_from_checkpoint_folder(
559 | self,
560 | last_total_steps = -1
561 | ):
562 | if last_total_steps != -1:
563 | filepath = os.path.join(
564 | self.checkpoint_path,
565 | f'checkpoint.{last_total_steps}.pt'
566 | )
567 | self.load(filepath)
568 | return
569 |
570 | sorted_checkpoints = self.all_checkpoints_sorted
571 |
572 | if len(sorted_checkpoints) == 0:
573 | self.print(
574 | f'no checkpoints found to load from at {self.checkpoint_path}'
575 | )
576 | return
577 |
578 | last_checkpoint = sorted_checkpoints[0]
579 | self.load(last_checkpoint)
580 |
581 |
582 |
583 |
584 |
585 |
586 | # ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
587 | # Forward_Pack
588 |
589 | # validating the unet number
590 |
591 | def validate_unet_number(
592 | self,
593 | unet_number = None
594 | ):
595 | if self.num_unets == 1:
596 | unet_number = default(unet_number, 1)
597 |
598 | assert 0 < unet_number <= self.num_unets, \
599 | f'unet number should be in between 1 and {self.num_unets}'
600 |
601 | return unet_number
602 |
603 | # function for allowing only one unet from being trained at a time
604 |
605 | def validate_and_set_unet_being_trained(self, unet_number = None):
606 | if exists(unet_number):
607 | self.validate_unet_number(unet_number)
608 |
609 | assert not exists(self.only_train_unet_number) or \
610 | self.only_train_unet_number == unet_number, \
611 | 'you can only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
612 |
613 | self.only_train_unet_number = unet_number
614 | self.imagen.only_train_unet_number = unet_number
615 |
616 | if not exists(unet_number):
617 | return
618 |
619 | self.wrap_unet(unet_number)
620 |
621 |
622 | def wrap_unet(self, unet_number):
623 | if hasattr(self, 'one_unet_wrapped'):
624 | return
625 |
626 | unet = self.imagen.get_unet(unet_number)
627 | unet_index = unet_number - 1
628 |
629 | optimizer = getattr(self, f'optim{unet_index}')
630 | scheduler = getattr(self, f'scheduler{unet_index}')
631 |
632 | if self.train_dl:
633 | self.unet_being_trained, self.train_dl, optimizer\
634 | = self.accelerator.prepare(
635 | unet, self.train_dl, optimizer
636 | )
637 | else:
638 | self.unet_being_trained, optimizer\
639 | = self.accelerator.prepare(unet, optimizer)
640 |
641 | if exists(scheduler):
642 | scheduler = self.accelerator.prepare(scheduler)
643 |
644 | setattr(self, f'optim{unet_index}', optimizer)
645 | setattr(self, f'scheduler{unet_index}', scheduler)
646 |
647 | self.one_unet_wrapped = True
648 |
649 | # hacking accelerator due to not having separate gradscaler per optimizer
650 |
651 | def set_accelerator_scaler(self, unet_number):
652 |
653 | def patch_optimizer_step(accelerated_optimizer, method):
654 | def patched_step(*args, **kwargs):
655 | accelerated_optimizer._accelerate_step_called = True
656 | return method(*args, **kwargs)
657 | return patched_step
658 |
659 | unet_number = self.validate_unet_number(unet_number)
660 | scaler = getattr(self, f'scaler{unet_number - 1}')
661 |
662 | self.accelerator.scaler = scaler
663 | for optimizer in self.accelerator._optimizers:
664 | optimizer.scaler = scaler
665 | optimizer._accelerate_step_called = False
666 | optimizer._optimizer_original_step_method = optimizer.optimizer.step
667 | optimizer._optimizer_patched_step_method = patch_optimizer_step(
668 | optimizer, optimizer.optimizer.step
669 | )
670 |
671 |
672 |
673 | @partial(cast_torch_tensor, cast_fp16 = True)
674 | def forward(
675 | self,
676 | *args,
677 | unet_number = None,
678 | max_batch_size = None,
679 | **kwargs
680 | ):
681 | # ++
682 | if self.CKeys['Debug_Level']==PT_Forw_Level:
683 | print (f"Debug mode for trainer.forward...")
684 |
685 | unet_number = self.validate_unet_number(unet_number) # check if unet_number is in the range
686 | # ++
687 | if self.CKeys['Debug_Level']==PT_Forw_Level:
688 | print (f"Train UNet number: {unet_number}")
689 |
690 | self.validate_and_set_unet_being_trained(unet_number)
691 | self.set_accelerator_scaler(unet_number)
692 |
693 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
694 |
695 | total_loss = 0.
696 |
697 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
698 |
699 | with self.accelerator.autocast():
700 | #--
701 | # loss = self.imagen(
702 | #++
703 | loss = self.ProtDesi(
704 | *chunked_args,
705 | unet = self.unet_being_trained,
706 | unet_number = unet_number, **chunked_kwargs
707 | )
708 | loss = loss * chunk_size_frac
709 | # ++
710 | if self.CKeys['Debug_Level']==PT_Forw_Level:
711 | print (f"get loss for a fraction: {loss}")
712 |
713 | total_loss += loss.item()
714 | # ++
715 | if self.CKeys['Debug_Level']==PT_Forw_Level:
716 | print (f"update tot_loss: {total_loss}")
717 |
718 | if self.training:
719 | self.accelerator.backward(loss)
720 |
721 | return total_loss
722 |
--------------------------------------------------------------------------------
/VibeGen/UtilityPack.py:
--------------------------------------------------------------------------------
1 | # ==========================================================
2 | # Utility functions
3 | # ==========================================================
4 | import os
5 | from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator
6 | import numpy as np
7 | import math
8 | import matplotlib.pyplot as plt
9 |
10 | from Bio.PDB import PDBParser
11 | from Bio.PDB.DSSP import DSSP
12 | from Bio.PDB import PDBList
13 |
14 | import torch
15 | from einops import rearrange
16 | import esm
17 |
18 | import json
19 |
20 | # =========================================================
21 | #
22 | def Print(this_line):
23 | # may update for multi-core case later
24 | print (this_line)
25 |
26 | def print_dict_content(this_dict):
27 |
28 | for this_key in this_dict.keys():
29 | print (f" {this_key}: {this_dict[this_key]}")
30 |
31 | # =========================================================
32 | # create a folder path if not exist
33 | def create_path(this_path):
34 | if not os.path.exists(this_path):
35 | print('Creating the given path...')
36 | os.mkdir (this_path)
37 | path_stat = 1
38 | print('Done.')
39 | else:
40 | print('The given path already exists!')
41 | path_stat = 2
42 | return path_stat
43 |
44 | # ============================================================
45 | # on esm, rebuild AA sequence from embedding
46 | # ============================================================
47 |
48 | def decode_one_ems_token_rec(this_token, esm_alphabet):
49 | # print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
50 | # print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
51 | # print( (this_token==100).nonzero(as_tuple=True)[0]==None )
52 |
53 | id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
54 | id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
55 |
56 |
57 | if len(id_e)==0:
58 | # no ending for this one, so id_e points to the end
59 | id_e=len(this_token)
60 | else:
61 | id_e=id_e[0]
62 | if len(id_b)==0:
63 | id_b=0
64 | else:
65 | id_b=id_b[-1]
66 |
67 | this_seq = []
68 | # this_token_used = []
69 | for ii in range(id_b+1,id_e,1):
70 | # this_token_used.append(this_token[ii])
71 | this_seq.append(
72 | esm_alphabet.get_tok(this_token[ii])
73 | )
74 |
75 | this_seq = "".join(this_seq)
76 |
77 | # print(this_seq)
78 | # print(len(this_seq))
79 | # # print(this_token[id_b+1:id_e])
80 | return this_seq
81 |
82 |
83 | def decode_many_ems_token_rec(batch_tokens, esm_alphabet):
84 | rev_y_seq = []
85 | for jj in range(len(batch_tokens)):
86 | # do for one seq: this_seq
87 | this_seq = decode_one_ems_token_rec(
88 | batch_tokens[jj], esm_alphabet
89 | )
90 | rev_y_seq.append(this_seq)
91 | return rev_y_seq
92 |
93 | def Print_model_params (model):
94 | pytorch_total_params = sum(p.numel() for p in model.parameters())
95 | pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
96 |
97 | Print (
98 | f"Total model parameters: {pytorch_total_params}\nTrainable parameters: {pytorch_total_params_trainable}\n"
99 | )
100 |
101 | def get_model_params (model):
102 | pytorch_total_params = sum(p.numel() for p in model.parameters())
103 | pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
104 |
105 | resu = {
106 | 'tot': pytorch_total_params,
107 | 'trainable': pytorch_total_params_trainable,
108 | 'freezed': pytorch_total_params-pytorch_total_params_trainable
109 | }
110 |
111 | return resu
112 |
113 | def write_one_line_to_file(
114 | this_line,
115 | file_name,
116 | mode,
117 | accelerator=None
118 | ):
119 | with open(file_name, mode) as f:
120 | f.write(this_line)
121 |
122 | # ==============================================================
123 | #
124 | # def convert_into_tokens_using_prob(
125 | # prob_result,
126 | # pLM_Model_Name
127 | # ):
128 | # if pLM_Model_Name=='esm2_t33_650M_UR50D' \
129 | # or pLM_Model_Name=='esm2_t36_3B_UR50D' \
130 | # or pLM_Model_Name=='esm2_t30_150M_UR50D' \
131 | # or pLM_Model_Name=='esm2_t12_35M_UR50D' :
132 |
133 | # repre=rearrange(
134 | # prob_result,
135 | # 'b c l -> b l c'
136 | # )
137 | # # with torch.no_grad():
138 | # # logits=model.lm_head(repre) # (b, l, token_dim)
139 | # logits = repre
140 |
141 | # tokens=logits.max(2).indices # (b,l)
142 |
143 | # else:
144 | # print("pLM_Model is not defined...")
145 | # return tokens,logits
146 |
147 | def read_mask_from_input(
148 | # consider different type of inputs
149 | # raw data: x_data (sequences)
150 | # tokenized: x_data_tokenized
151 | tokenized_data=None, # X_train_batch,
152 | mask_value=None,
153 | seq_data=None, # Y_train_batch,
154 | max_seq_length=None,
155 | ):
156 | # # old:
157 | # mask = X_train_batch!=mask_value
158 | # new
159 | if seq_data!=None:
160 | # use the real sequence length to create mask
161 | n_seq = len(seq_data)
162 | mask = torch.zeros(n_seq, max_seq_length)
163 | for ii in range(n_seq):
164 | this_len = len(seq_data[ii])
165 | mask[ii,1:1+this_len]=1
166 | mask = mask==1
167 | #
168 | elif tokenized_data!=None:
169 | n_seq = len(tokenized_data)
170 | mask = tokenized_data!=mask_value
171 | # fix the beginning part: 0+content+00, not 00+content+00
172 | for ii in range(n_seq):
173 | # get all nonzero index
174 | id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0]
175 | # correction for ForcPath,
176 | # pick up 0.0 for zero-force padding at the beginning
177 | mask[ii,1:id_1[0]]=True
178 |
179 | return mask
180 |
181 | # on pLM tokens
182 | # basic 20 in abr order: ARNDCEQGHILKMFPSTWYV
183 | # in esm, tot = 33
184 | # basic 20 in esm order: LAGVSERTIDPKQNFYMHWC
185 | # others (4):
186 | # special (9): X B U Z O . -
187 | # LAGVSERTIDPKQNFYMHWC: toke the channels: 4-23
188 | # full dict
189 | esm_tok_to_idx = \
190 | {'': 0, '': 1, '': 2, '': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '': 31, '': 32}
191 |
192 | esm_idx_to_tok = \
193 | {'0': '', '1': '', '2': '', '3': '', '4': 'L', '5': 'A', '6': 'G', '7': 'V', '8': 'S', '9': 'E', '10': 'R', '11': 'T', '12': 'I', '13': 'D', '14': 'P', '15': 'K', '16': 'Q', '17': 'N', '18': 'F', '19': 'Y', '20': 'M', '21': 'H', '22': 'W', '23': 'C', '24': 'X', '25': 'B', '26': 'U', '27': 'Z', '28': 'O', '29': '.', '30': '-', '31': '', '32': ''}
194 |
195 | common_AA_list = "LAGVSERTIDPKQNFYMHWC"
196 |
197 |
198 | # common_AA_idx_in_esm = []
199 | # for ii in range(len(common_AA_list)):
200 | # common_AA_idx_in_esm.append(
201 | # esm_tok_to_idx[
202 | # common_AA_list[ii]
203 | # ]
204 | # )
205 |
206 | common_AA_idx_in_esm = [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
207 |
208 | def keep_only_20AA_channels_in_one_pLM_logits(
209 | full_logits, # (seq_len, channel)
210 | keep_channels=common_AA_idx_in_esm
211 | ):
212 | assert full_logits.shape[-1]==33, \
213 | "Not ESM logits shape"
214 |
215 | n_channel = full_logits.shape[-1]
216 | for this_c in range(n_channel):
217 | if not (this_c in keep_channels):
218 | full_logits[:,this_c]=-float('inf')
219 |
220 | return full_logits
221 |
222 | def get_toks_list_from_Y_batch(
223 | batch_GT, # (b, seq_len)
224 | batch_mask, # (b, seq_len)
225 | ):
226 | toks_list = []
227 | seqs_list = []
228 |
229 | for ii in range(len(batch_GT)):
230 | this_GT = batch_GT[ii]
231 | this_mask = batch_mask[ii]
232 | this_GT = this_GT[this_mask==True]
233 | #
234 | toks_list.append(this_GT)
235 | this_seq = [
236 | esm_idx_to_tok[str(jj.item())] for jj in this_GT
237 | ]
238 | this_seq = "".join(this_seq)
239 | seqs_list.append(this_seq)
240 |
241 |
242 | return toks_list, seqs_list
243 |
244 | def compare_two_seq_strings(seq_PR, seq_GT):
245 | # take seq_GT as the ref,
246 | # assume len(seq_GT)>=len(seq_PR)
247 | len_comp = min( len(seq_PR), len(seq_GT))
248 | num_hit = 0
249 | for ii in range(len_comp):
250 | if seq_PR[ii]==seq_GT[ii]:
251 | num_hit += 1
252 | ratio_hit = num_hit/len_comp
253 |
254 | return ratio_hit
255 |
256 | def save_2d_tensor_as_np_arr_txt(
257 | X_tensor, # (a, b)
258 | mask = None, # (b)
259 | outname = None,
260 | ):
261 | assert X_tensor.dim() == 2
262 |
263 | if not (mask is None):
264 | assert mask.dim() == 1
265 |
266 | if not (mask is None):
267 | X_tensor = X_tensor[:, mask]
268 |
269 |
270 | test_one_X_arr = X_tensor.cpu().detach().numpy()
271 | if outname is None:
272 | print (test_one_X_arr)
273 | else:
274 | np.savetxt(outname, test_one_X_arr)
275 | # # to read back as a 2d np arr
276 | # test_one_X_arr_1 = np.loadtxt(test_file)
277 |
278 | # ++ read back for checking
279 | def read_2d_np_arr_from_txt(
280 | test_file
281 | ):
282 | test_one_X_arr_1 = np.loadtxt(test_file)
283 | return test_one_X_arr_1
284 |
285 | def string_diff (seq1, seq2):
286 | return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2))
287 |
288 | # def write_fasta_file(
289 | # this_seq,
290 | # this_head,
291 | # this_file
292 | # ):
293 | # with open(this_file, mode = 'w') as f:
294 | # f.write (f">{this_head}\n")
295 | # f.write (f"{this_seq}")
296 |
297 | def write_fasta_file(
298 | this_seq_list,
299 | this_head_list,
300 | this_file
301 | ):
302 | n_seq = len(this_seq_list)
303 |
304 | with open(this_file, mode = 'w') as f:
305 | for i_seq in range(n_seq):
306 |
307 | f.write (f">{this_head_list[i_seq]}\n")
308 | f.write (f"{this_seq_list[i_seq]}\n")
309 |
310 | # ++
311 | def read_recover_AAs_only(test_fasta_file):
312 |
313 | file1 = open(test_fasta_file, 'r')
314 | Lines = file1.readlines()
315 | # only get AA
316 | AA_GT = Lines[1].strip()
317 | AA_recon_GT = Lines[3].strip()
318 |
319 | resu = {}
320 | resu['AA_GT'] = AA_GT
321 | resu['AA_recon_GT'] = AA_recon_GT
322 |
323 | return resu
324 |
325 | # ===================================================================================
326 | # old one
327 | def fold_one_AA_to_SS_using_omegafold_for_5_Diffusionfold(
328 | sequence,
329 | num_cycle=16,
330 | device=None,
331 | # ++++++++++++++
332 | prefix=None,
333 | AA_file_path=None,
334 | PDB_file_path=None, # output file path
335 | head_note=None,
336 | ):
337 | AA_file_name = f"{AA_file_path}/{prefix}_.fasta"
338 | print ("Writing FASTA file: ", AA_file_name)
339 | head_line = f"{head_note}"
340 | with open (AA_file_name, mode ='w') as f:
341 | f.write (f'>{head_line}\n')
342 | f.write (f'{sequence}')
343 | #
344 | #
345 | PDB_result=f"{PDB_file_path}/{head_line}.pdb"
346 | if not os.path.exists(PDB_result):
347 | print (f"Now run OmegaFold.... on device={device}")
348 | # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device
349 | cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --num_cycle {num_cycle} --device={device}"
350 | print(os.popen(cmd_line).read())
351 |
352 | print ("Done OmegaFold")
353 |
354 | # PDB_result=f"{prefix}{OUTFILE}.PDB"
355 |
356 | print (f"Resulting PDB file...: {PDB_result}")
357 | else:
358 | print (f"PDB file already exist.")
359 |
360 | return PDB_result, AA_file_name
361 | #
362 | # ===================================================================================
363 | # new one: need to install the modified omegafold from self-hold repo
364 | # https://github.com/Bo-Ni/OmegaFold_0.git
365 | def get_subbatch_size(L):
366 | if L < 500: return 500
367 | if L < 1000: return 500 # 500 # 200
368 | return 150
369 |
370 | def fold_one_AA_to_SS_using_omegafold(
371 | sequence,
372 | num_cycle=16,
373 | device=None,
374 | # ++++++++++++++
375 | prefix="Temp", # None,
376 | AA_file_path="./", # None,
377 | PDB_file_path="./", # output file path
378 | head_note="Temp_", # None,
379 | ):
380 | AA_file_name = f"{AA_file_path}/{prefix}_.fasta"
381 | print ("Writing FASTA file: ", AA_file_name)
382 | head_line = f"{head_note}"
383 | with open (AA_file_name, mode ='w') as f:
384 | f.write (f'>{head_line}\n')
385 | f.write (f'{sequence}')
386 | #
387 | subbatch_size = get_subbatch_size(len(sequence))
388 | #
389 | PDB_result=f"{PDB_file_path}/{head_line}.pdb"
390 |
391 | if not os.path.exists(PDB_result):
392 | Print (f"Now run OmegaFold.... on device={device}\n\n")
393 | # !omegafold $filename $prefix --num_cycle $num_cycle --device=$device
394 | # cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --num_cycle {num_cycle} --device={device}"
395 | cmd_line=F"omegafold {AA_file_name} {PDB_file_path} --subbatch_size {str(subbatch_size)} --num_cycle {num_cycle} --device={device}"
396 |
397 | Print(os.popen(cmd_line).read())
398 |
399 | Print ("Done OmegaFold")
400 |
401 | # PDB_result=f"{prefix}{OUTFILE}.PDB"
402 |
403 | Print (f"Resulting PDB file...: {PDB_result}\n\n")
404 | else:
405 | Print (f"PDB file already exist.")
406 |
407 | return PDB_result, AA_file_name
408 | #
409 | # ===================================================================================
410 | # plot
411 | import py3Dmol
412 |
413 | def plot_plddt_legend(dpi=100):
414 | thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
415 | plt.figure(figsize=(1,0.1),dpi=dpi)
416 | ########################################
417 | for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
418 | plt.bar(0, 0, color=c)
419 | plt.legend(thresh, frameon=False,
420 | loc='center', ncol=6,
421 | handletextpad=1,
422 | columnspacing=1,
423 | markerscale=0.5,)
424 | plt.axis(False)
425 | return plt
426 |
427 | color = "lDDT" # choose from ["chain", "lDDT", "rainbow"]
428 | show_sidechains = False #choose from {type:"boolean"}
429 | show_mainchains = False #choose from {type:"boolean"}
430 |
431 | def show_pdb(
432 | pdb_file,
433 | flag=0,
434 | show_sidechains=False,
435 | show_mainchains=False,
436 | color="lDDT"
437 | ):
438 | model_name = f"Flag_{flag}"
439 | view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
440 | view.addModel(open(pdb_file,'r').read(),'pdb')
441 |
442 | if color == "lDDT":
443 | view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
444 | elif color == "rainbow":
445 | view.setStyle({'cartoon': {'color':'spectrum'}})
446 | elif color == "chain":
447 | chains = len(queries[0][1]) + 1 if is_complex else 1
448 | for n,chain,color in zip(
449 | range(chains),list("ABCDEFGH"),
450 | ["lime","cyan","magenta","yellow","salmon","white","blue","orange"]
451 | ):
452 | view.setStyle({'chain':chain},{'cartoon': {'color':color}})
453 |
454 | if show_sidechains:
455 | BB = ['C','O','N']
456 | view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
457 | {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
458 | view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
459 | {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
460 | view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
461 | {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
462 | if show_mainchains:
463 | BB = ['C','O','N','CA']
464 | view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
465 |
466 | view.zoomTo()
467 | if color == "lDDT":
468 | plot_plddt_legend().show()
469 | #
470 | return view
471 | #
472 | # ===================================================================================
473 | # SecStr
474 | from Bio.PDB import PDBParser
475 | from Bio.PDB.DSSP import DSSP
476 | from Bio.PDB import PDBList
477 |
478 | Unique_SecStr_Q8_String="HET~BGIS"
479 | Unique_SecStr_Q3_String="HEC"
480 | #
481 | # =============================================
482 | # count statistics of Q8 based SecStr
483 | #
484 | def count_ratio_for_Q8(
485 | this_secstr,
486 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String,
487 | ):
488 | resu = {}
489 | seq_len = len(this_secstr)
490 | for this_char in Unique_SecStr_Q8_String:
491 | resu[this_char] = this_secstr.count(this_char)/seq_len
492 | #
493 | return resu
494 | # =============================================
495 | # count statistics of Q3 based SecStr
496 | #
497 | def count_ratio_for_Q3(
498 | this_secstr,
499 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String,
500 | ):
501 | resu = {}
502 | seq_len = len(this_secstr)
503 | for this_char in Unique_SecStr_Q3_String:
504 | resu[this_char] = this_secstr.count(this_char)/seq_len
505 | #
506 | return resu
507 | # ===============================================
508 | #
509 | def analyze_SS_Q8_Q3_for_df(
510 | df_smo_recon_BSDB_4P_expanded,
511 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String,
512 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String,
513 | ):
514 | #
515 | # do statistics on Q8
516 | this_key_to_add = 'stat_Q8'
517 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()):
518 | print (f"Add new key {this_key_to_add}")
519 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply(
520 | # ================ change this part ===========================
521 | lambda row: count_ratio_for_Q8(
522 | row['SS_Q8'],
523 | Unique_SecStr_Q8_String=Unique_SecStr_Q8_String,
524 | ),
525 | # ================ change this part ===========================
526 | axis=1,
527 | )
528 |
529 | # do statistics on Q3
530 | this_key_to_add = 'stat_Q3'
531 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()):
532 | print (f"Add new key {this_key_to_add}")
533 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply(
534 | # ================ change this part ===========================
535 | lambda row: count_ratio_for_Q3(
536 | row['SS_Q3'],
537 | Unique_SecStr_Q3_String=Unique_SecStr_Q3_String,
538 | ),
539 | # ================ change this part ===========================
540 | axis=1,
541 | )
542 | #
543 | # expand to df columns
544 | for this_char in Unique_SecStr_Q3_String:
545 | print (f"working on Q3 {this_char}")
546 | this_key_to_add = 'stat_Q3_'+this_char
547 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()):
548 | print (f"Add new key {this_key_to_add}")
549 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply(
550 | # ================ change this part ===========================
551 | lambda row: row['stat_Q3'][this_char],
552 | # ================ change this part ===========================
553 | axis=1,
554 | )
555 | # expand to Q8
556 | for this_char in Unique_SecStr_Q8_String:
557 | print (f"working on Q8 {this_char}")
558 | this_key_to_add = 'stat_Q8_'+this_char
559 | if not (this_key_to_add in df_smo_recon_BSDB_4P_expanded.keys()):
560 | print (f"Add new key {this_key_to_add}")
561 | df_smo_recon_BSDB_4P_expanded[this_key_to_add] = df_smo_recon_BSDB_4P_expanded.apply(
562 | # ================ change this part ===========================
563 | lambda row: row['stat_Q8'][this_char],
564 | # ================ change this part ===========================
565 | axis=1,
566 | )
567 |
568 | return df_smo_recon_BSDB_4P_expanded
569 | # ==================================================
570 |
571 | def get_DSSP_result (fname):
572 | pdb_list = [fname]
573 |
574 | # parse structure
575 | p = PDBParser()
576 | for i in pdb_list:
577 | structure = p.get_structure(i, fname)
578 | # use only the first model
579 | model = structure[0]
580 | # calculate DSSP
581 | dssp = DSSP(model, fname, file_type='PDB' )
582 | # extract sequence and secondary structure from the DSSP tuple
583 | sequence = ''
584 | sec_structure = ''
585 | for z in range(len(dssp)):
586 | a_key = list(dssp.keys())[z]
587 | sequence += dssp[a_key][1]
588 | sec_structure += dssp[a_key][2]
589 |
590 | # print extracted sequence and structure
591 | #print(i)
592 | #print(sequence)
593 | #print(sec_structure)
594 | #
595 | # The DSSP codes for secondary structure used here are:
596 | # ===== ====
597 | # Code Structure
598 | # ===== ====
599 | # H Alpha helix (4-12)
600 | # B Isolated beta-bridge residue
601 | # E Strand
602 | # G 3-10 helix
603 | # I Pi helix
604 | # T Turn
605 | # S Bend
606 | # - None
607 | # ===== ====
608 | #
609 |
610 | sec_structure = sec_structure.replace('-', '~')
611 | sec_structure_3state=sec_structure
612 |
613 |
614 | # if desired, convert DSSP's 8-state assignments into 3-state [C - coil, E - extended (beta-strand), H - helix]
615 | sec_structure_3state = sec_structure_3state.replace('~', 'C')
616 | sec_structure_3state = sec_structure_3state.replace('I', 'C')
617 | sec_structure_3state = sec_structure_3state.replace('T', 'C')
618 | sec_structure_3state = sec_structure_3state.replace('S', 'C')
619 | sec_structure_3state = sec_structure_3state.replace('G', 'H')
620 | sec_structure_3state = sec_structure_3state.replace('B', 'E')
621 |
622 | return sec_structure,sec_structure_3state, sequence
623 |
624 | # ++ for postprocess
625 | def get_DSSP_set_result(fname):
626 | sec_structure,sec_structure_3state, sequence = get_DSSP_result (fname)
627 | resu={}
628 | resu['SecStr_Q8']=sec_structure
629 | resu['SecStr_Q3']=sec_structure_3state
630 | resu['AA_from_DSSP']=sequence
631 |
632 | return resu
633 |
634 | def write_DSSP_result_to_json(
635 | sec_structure,
636 | sec_structure_3state,
637 | sequence,
638 | filename,
639 | ):
640 | resu = {
641 | "Q8": sec_structure,
642 | "Q3": sec_structure_3state,
643 | "AA_from_DSSP": sequence
644 | }
645 | resu_json = json.dumps(resu, indent=4)
646 |
647 | with open(filename, "w") as f:
648 | f.write(resu_json)
649 |
650 | # # to read back
651 | # with open(filename, 'r') as openfile:
652 | # # Reading from json file
653 | # json_object = json.load(openfile)
654 |
655 | # print(json_object)
656 | # print(type(json_object)) # dict
657 |
658 | # ==============================================================
659 | # pick some Normal Mode from a df
660 | # For NMS vectors only
661 | def build_XCond_list_from_df(
662 | df,
663 | key_list,
664 | pick_id_list,
665 | ):
666 | n_mode = len(key_list)
667 | n_samp = len(pick_id_list)
668 | resu = []
669 | for id_samp in pick_id_list:
670 | this_X_list = []
671 | for this_key in key_list:
672 | add_one = df[this_key].values[id_samp]
673 |
674 | this_X_list.append(
675 | add_one
676 | )
677 | this_X = np.array(this_X_list)
678 | resu.append(this_X)
679 |
680 | return resu
681 |
682 | # For AA Seq only
683 | def build_AA_list_from_df(
684 | df,
685 | AA_key,
686 | pick_id_list,
687 | ):
688 | n_samp = len(pick_id_list)
689 | resu = []
690 | for id_samp in pick_id_list:
691 | resu.append(
692 | df[AA_key].values[id_samp]
693 | )
694 |
695 | return resu
696 |
697 | # ==============================================================
698 | # add for Protein Predictor
699 | def get_nms_vec_as_arr_list_from_batch_using_mask(
700 | result_mask, # (b, seq_len) # torch.tensor
701 | output_diffuser, # (b, n_mode, seq_len)
702 | NormFac_list, # (n_mode, )
703 | ):
704 | n_samp = output_diffuser.shape[0]
705 | n_mode = output_diffuser.shape[1]
706 |
707 | nms_vecs_list = []
708 | for i_samp in range(n_samp):
709 | this_mask = result_mask[i_samp] # (seq_len, )
710 | this_nms_vecs = output_diffuser[i_samp]
711 |
712 | # to take care of multi-modes
713 | this_nms_arr = []
714 | for i_mode in range(n_mode):
715 | this_add = this_nms_vecs[i_mode][this_mask==True] # only work for 1D tensor
716 | this_add = this_add * NormFac_list[i_mode] # map it back to real values
717 | this_nms_arr.append(
718 | this_add.cpu().detach().numpy()
719 | )
720 | this_nms_arr = np.array(this_nms_arr) # convert into np.arr
721 |
722 | # deliver to the list to store
723 | nms_vecs_list.append(this_nms_arr)
724 |
725 | return nms_vecs_list
726 |
727 | # compare two nms_vecs
728 | def compare_two_nms_vecs_arr(
729 | PR_nms_vecs,
730 | GT_nms_vecs,
731 | ):
732 | n_mode = GT_nms_vecs.shape[0]
733 | # calculate error for each mode and the tot
734 | # calculate rela_L2 error
735 | resu = {}
736 | for i_mode in range(n_mode):
737 | resu["rela_L2_err_Mode_"+str(i_mode)]=np.linalg.norm(PR_nms_vecs[i_mode]-GT_nms_vecs[i_mode])/np.linalg.norm(GT_nms_vecs[i_mode])
738 | #
739 | # calculate for multi-modes
740 | resu["rela_L2_err_MulMode"]=np.linalg.norm(PR_nms_vecs-GT_nms_vecs)/np.linalg.norm(GT_nms_vecs)
741 |
742 | return resu
743 |
744 | # ======================================================
745 |
746 | def translate_seqs_list_into_idx_tensor_w_pLM(
747 | # 1. model converter
748 | esm_batch_converter,
749 | AA_seq_max_len,
750 | # 2. on input
751 | raw_condition_list,
752 | device
753 | ):
754 |
755 | seqs_ext=[]
756 | # add a fake one to make sure the padding length
757 | dummy_seq = 'A'*(AA_seq_max_len-2)
758 | seqs_ext.append(
759 | (" ", dummy_seq)
760 | )
761 |
762 | for i in range(len(raw_condition_list)):
763 | seqs_ext.append(
764 | (" ", raw_condition_list[i])
765 | )
766 | # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext)
767 | _, y_strs, y_data = esm_batch_converter(seqs_ext)
768 | # y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1)
769 | # print(batch_tokens.shape)
770 | #
771 | # ++ remove the dummy one
772 | y_data = y_data[1:]
773 | seqs_ext = seqs_ext[1:]
774 |
775 | y_data = y_data.to(device)
776 |
777 | return y_data
778 |
779 | # ==================================================
780 |
781 | # def cal_err_list_using_
--------------------------------------------------------------------------------
/VibeGen/JointSamplingPack.py:
--------------------------------------------------------------------------------
1 | """
2 | Task:
3 | 1. create a trainer for ProteinDesigner
4 | 2. include train_loop, sample_loop
5 |
6 | Bo Ni, Sep 8, 2024
7 | """
8 |
9 | # //////////////////////////////////////////////////////
10 | # 0. load in packages
11 | # //////////////////////////////////////////////////////
12 |
13 | import os
14 | from math import ceil
15 | from contextlib import contextmanager, nullcontext
16 | from functools import partial, wraps
17 | from collections.abc import Iterable
18 |
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 | from torch.utils.data import random_split, DataLoader
23 | from torch.optim import Adam
24 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
25 | from torch.cuda.amp import autocast, GradScaler
26 |
27 | import pytorch_warmup as warmup
28 |
29 | from packaging import version
30 |
31 | import numpy as np
32 | import math
33 | import pandas as pd
34 |
35 | from ema_pytorch import EMA
36 |
37 | from einops import rearrange
38 |
39 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
40 |
41 | from fsspec.core import url_to_fs
42 | from fsspec.implementations.local import LocalFileSystem
43 |
44 | import shutil
45 | import matplotlib.pyplot as plt
46 |
47 | from sklearn.metrics import r2_score
48 | from scipy.stats import spearmanr, pearsonr
49 |
50 | # //////////////////////////////////////////////////////////////
51 | # 2. special packages
52 | # //////////////////////////////////////////////////////////////
53 | from VibeGen.DataSetPack import (
54 | pad_a_np_arr_esm_for_NMS
55 | )
56 | from VibeGen.ModelPack import (
57 | ProteinDesigner_Base,
58 | ProteinPredictor_Base
59 | )
60 | from VibeGen.imagen_x_imagen_pytorch import (
61 | ElucidatedImagen_OneD, eval_decorator
62 | )
63 | #
64 | from VibeGen.UtilityPack import (
65 | Print, Print_model_params,
66 | create_path,
67 | get_toks_list_from_Y_batch,
68 | save_2d_tensor_as_np_arr_txt,
69 | write_fasta_file,
70 | compare_two_seq_strings,
71 | fold_one_AA_to_SS_using_omegafold,
72 | show_pdb,
73 | get_DSSP_result,
74 | write_DSSP_result_to_json,
75 | write_one_line_to_file,
76 | decode_many_ems_token_rec,
77 | get_nms_vec_as_arr_list_from_batch_using_mask,
78 | compare_two_nms_vecs_arr,
79 | translate_seqs_list_into_idx_tensor_w_pLM
80 | )
81 |
82 | # //////////////////////////////////////////////////////////////
83 | # 3. local setup parameters: for debug purpose
84 | # //////////////////////////////////////////////////////////////
85 | PT_Init_Level = 1
86 | PT_Forw_Level = 1
87 |
88 | Local_Debug_Level = 0
89 | # //////////////////////////////////////////////////////////////
90 | # 4. helper functions
91 | # //////////////////////////////////////////////////////////////
92 | def merge_two_topk(
93 | y_goo,
94 | y_bad,
95 | ):
96 | y={}
97 | y['indices']=torch.concatenate(
98 | (y_goo.indices,y_bad.indices)
99 | )
100 | y['values']= torch.concatenate(
101 | (y_goo.values,y_bad.values)
102 | )
103 |
104 | len_goo = len(y_goo.indices)
105 | len_bad = len(y_bad.indices)
106 | name_list_goo = [f'min_err_{ii}' for ii in range(len_goo)]
107 | name_list_bad = [f'max_err_{ii}' for ii in range(len_bad)]
108 | name_list = name_list_goo+name_list_bad
109 |
110 | y['name_type']=name_list
111 |
112 | # indices=torch.concatenate(
113 | # (y1.indices,y2.indices)
114 | # )
115 | # values= torch.concatenate(
116 | # (y1.values,y2.values)
117 | # )
118 | # y=torch.return_types.topk_out(
119 | # values=values,
120 | # indices=indices
121 | # )
122 | return y
123 |
124 | # //////////////////////////////////////////////////////////////
125 | # 5. Main class/functions: a base trainer wrap for ProteinDesigner
126 | # //////////////////////////////////////////////////////////////
127 |
128 |
129 | # ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
130 |
131 | # joint PD & PP for sampling
132 |
133 | def joint_sampling_w_PD_and_PP(
134 | # 1. model
135 | PD_wk_ProteinDesigner,
136 | PP_wk_ProteinPredictor,
137 | # 2. data
138 | PD_test_set_condition_list, # input as a list of NMS vecs
139 | PD_test_set_AA_list=None, # whether GT is provided
140 | PD_DataKeys=None,
141 | # 3. control param
142 | n_try_w_PD = 100, # For PD, try this number times as a batch
143 | n_keep_w_PP_goo = 2, # Use PP to pick the top this_number samples
144 | # ++
145 | n_keep_w_PP_bad=2, # also keep the worest one as ref:
146 | #
147 | PD_cond_scal = 7.5,
148 | PP_cond_scal = 7.5,
149 | # 4. outputs
150 | joint_sampling_dir = None,
151 | joint_sampling_prefix = f"TestSet_", # 3. on postprocessing
152 | #
153 | IF_plot_PP = True,
154 | IF_showfig = True,
155 | IF_save_pred_pack = True,
156 | IF_plot_PD = True,
157 | IF_fold_seq = True,
158 | IF_show_foldding = True,
159 | IF_DSSP = True,
160 | # others
161 | device = None,
162 | ):
163 |
164 | if not (PD_test_set_AA_list is None):
165 | assert len(PD_test_set_condition_list)==len(PD_test_set_AA_list), \
166 | "the input Conditioning and GT don't have the same length..."
167 | else:
168 | Print(f"Only input Conditioning is provided...")
169 |
170 | # prepare wk dir
171 | if not os.path.exists(joint_sampling_dir):
172 | Print(f"Create joint sampling path...")
173 | create_path(joint_sampling_dir)
174 | Print(f"Done.")
175 | else:
176 | Print(f"Dir exists. Use caution...")
177 |
178 | # model status
179 | PD_wk_ProteinDesigner.turn_on_eval_mode()
180 | PP_wk_ProteinPredictor.turn_on_eval_mode()
181 |
182 |
183 | # prepare
184 | # on PD
185 | text_len_max = PD_wk_ProteinDesigner.text_max_len
186 | img_len_max = PD_wk_ProteinDesigner.diffuser_core.image_sizes[0]
187 | len_in = min(text_len_max, img_len_max) # depend on problem statement, may change
188 |
189 | text_embed_input_dim = PD_wk_ProteinDesigner.text_embed_input_dim
190 | cond_img_channels = PD_wk_ProteinDesigner.diffuser_core.unets[0].cond_images_channels
191 | mode_in = min(text_embed_input_dim, cond_img_channels)
192 |
193 | print (len_in)
194 | print (mode_in)
195 |
196 | # on PP
197 | AA_seq_max_len = PP_wk_ProteinPredictor.seq_obj_max_size
198 | esm_batch_converter = PP_wk_ProteinPredictor.pLM_alphabet.get_batch_converter(
199 | truncation_seq_length=AA_seq_max_len-2
200 | )
201 | PP_wk_ProteinPredictor.pLM.eval()
202 |
203 | PR_err_for_PP = torch.nn.MSELoss(reduction='none')
204 |
205 | n_keep_w_PP = n_keep_w_PP_goo+n_keep_w_PP_bad
206 |
207 | # ++
208 | # ++ get GT for PD if exists
209 | # 6. translate back to a batch for PP
210 | if not (PD_test_set_AA_list is None):
211 | GT_test_set_AA_batch_for_PD = \
212 | translate_seqs_list_into_idx_tensor_w_pLM(
213 | # 1. model converter
214 | esm_batch_converter,
215 | AA_seq_max_len,
216 | # 2. on input
217 | raw_condition_list=PD_test_set_AA_list,
218 | # 3. on outpt
219 | device=device
220 | ) # (batch, seq_len)
221 | #
222 | # get len info as shape mask
223 | mask_from_Y_all = PD_wk_ProteinDesigner.read_mask_from_seq_toks_using_pLM(
224 | GT_test_set_AA_batch_for_PD
225 | )
226 | #
227 | GT_idx_list, GT_seqs_list = get_toks_list_from_Y_batch(
228 | GT_test_set_AA_batch_for_PD,
229 | mask_from_Y_all
230 | )
231 | else:
232 | GT_test_set_AA_batch_for_PD = None
233 | mask_from_Y_all = None
234 | GT_idx_list = None
235 | GT_seqs_list = None
236 |
237 | # ++ for picking up from the previous runs
238 | # .............................................................
239 | reco_csv = joint_sampling_dir+'/'+joint_sampling_prefix+\
240 | f'Try_{n_try_w_PD}_Pick_{n_keep_w_PP}'+'_reco.csv'
241 |
242 | Print (f"Use reco file: \n{reco_csv}")
243 | if not os.path.isfile(reco_csv):
244 | # first time
245 | Print (f"First run of the sampling...\n\n")
246 | # write the top line
247 | csv_top_line = f"root_path,error_L2,r2"
248 | write_one_line_to_file(
249 | this_line=csv_top_line+'\n',
250 | file_name=reco_csv,
251 | mode='w',
252 | )
253 | n_pick_finished = 0
254 | n_samp_finished = 0
255 | else:
256 | df_reco = pd.read_csv(reco_csv)
257 | n_pick_finished = len(df_reco)//n_keep_w_PP
258 | n_samp_finished = len(df_reco)%n_keep_w_PP
259 | Print (f"Previously, finished input #: {n_pick_finished}")
260 | Print (f"finished samp #: {n_samp_finished}\n\n")
261 |
262 |
263 |
264 | X_file_list = []
265 |
266 | # pick one sample
267 | for i_pick in range(len(PD_test_set_condition_list)):
268 |
269 | if i_pick > n_pick_finished-1: # pick up from the previous
270 |
271 | Print (f"\n\nWorking on Input #: {i_pick}\n\n")
272 |
273 | # i_pick = 1
274 |
275 | # 1. get X data padded
276 | X_arr = np.zeros(
277 | (mode_in, len_in)
278 | ) # (n_mode, seq_len)
279 |
280 | for j in range(mode_in):
281 | X_arr[j, :] = pad_a_np_arr_esm_for_NMS(
282 | PD_test_set_condition_list[i_pick][j, :],
283 | 0,
284 | len_in
285 | )
286 | print (X_arr.shape)
287 |
288 | # 2. get X normalized and formated
289 | for j in range(mode_in):
290 | X_arr[j, :] = X_arr[j, :]/PD_DataKeys['Xnormfac'][j]
291 |
292 | X_train = torch.from_numpy(X_arr).float() # (c, seq_len)
293 |
294 | # 3. expand into a batch
295 | X_train = X_train.unsqueeze(0).repeat(n_try_w_PD,1,1)
296 |
297 | print (X_train.shape)
298 |
299 | X_train_batch = X_train.to(device)
300 |
301 | # 4. prep the GT for NMS vecs
302 | seq_len_pick = PD_test_set_condition_list[i_pick].shape[1]
303 | GT_NMS_tensor_pick = torch.from_numpy(
304 | PD_test_set_condition_list[i_pick]
305 | ).float() # (n_mode, this_seq_len)
306 | GT_NMS_tensor = GT_NMS_tensor_pick.unsqueeze(0).repeat(
307 | n_try_w_PD,1,1
308 | ) # (batch, n_mode, this_seq_len)
309 | GT_NMS_tensor = GT_NMS_tensor.to(device)
310 |
311 | # 5. make prediction w. PD
312 | print (f"\n\nPD making {str(n_try_w_PD)} designs ...\n\n")
313 |
314 | PR_toks_list, PR_seqs_list, result_mask = \
315 | PD_wk_ProteinDesigner.sample_to_pLM_idx_seq(
316 | #
317 | common_AA_only=True, # False,
318 | mask_from_Y=None, # mask_from_Y
319 | # if none, will use mask from X, cond_img then text
320 | #
321 | text_con_input = X_train_batch,
322 | cond_images = X_train_batch,
323 | #
324 | cond_scale = PD_cond_scal,
325 | )
326 | # result_mask: (batch, seq_len)
327 |
328 | # 6. translate back to a batch for PP
329 | print (f"\n\nPP predicting performances ...\n\n")
330 |
331 | y_data_for_PP = \
332 | translate_seqs_list_into_idx_tensor_w_pLM(
333 | # 1. model converter
334 | esm_batch_converter,
335 | AA_seq_max_len,
336 | # 2. on input
337 | raw_condition_list=PR_seqs_list,
338 | # 3. on outpt
339 | device=device
340 | ) # (batch, seq_len)
341 |
342 | print (y_data_for_PP.shape)
343 |
344 | # 7. make prediction w PP
345 | PR_NMS_arr_list = PP_wk_ProteinPredictor.sample_to_NMS_list(
346 | # mask
347 | mask_from_Y = result_mask,
348 | NormFac_list = PD_DataKeys['Xnormfac'],
349 | #
350 | text_con_input = y_data_for_PP,
351 | cond_images = y_data_for_PP,
352 | #
353 | cond_scale = PP_cond_scal,
354 | ) # list (n_mode, seq_len)
355 | # make the list into a tensor
356 | PR_NMS_tensor = torch.from_numpy(
357 | np.stack(PR_NMS_arr_list, axis=0) # (b, n_mode, this_seq_len)
358 | ).float()
359 | PR_NMS_tensor = PR_NMS_tensor.to(device)
360 |
361 | # 8. calc the error for NMS vecs
362 |
363 | PR_NMS_err_batch = PR_err_for_PP(
364 | PR_NMS_tensor,
365 | GT_NMS_tensor,
366 | ) # (b, n_mode, this_seq_len)
367 | PR_NMS_err_batch = torch.sum(
368 | PR_NMS_err_batch,
369 | dim=(1,2)
370 | ) # (b, )
371 |
372 | print (f"Pick the best {n_keep_w_PP_goo}...")
373 |
374 | idxs_vals_to_pick_goo = torch.topk(
375 | PR_NMS_err_batch,
376 | k=n_keep_w_PP_goo,
377 | largest=False,
378 | )
379 | # have indices and values
380 | # (n_keep_w_PP, )
381 | print (f"Pick the worst {n_keep_w_PP_bad}...")
382 | idxs_vals_to_pick_bad = torch.topk(
383 | PR_NMS_err_batch,
384 | k=n_keep_w_PP_bad,
385 | largest=True,
386 | )
387 |
388 | print (f"N\n\now, fold the picked best {n_keep_w_PP_goo} and worst {n_keep_w_PP_bad} samples...")
389 | idxs_vals_to_pick = merge_two_topk(
390 | y_goo=idxs_vals_to_pick_goo,
391 | y_bad=idxs_vals_to_pick_bad,
392 | )
393 |
394 |
395 | # 9. postprocess
396 | for i_in_k in range(n_keep_w_PP):
397 |
398 | if i_in_k > n_samp_finished-1: # pick up from the previous
399 |
400 | Print (f"\n\nProcessing Picked #: Input {i_pick+1} -- Design {i_in_k+1}\n\n")
401 | if i_in_k n_samp_finished-1: # pick up from the previous
718 |
719 | else:
720 | pass # this record is already finished
721 |
722 |
723 |
--------------------------------------------------------------------------------
/VibeGen/DataSetPack.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.preprocessing import text, sequence
2 | from tensorflow.keras.preprocessing.text import Tokenizer
3 |
4 | from torch.utils.data import DataLoader,Dataset
5 | import pandas as pd
6 | import seaborn as sns
7 |
8 | import torchvision
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 |
13 | from torch import nn
14 | from torch import optim
15 | import torch.nn.functional as F
16 | from torchvision import datasets, transforms, models
17 |
18 | import torch.optim as optim
19 | from torch.optim.lr_scheduler import ExponentialLR, StepLR
20 | from functools import partial, wraps
21 |
22 | from sklearn.model_selection import train_test_split
23 | from sklearn.preprocessing import QuantileTransformer
24 | from sklearn.preprocessing import RobustScaler
25 |
26 | from matplotlib.ticker import MaxNLocator
27 |
28 | import torch
29 |
30 | import esm
31 |
32 | # special packages
33 |
34 | import VibeGen.UtilityPack as UPack
35 | from VibeGen.UtilityPack import (
36 | decode_one_ems_token_rec,
37 | decode_many_ems_token_rec
38 | )
39 |
40 | #
41 | DPack_Random = 123456
42 |
43 | class RegressionDataset(Dataset):
44 |
45 | def __init__(self, X_data, y_data):
46 | self.X_data = X_data
47 | self.y_data = y_data
48 |
49 | def __getitem__(self, index):
50 | return self.X_data[index], self.y_data[index]
51 |
52 | def __len__ (self):
53 | return len(self.X_data)
54 |
55 | # ============================================================
56 | # handle NMA result
57 | #
58 | # 1. screen the dataset
59 | # ============================================================
60 | def screen_dataset_MD_NMS_MultiModes(
61 | # # --
62 | # file_path,
63 | # ++
64 | csv_file=None,
65 | pk_file =None,
66 | PKeys=None,
67 | CKeys=None,
68 | ):
69 | # unload the parameters
70 |
71 | store_path = PKeys['data_dir']
72 | IF_SaveFig = CKeys['SilentRun']
73 | min_AASeq_len = PKeys['min_AA_seq_len']
74 | max_AASeq_len = PKeys['max_AA_seq_len']
75 | max_used_Seg_Num = PKeys['max_used_Seg_Num']
76 |
77 | # max_used_Smo_F = PKeys['max_Force_cap']
78 |
79 | # working part
80 | if csv_file != None:
81 | # not used for now
82 | # functions
83 | print('=============================================')
84 | print('1. read in the csv file...')
85 | print('=============================================')
86 | arr_key = PKeys['arr_key']
87 |
88 | df_raw = pd.read_csv(csv_file)
89 | UPack.Print("Raw df has keys:")
90 | UPack.Print(df_raw.keys())
91 |
92 | # convert string array back to array
93 | for this_key in arr_key:
94 | # np.array(list(map(float, one_record.split(" "))))
95 | df_raw[this_key] = df_raw[this_key].apply(lambda x: np.array(list(map(float, x.split(" ")))))
96 | # =====================================================
97 | # adjust if needed
98 | # patch up
99 | df_raw.rename(columns={"sample_FORCEpN_data":"sample_FORCE_data"}, inplace=True)
100 | print('Updated keys: \n', df_raw.keys())
101 |
102 | elif pk_file != None:
103 | # functions
104 | print('=============================================')
105 | print('1. read in the pk file...')
106 | print('=============================================')
107 | #
108 | df_raw = pd.read_pickle(pk_file)
109 |
110 | UPack.Print("Raw df has keys:")
111 | UPack.Print(df_raw.keys())
112 |
113 | # ..............................................................................
114 | # --
115 | fig = plt.figure(figsize=(24,16),dpi=200)
116 | fig, ax0 = plt.subplots()
117 | for ii in range(len( df_raw )):
118 | if df_raw['AA_Eff_Len'][ii]<=6400:
119 | # # +
120 | # ax0.plot(
121 | # df_disp_forc_smo['normalized_pull_gap_data'][ii],
122 | # df_disp_forc_smo['forc_data'][ii],
123 | # color="blue",label='full data'
124 | # )
125 | # #
126 | ax0.plot(
127 | df_raw['Norm_Resi_Ind_List'][ii],
128 | # df_raw['sample_FORCEpN_data'][ii],
129 | df_raw['Mode7_NormDisAmp'][ii],
130 | alpha=0.1,
131 | # color="green",label='simplified data',
132 | # linestyle='None',marker='^'
133 | )
134 | # ============================================
135 | # # too slow to do this
136 | # ax0.scatter(
137 | # df_raw['NormResiIndx_At_MaxVibrAmp_Mode7'][ii],
138 | # df_raw['NormDisAmp_At_MaxVibrAmp_Mode7'][ii],
139 | # )
140 | else:
141 | print(df_raw['pdb_id'][ii])
142 | # we see mistakes in: 1. wrong len of the AA; 2. wrong # of residue of the beginning and end
143 | plt.xlabel('Normalized residue index')
144 | plt.ylabel('Normalized vibrational disp. amp.')
145 | outname = store_path+'CSV_0_NMS_Mode7_Dist.jpg'
146 | if IF_SaveFig==1:
147 | plt.savefig(outname, dpi=200)
148 | else:
149 | plt.show()
150 | plt.close()
151 |
152 | print('=============================================')
153 | print('2. screen the entries...')
154 | print('=============================================')
155 | #
156 | df_isnull = pd.DataFrame(
157 | round(
158 | (df_raw.isnull().sum().sort_values(ascending=False)/df_raw.shape[0])*100,
159 | 1
160 | )
161 | ).reset_index()
162 | df_isnull.style.format({'% of Missing Data': lambda x:'{:.1%}'.format(abs(x))})
163 | cm = sns.light_palette("skyblue", as_cmap=True)
164 | df_isnull = df_isnull.style.background_gradient(cmap=cm)
165 | print('Check null...')
166 | print( df_isnull )
167 |
168 | print('Working on a dataframe with useful keywords')
169 | # suppose to be a smaller one
170 | # Focus on mode 7 For the moment
171 | # Expand to modes 7,8,9
172 | protein_df = pd.DataFrame().assign(
173 | pdb_id=df_raw['pdb_id'],
174 | AA=df_raw['AA_Full'],
175 | seq_len=df_raw['AA_Eff_Len'],
176 | AA_Seg_Num=df_raw['AA_Seg_Num'],
177 | Norm_Resi_Ind_List=df_raw['Norm_Resi_Ind_List'],
178 | # on mode 7
179 | Mode7_NormDisAmp=df_raw['Mode7_NormDisAmp'],
180 | ScaFac_7=df_raw['ScaFac_7'],
181 | Mode7_NormDis=df_raw['Mode7_NormDis'],
182 | Mode7_Freq=df_raw['Mode7_Freq'],
183 | NormResiIndx_At_MaxVibrAmp_Mode7=df_raw['NormResiIndx_At_MaxVibrAmp_Mode7'],
184 | NormDisAmp_At_MaxVibrAmp_Mode7=df_raw['NormDisAmp_At_MaxVibrAmp_Mode7'],
185 | Mode7_FixLen_NormDisAmp=df_raw['Mode7_FixLen_NormDisAmp'],
186 | # on mode 8
187 | Mode8_NormDisAmp=df_raw['Mode8_NormDisAmp'],
188 | ScaFac_8=df_raw['ScaFac_8'],
189 | Mode8_NormDis=df_raw['Mode8_NormDis'],
190 | Mode8_Freq=df_raw['Mode8_Freq'],
191 | NormResiIndx_At_MaxVibrAmp_Mode8=df_raw['NormResiIndx_At_MaxVibrAmp_Mode8'],
192 | NormDisAmp_At_MaxVibrAmp_Mode8=df_raw['NormDisAmp_At_MaxVibrAmp_Mode8'],
193 | # Mode8_FixLen_NormDisAmp=df_raw['Mode8_FixLen_NormDisAmp'],
194 | # on mode 9
195 | Mode9_NormDisAmp=df_raw['Mode9_NormDisAmp'],
196 | ScaFac_9=df_raw['ScaFac_9'],
197 | Mode9_NormDis=df_raw['Mode9_NormDis'],
198 | Mode9_Freq=df_raw['Mode9_Freq'],
199 | NormResiIndx_At_MaxVibrAmp_Mode9=df_raw['NormResiIndx_At_MaxVibrAmp_Mode9'],
200 | NormDisAmp_At_MaxVibrAmp_Mode9=df_raw['NormDisAmp_At_MaxVibrAmp_Mode9'],
201 | # Mode9_FixLen_NormDisAmp=df_raw['Mode9_FixLen_NormDisAmp'],
202 | )
203 | # ++ add new keys on energy if needed
204 |
205 | # screen using AA length
206 | print('a. screen using sequence length...')
207 | print('original sequences #: ', len(protein_df))
208 | #
209 | protein_df.drop(
210 | protein_df[protein_df['seq_len']>max_AASeq_len-2].index,
211 | inplace = True
212 | )
213 | protein_df.drop(
214 | protein_df[protein_df['seq_len'] max_used_Seg_Num].index,
225 | inplace = True
226 | )
227 | # protein_df.drop(
228 | # protein_df[protein_df['seq_len']
598 | #
599 | x1 = x0.copy()
600 | x1 = np.insert(x1,0,add_x)
601 | n0 = len(x1)
602 | if n0
616 | # # x1 = [add_x]+x1 # somehow, this one doesn't work
617 | # # print(x1)
618 | # # print('x1 len: ',len(x1) )
619 | # n0 = len(x1)
620 | # #
621 | # if n00:
680 | jj_0 = jj
681 | break
682 | for jj in range(max_AA_len):
683 | if np.fabs(this_print_arr[-(jj+1)])>0:
684 | jj_1 = jj
685 | break
686 | jj_data_len = max_AA_len-jj_0-jj_1
687 | print (f" Begin_padding: {jj_0}")
688 | print (f" End_padding: {jj_1}")
689 | print (f" Data_len: {jj_data_len}")
690 |
691 | UPack.Print (f"Now, calculate Normalization Factor for each mode")
692 | UPack.Print (f"Upper bound of the NFs: {np.amax(X)}")
693 |
694 | X_NF_List = []
695 | for ii, this_X_Key in enumerate(X_Keys):
696 | this_x_max = np.amax(
697 | X[:,ii,:]
698 | )
699 | X_NF_List.append(this_x_max)
700 | # normalization
701 | X[:,ii,:] = X[:,ii,:]/this_x_max
702 |
703 | UPack.Print (f"X_NF_List: {X_NF_List}")
704 |
705 | UPack.Print("======================================================")
706 | UPack.Print("2. work on Y data: AA Sequence")
707 | UPack.Print("======================================================")
708 | # take care of the y part: AA encoding
709 | #create and fit tokenizer for AA sequences
710 | seqs = protein_df.AA.values
711 | # ++ for pLM: esm
712 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
713 | print("pLM model: ", PKeys['ESM-2_Model'])
714 |
715 | if PKeys['ESM-2_Model']=='esm2_t33_650M_UR50D':
716 | # print('Debug block')
717 | # embed dim: 1280
718 | esm_model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
719 | len_toks=len(esm_alphabet.all_toks)
720 | elif PKeys['ESM-2_Model']=='esm2_t12_35M_UR50D':
721 | # embed dim: 480
722 | esm_model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
723 | len_toks=len(esm_alphabet.all_toks)
724 | elif PKeys['ESM-2_Model']=='esm2_t36_3B_UR50D':
725 | # embed dim: 2560
726 | esm_model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
727 | len_toks=len(esm_alphabet.all_toks)
728 | elif PKeys['ESM-2_Model']=='esm2_t30_150M_UR50D':
729 | # embed dim: 640
730 | esm_model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
731 | len_toks=len(esm_alphabet.all_toks)
732 | else:
733 | print("protein language model is not defined.")
734 | #
735 | # for check
736 | print("esm_alphabet.use_msa: ", esm_alphabet.use_msa)
737 | print("# of tokens in AA alphabet: ", len_toks)
738 | # need to save 2 positions for and
739 | esm_batch_converter = esm_alphabet.get_batch_converter(
740 | truncation_seq_length=PKeys['max_AA_seq_len']-2
741 | )
742 | esm_model.eval() # disables dropout for deterministic results
743 | # prepare seqs for the "esm_batch_converter..."
744 | # add dummy labels
745 | seqs_ext=[]
746 | for i in range(len(seqs)):
747 | seqs_ext.append(
748 | (" ", seqs[i])
749 | )
750 | # batch_labels, batch_strs, batch_tokens = esm_batch_converter(seqs_ext)
751 | _, y_strs, y_data = esm_batch_converter(seqs_ext)
752 | y_strs_lens = (y_data != esm_alphabet.padding_idx).sum(1)
753 | # print(batch_tokens.shape)
754 | print ("y_data.dim: ", y_data.dtype)
755 |
756 | fig_handle = sns.histplot(
757 | data=pd.DataFrame({'AA code': np.array(y_data).flatten()}),
758 | x='AA code',
759 | bins=np.array([i-0.5 for i in range(0,33+3,1)]), # np.array([i-0.5 for i in range(0,20+3,1)])
760 | # binwidth=1,
761 | )
762 | fig = fig_handle.get_figure()
763 | fig_handle.set_xlim(-1, 33+1)
764 | # fig_handle.set_ylim(0, 100000)
765 | outname=store_path+'CSV_5_DataSet_AACode_dist.jpg'
766 | if IF_SaveFig==1:
767 | plt.savefig(outname, dpi=200)
768 | else:
769 | plt.show()
770 | plt.close()
771 |
772 | # -----------------------------------------------------------
773 | # print ("#################################")
774 | # print ("DICTIONARY y_data")
775 | # dictt=tokenizer_y.get_config()
776 | # print (dictt)
777 | # num_words = len(tokenizer_y.word_index) + 1
778 | # print ("################## y max token: ",num_words )
779 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
780 | print ("#################################")
781 | print ("DICTIONARY y_data: esm-", PKeys['ESM-2_Model'])
782 | print ("################## y max token: ",len_toks )
783 |
784 | #revere
785 | print ("TEST REVERSE: ")
786 |
787 | # # --------------------------------------------------------------
788 | # y_data_reversed=tokenizer_y.sequences_to_texts (y_data)
789 |
790 | # for iii in range (len(y_data_reversed)):
791 | # y_data_reversed[iii]=y_data_reversed[iii].upper().strip().replace(" ", "")
792 |
793 | # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
794 | # assume y_data is reversiable
795 | y_data_reversed = decode_many_ems_token_rec(y_data, esm_alphabet)
796 |
797 |
798 | print ("Element 0", y_data_reversed[0])
799 | print ("Number of y samples",len (y_data_reversed) )
800 |
801 | for iii in [0,2,6]:
802 | print("Ori and REVERSED SEQ: ", iii)
803 | print(seqs[iii])
804 | print(y_data_reversed[iii])
805 |
806 | # print ("Original: ", y_data[:3,:])
807 | # print ("REVERSED TEXT 0..2: ", y_data_reversed[0:3])
808 |
809 | print ("Len 0 as example: ", len (y_data_reversed[0]) )
810 | print ("CHeck ori: ", len (seqs[0]) )
811 | print ("Len 2 as example: ", len (y_data_reversed[2]) )
812 | print ("CHeck ori: ", len (seqs[2]) )
813 |
814 | # placeholder
815 | tokenizer_X = None
816 | tokenizer_Y = None
817 |
818 | return X, X_NF_List, y_data, y_data_reversed,tokenizer_X, tokenizer_Y
819 |
820 | # =============================================================
821 | # build loaders
822 | def build_dataloaders(
823 | X,
824 | y_data,
825 | protein_df,
826 | PKeys=None,
827 | CKeys=None,
828 | ):
829 | # unload the parameters
830 | store_path = PKeys['data_dir']
831 | IF_SaveFig = CKeys['SilentRun']
832 |
833 | batch_size = PKeys['batch_size']
834 | TestSet_ratio = PKeys['testset_ratio']
835 | maxdata=PKeys['maxdata']
836 |
837 |
838 | if maxdata(-1,1)
143 | def normalize_neg_one_to_one(img):
144 | return img * 2 - 1
145 |
146 | def unnormalize_zero_to_one(normed_img):
147 | return (normed_img + 1) * 0.5
148 |
149 | def compact(input_dict):
150 | return {key: value for key, value in input_dict.items() if exists(value)}
151 |
152 | def maybe_transform_dict_key(input_dict, key, fn):
153 | if key not in input_dict:
154 | return input_dict
155 |
156 | copied_dict = input_dict.copy()
157 | copied_dict[key] = fn(copied_dict[key])
158 | return copied_dict
159 |
160 | # tensor helpers
161 |
162 | def log(t, eps: float = 1e-12):
163 | return torch.log(t.clamp(min = eps))
164 |
165 | #
166 | # ===========================================================
167 | # ===========================================================
168 | # ===========================================================
169 | # main class: ElucidatedImagen
170 | # ===========================================================
171 | # ===========================================================
172 | # ===========================================================
173 | #
174 | # on diffusion scheduler
175 | #
176 | # gaussian diffusion with continuous time helper functions and classes
177 | # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py
178 |
179 | @torch.jit.script
180 | def beta_linear_log_snr(t):
181 | return -torch.log(expm1(1e-4 + 10 * (t ** 2)))
182 |
183 | @torch.jit.script
184 | def alpha_cosine_log_snr(t, s: float = 0.008):
185 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version
186 |
187 | def log_snr_to_alpha_sigma(log_snr):
188 | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))
189 | #
190 | class GaussianDiffusionContinuousTimes(nn.Module):
191 | def __init__(
192 | self,
193 | *,
194 | noise_schedule,
195 | timesteps = 1000,
196 | ):
197 | super().__init__()
198 |
199 | if noise_schedule == "linear":
200 | self.log_snr = beta_linear_log_snr
201 | elif noise_schedule == "cosine":
202 | self.log_snr = alpha_cosine_log_snr
203 | else:
204 | raise ValueError(f'invalid noise schedule {noise_schedule}')
205 |
206 | self.num_timesteps = timesteps
207 |
208 | def get_times(
209 | self,
210 | batch_size,
211 | noise_level,
212 | *,
213 | device
214 | ):
215 | return torch.full(
216 | (batch_size,),
217 | noise_level,
218 | device = device,
219 | dtype = torch.float32
220 | )
221 |
222 | def sample_random_times(
223 | self,
224 | batch_size,
225 | *,
226 | device
227 | ):
228 | return torch.zeros(
229 | (batch_size,),
230 | device = device
231 | ).float().uniform_(0, 1)
232 |
233 | def get_condition(self, times):
234 | return maybe(self.log_snr)(times)
235 |
236 | def get_sampling_timesteps(
237 | self,
238 | batch,
239 | *,
240 | device
241 | ):
242 | times = torch.linspace(
243 | 1.,
244 | 0.,
245 | self.num_timesteps + 1,
246 | device = device
247 | )
248 | times = repeat(times, 't -> b t', b = batch)
249 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
250 | times = times.unbind(dim = -1)
251 | return times
252 |
253 | def q_posterior(
254 | self,
255 | x_start,
256 | x_t,
257 | t,
258 | *,
259 | t_next = None
260 | ):
261 | t_next = default(
262 | t_next,
263 | lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)
264 | )
265 |
266 | """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
267 | log_snr = self.log_snr(t)
268 | log_snr_next = self.log_snr(t_next)
269 | log_snr, log_snr_next = map(
270 | partial(right_pad_dims_to, x_t),
271 | (log_snr, log_snr_next)
272 | )
273 |
274 | alpha, sigma = log_snr_to_alpha_sigma(log_snr)
275 | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
276 |
277 | # c - as defined near eq 33
278 | c = -expm1(log_snr - log_snr_next)
279 | posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)
280 |
281 | # following (eq. 33)
282 | posterior_variance = (sigma_next ** 2) * c
283 | posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
284 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
285 |
286 | def q_sample(
287 | self,
288 | x_start,
289 | t,
290 | noise = None
291 | ):
292 | dtype = x_start.dtype
293 |
294 | if isinstance(t, float):
295 | batch = x_start.shape[0]
296 | t = torch.full((batch,), t, device = x_start.device, dtype = dtype)
297 |
298 | noise = default(noise, lambda: torch.randn_like(x_start))
299 | log_snr = self.log_snr(t).type(dtype)
300 | log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
301 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
302 |
303 | return alpha * x_start + sigma * noise, log_snr, alpha, sigma
304 |
305 | def q_sample_from_to(
306 | self,
307 | x_from,
308 | from_t,
309 | to_t,
310 | noise = None
311 | ):
312 | shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
313 | batch = shape[0]
314 |
315 | if isinstance(from_t, float):
316 | from_t = torch.full((batch,), from_t, device = device, dtype = dtype)
317 |
318 | if isinstance(to_t, float):
319 | to_t = torch.full((batch,), to_t, device = device, dtype = dtype)
320 |
321 | noise = default(noise, lambda: torch.randn_like(x_from))
322 |
323 | log_snr = self.log_snr(from_t)
324 | log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)
325 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)
326 |
327 | log_snr_to = self.log_snr(to_t)
328 | log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)
329 | alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to)
330 |
331 | return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha
332 |
333 | def predict_start_from_v(self, x_t, t, v):
334 | log_snr = self.log_snr(t)
335 | log_snr = right_pad_dims_to(x_t, log_snr)
336 | alpha, sigma = log_snr_to_alpha_sigma(log_snr)
337 | return alpha * x_t - sigma * v
338 |
339 | def predict_start_from_noise(self, x_t, t, noise):
340 | log_snr = self.log_snr(t)
341 | log_snr = right_pad_dims_to(x_t, log_snr)
342 | alpha, sigma = log_snr_to_alpha_sigma(log_snr)
343 | return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)
344 |
345 | # ===========================================================
346 | # constants
347 |
348 | Hparams_fields = [
349 | 'num_sample_steps',
350 | 'sigma_min',
351 | 'sigma_max',
352 | 'sigma_data',
353 | 'rho',
354 | 'P_mean',
355 | 'P_std',
356 | 'S_churn',
357 | 'S_tmin',
358 | 'S_tmax',
359 | 'S_noise'
360 | ]
361 |
362 | Hparams = namedtuple('Hparams', Hparams_fields)
363 |
364 |
365 | # ===========================================================
366 | # ===========================================================
367 | # add for OneD data format
368 | #
369 | class ElucidatedImagen_OneD(nn.Module):
370 | def __init__(
371 | self,
372 | # 1. unets: many setups of UNet will be passed on via UNet itself
373 | unets,
374 | *,
375 | channels = 3,
376 | # 2. in-output image size
377 | image_sizes, # for cascading ddpm, image size at each stage
378 | # 3. on text conditioning
379 | text_encoder_name = None, # TBU: DEFAULT_T5_NAME,
380 | text_embed_dim = None,
381 | cond_drop_prob = 0.1,
382 | condition_on_text = True,
383 | #
384 | random_crop_sizes = None,
385 | resize_mode = 'nearest',
386 | temporal_downsample_factor = 1,
387 | resize_cond_video_frames = True,
388 | lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
389 | per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
390 | auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
391 | dynamic_thresholding = True,
392 | dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
393 | only_train_unet_number = None,
394 | lowres_noise_schedule = 'linear',
395 | num_sample_steps = 32, # number of sampling steps
396 | sigma_min = 0.002, # min noise level
397 | sigma_max = 80, # max noise level
398 | sigma_data = 0.5, # standard deviation of data distribution
399 | rho = 7, # controls the sampling schedule
400 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
401 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
402 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
403 | S_tmin = 0.05,
404 | S_tmax = 50,
405 | S_noise = 1.003,
406 | # ++
407 | CKeys = {'Debug_Level':0}, # for debug purpose: 0--silence mode
408 | ):
409 | super().__init__()
410 |
411 | # ++ for debug
412 | self.CKeys = CKeys
413 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
414 | print ("Debug mode: Initialization of EImagen...\n")
415 |
416 | self.only_train_unet_number = only_train_unet_number
417 | # ++
418 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
419 | print (f".only_train_unet_number: {self.only_train_unet_number}")
420 |
421 | # conditioning hparams
422 |
423 | self.condition_on_text = condition_on_text
424 | self.unconditional = not condition_on_text
425 | # ++
426 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
427 | print (f".condition_on_text: {self.condition_on_text}")
428 | print (f".unconditional: {self.unconditional}")
429 |
430 | # channels
431 |
432 | self.channels = channels
433 | # ++
434 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
435 | print (f".channels: {self.channels}")
436 |
437 | # automatically take care of ensuring that first unet is unconditional
438 | # while the rest of the unets are conditioned on the low resolution image produced by previous unet
439 |
440 | unets = cast_tuple(unets)
441 | num_unets = len(unets)
442 |
443 | # randomly cropping for upsampler training
444 |
445 | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
446 | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
447 | # may get rid of this when moving to 1d case
448 | # ++
449 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
450 | print (f".random_crop_sizes: {self.random_crop_sizes}")
451 |
452 | # lowres augmentation noise schedule
453 |
454 | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(
455 | noise_schedule = lowres_noise_schedule
456 | )
457 | # ++
458 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
459 | print (f".lowres_noise_schedule: {self.lowres_noise_schedule}")
460 |
461 | # get text encoder
462 |
463 | self.text_encoder_name = text_encoder_name
464 | # --: a dull one is enough
465 | # self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
466 | # ++
467 | self.text_embed_dim = text_embed_dim
468 | # ++
469 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
470 | print (f".text_encoder_name: {self.text_encoder_name}")
471 | print (f".text_embed_dim: {self.text_embed_dim}")
472 |
473 | # -- text channel is not updated yet
474 | # self.encode_text = partial(t5_encode_text, name = text_encoder_name)
475 | # ++: TBU if needed
476 | self.encode_text = None
477 | # ++
478 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
479 | print (f".encode_text: {self.encode_text}")
480 |
481 | # construct unets
482 |
483 | self.unets = nn.ModuleList([])
484 | self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
485 |
486 | for ind, one_unet in enumerate(unets):
487 | # check the class of the unet: accept Unet_OneD, NullUnet
488 | assert isinstance(one_unet, (Unet_OneD, Unet3D, NullUnet))
489 | is_first = ind == 0
490 |
491 | one_unet = one_unet.cast_model_parameters(
492 | lowres_cond = not is_first, # may open this channel
493 | cond_on_text = self.condition_on_text,
494 | text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
495 | channels = self.channels,
496 | channels_out = self.channels,
497 | )
498 |
499 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
500 | print (f"Add one UNet: ")
501 | print (one_unet)
502 | print (f"======================================== ")
503 |
504 | self.unets.append(one_unet)
505 |
506 | # determine whether we are training on images or video
507 |
508 | is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
509 | self.is_video = is_video
510 | # ++
511 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
512 | print (f".is_video: {self.is_video}")
513 |
514 | self.right_pad_dims_to_datatype = partial(
515 | rearrange,
516 | # --
517 | # pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1')
518 | # ++ may think adding one for video of 1d data
519 | pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1 1')
520 | )
521 | # ++
522 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
523 | print (f".right_pad_dims_to_datatype: {self.right_pad_dims_to_datatype}")
524 |
525 | self.resize_to = resize_video_to if is_video else resize_2d_image_to
526 | # only triggered when the last dimension doesn't match the traget one
527 | # input: (mini-batch, channels, width) # assume it works for 1d
528 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
529 |
530 | self.resize_to = partial(self.resize_to, mode = resize_mode)
531 | # ++
532 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
533 | print (f".resize_to: {self.resize_to}")
534 |
535 | # unet image sizes
536 |
537 | self.image_sizes = cast_tuple(image_sizes)
538 | assert num_unets == len(self.image_sizes), \
539 | f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
540 | # ++
541 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
542 | print (f".image_size: {self.image_sizes}")
543 |
544 | self.sample_channels = cast_tuple(self.channels, num_unets)
545 | # ++
546 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
547 | print (f".sample_channels: {self.sample_channels}")
548 |
549 | # cascading ddpm related stuff
550 |
551 | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
552 | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), \
553 | 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
554 |
555 | self.lowres_sample_noise_level = lowres_sample_noise_level
556 | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
557 | # ++
558 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
559 | print (f".lowres_sample_noise_level: {self.lowres_sample_noise_level}")
560 | print (f".per_sample_random_aug_noise_level: {self.per_sample_random_aug_noise_level}")
561 |
562 | # classifier free guidance
563 |
564 | self.cond_drop_prob = cond_drop_prob
565 | self.can_classifier_guidance = cond_drop_prob > 0.
566 | # ++
567 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
568 | print (f".cond_drop_prob: {self.cond_drop_prob}")
569 | print (f".can_classifier_guidance: {self.can_classifier_guidance}")
570 |
571 | # normalize and unnormalize image functions
572 |
573 | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
574 | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
575 | self.input_image_range = (0. if auto_normalize_img else -1., 1.)
576 | # ++
577 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
578 | print (f".normalize_img: {self.normalize_img}")
579 | print (f".unnormalize_img: {self.unnormalize_img}")
580 | print (f".input_image_range: {self.input_image_range}")
581 |
582 | # dynamic thresholding
583 |
584 | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
585 | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
586 | # ++
587 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
588 | print (f".dynamic_thresholding: {self.dynamic_thresholding}")
589 | print (f".dynamic_thresholding_percentile: {self.dynamic_thresholding_percentile}")
590 |
591 | # temporal interpolations
592 |
593 | temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
594 | self.temporal_downsample_factor = temporal_downsample_factor
595 |
596 | self.resize_cond_video_frames = resize_cond_video_frames
597 | self.temporal_downsample_divisor = temporal_downsample_factor[0]
598 | # ++
599 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
600 | print (f".temporal_downsample_factor: {self.temporal_downsample_factor}")
601 | print (f".resize_cond_video_frames: {self.resize_cond_video_frames}")
602 | print (f".temporal_downsample_divisor: {self.temporal_downsample_divisor}")
603 |
604 | assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
605 | assert tuple(sorted(temporal_downsample_factor, reverse = True)) == temporal_downsample_factor, 'temporal downsample factor must be in order of descending'
606 |
607 | # elucidating parameters
608 |
609 | hparams = [
610 | num_sample_steps,
611 | sigma_min,
612 | sigma_max,
613 | sigma_data,
614 | rho,
615 | P_mean,
616 | P_std,
617 | S_churn,
618 | S_tmin,
619 | S_tmax,
620 | S_noise,
621 | ]
622 |
623 | hparams = [cast_tuple(hp, num_unets) for hp in hparams]
624 | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
625 | # ++
626 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
627 | print (f".hparams: {self.hparams}")
628 |
629 | # one temp parameter for keeping track of device
630 |
631 | self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
632 |
633 | # default to device of unets passed in
634 |
635 | self.to(next(self.unets.parameters()).device)
636 | # ++
637 | if self.CKeys['Debug_Level']==Imagen_Init_Level:
638 | print (f".device: {next(self.unets.parameters()).device}")
639 |
640 | def force_unconditional_(self):
641 | self.condition_on_text = False
642 | self.unconditional = True
643 |
644 | for unet in self.unets:
645 | unet.cond_on_text = False
646 |
647 | @property
648 | def device(self):
649 | return self._temp.device
650 |
651 | def get_unet(self, unet_number):
652 | assert 0 < unet_number <= len(self.unets)
653 | index = unet_number - 1
654 |
655 | if isinstance(self.unets, nn.ModuleList):
656 | unets_list = [unet for unet in self.unets]
657 | delattr(self, 'unets')
658 | self.unets = unets_list
659 |
660 | if index != self.unet_being_trained_index:
661 | for unet_index, unet in enumerate(self.unets):
662 | unet.to(self.device if unet_index == index else 'cpu')
663 |
664 | self.unet_being_trained_index = index
665 | return self.unets[index]
666 |
667 | def reset_unets_all_one_device(self, device = None):
668 | device = default(device, self.device)
669 | self.unets = nn.ModuleList([*self.unets])
670 | self.unets.to(device)
671 |
672 | self.unet_being_trained_index = -1
673 |
674 | @contextmanager
675 | def one_unet_in_gpu(self, unet_number = None, unet = None):
676 | assert exists(unet_number) ^ exists(unet)
677 |
678 | if exists(unet_number):
679 | unet = self.unets[unet_number - 1]
680 |
681 | cpu = torch.device('cpu')
682 |
683 | devices = [module_device(unet) for unet in self.unets]
684 |
685 | self.unets.to(cpu)
686 | unet.to(self.device)
687 |
688 | yield
689 |
690 | for unet, device in zip(self.unets, devices):
691 | unet.to(device)
692 |
693 | # overriding state dict functions
694 |
695 | def state_dict(self, *args, **kwargs):
696 | self.reset_unets_all_one_device()
697 | return super().state_dict(*args, **kwargs)
698 |
699 | def load_state_dict(self, *args, **kwargs):
700 | self.reset_unets_all_one_device()
701 | return super().load_state_dict(*args, **kwargs)
702 |
703 | # dynamic thresholding
704 |
705 | def threshold_x_start(self, x_start, dynamic_threshold = True):
706 | if not dynamic_threshold:
707 | return x_start.clamp(-1., 1.)
708 |
709 | s = torch.quantile(
710 | rearrange(x_start, 'b ... -> b (...)').abs(),
711 | self.dynamic_thresholding_percentile,
712 | dim = -1
713 | )
714 |
715 | s.clamp_(min = 1.)
716 | s = right_pad_dims_to(x_start, s)
717 | return x_start.clamp(-s, s) / s
718 |
719 | # derived preconditioning params - Table 1
720 |
721 | def c_skip(self, sigma_data, sigma):
722 | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
723 |
724 | def c_out(self, sigma_data, sigma):
725 | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
726 |
727 | def c_in(self, sigma_data, sigma):
728 | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
729 |
730 | def c_noise(self, sigma):
731 | return log(sigma) * 0.25
732 |
733 | # preconditioned network output
734 | # equation (7) in the paper
735 |
736 | def preconditioned_network_forward(
737 | self,
738 | unet_forward,
739 | noised_images,
740 | sigma,
741 | *,
742 | sigma_data,
743 | clamp = False,
744 | dynamic_threshold = True,
745 | **kwargs
746 | ):
747 | batch, device = noised_images.shape[0], noised_images.device
748 |
749 | if isinstance(sigma, float):
750 | sigma = torch.full((batch,), sigma, device = device)
751 |
752 | padded_sigma = self.right_pad_dims_to_datatype(sigma)
753 |
754 | net_out = unet_forward(
755 | self.c_in(sigma_data, padded_sigma) * noised_images,
756 | self.c_noise(sigma),
757 | **kwargs
758 | )
759 |
760 | out = self.c_skip(sigma_data, padded_sigma) * noised_images \
761 | + self.c_out(sigma_data, padded_sigma) * net_out
762 |
763 | if not clamp:
764 | return out
765 |
766 | return self.threshold_x_start(out, dynamic_threshold)
767 |
768 | # sampling
769 |
770 | # sample schedule
771 | # equation (5) in the paper
772 |
773 | def sample_schedule(
774 | self,
775 | num_sample_steps,
776 | rho,
777 | sigma_min,
778 | sigma_max
779 | ):
780 | N = num_sample_steps
781 | inv_rho = 1 / rho
782 |
783 | steps = torch.arange(
784 | num_sample_steps,
785 | device = self.device,
786 | dtype = torch.float32
787 | )
788 | sigmas = (
789 | sigma_max ** inv_rho \
790 | + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)
791 | ) ** rho
792 |
793 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
794 | return sigmas
795 |
796 | @torch.no_grad()
797 | def one_unet_sample(
798 | self,
799 | unet,
800 | shape,
801 | *,
802 | unet_number,
803 | clamp = True,
804 | dynamic_threshold = True,
805 | cond_scale = 1.,
806 | use_tqdm = True,
807 | inpaint_videos = None,
808 | inpaint_images = None,
809 | inpaint_masks = None,
810 | inpaint_resample_times = 5,
811 | init_images = None,
812 | skip_steps = None,
813 | sigma_min = None,
814 | sigma_max = None,
815 | **kwargs
816 | ):
817 | # video
818 |
819 | is_video = len(shape) == 5
820 | frames = shape[-3] if is_video else None
821 | resize_kwargs = dict(target_frames = frames) if exists(frames) else dict()
822 |
823 | # get specific sampling hyperparameters for unet
824 |
825 | hp = self.hparams[unet_number - 1]
826 |
827 | sigma_min = default(sigma_min, hp.sigma_min)
828 | sigma_max = default(sigma_max, hp.sigma_max)
829 |
830 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
831 |
832 | sigmas = self.sample_schedule(
833 | hp.num_sample_steps,
834 | hp.rho,
835 | sigma_min, sigma_max
836 | )
837 |
838 | gammas = torch.where(
839 | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
840 | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
841 | 0.
842 | )
843 |
844 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
845 |
846 | # images is noise at the beginning
847 |
848 | init_sigma = sigmas[0]
849 |
850 | images = init_sigma * torch.randn(shape, device = self.device)
851 |
852 | # initializing with an image
853 |
854 | if exists(init_images):
855 | images += init_images
856 |
857 | # keeping track of x0, for self conditioning if needed
858 |
859 | x_start = None
860 |
861 | # prepare inpainting images and mask
862 |
863 | inpaint_images = default(inpaint_videos, inpaint_images)
864 | has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
865 | resample_times = inpaint_resample_times if has_inpainting else 1
866 |
867 | if has_inpainting:
868 | inpaint_images = self.normalize_img(inpaint_images)
869 | inpaint_images = self.resize_to(inpaint_images, shape[-1], **resize_kwargs)
870 | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1], **resize_kwargs).bool()
871 |
872 | # unet kwargs
873 |
874 | unet_kwargs = dict(
875 | sigma_data = hp.sigma_data,
876 | clamp = clamp,
877 | dynamic_threshold = dynamic_threshold,
878 | cond_scale = cond_scale,
879 | **kwargs
880 | )
881 |
882 | # gradually denoise
883 |
884 | initial_step = default(skip_steps, 0)
885 | sigmas_and_gammas = sigmas_and_gammas[initial_step:]
886 |
887 | total_steps = len(sigmas_and_gammas)
888 |
889 | for ind, (sigma, sigma_next, gamma) in tqdm(
890 | enumerate(sigmas_and_gammas),
891 | total = total_steps,
892 | desc = 'sampling time step',
893 | disable = not use_tqdm
894 | ):
895 | is_last_timestep = ind == (total_steps - 1)
896 |
897 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
898 |
899 | for r in reversed(range(resample_times)):
900 | is_last_resample_step = r == 0
901 |
902 | eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
903 |
904 | sigma_hat = sigma + gamma * sigma
905 | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
906 |
907 | images_hat = images + added_noise
908 |
909 | self_cond = x_start if unet.self_cond else None
910 |
911 | if has_inpainting:
912 | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
913 |
914 | model_output = self.preconditioned_network_forward(
915 | unet.forward_with_cond_scale,
916 | images_hat,
917 | sigma_hat,
918 | self_cond = self_cond,
919 | **unet_kwargs
920 | )
921 |
922 | denoised_over_sigma = (images_hat - model_output) / sigma_hat
923 |
924 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
925 |
926 | # second order correction, if not the last timestep
927 |
928 | has_second_order_correction = sigma_next != 0
929 |
930 | if has_second_order_correction:
931 | self_cond = model_output if unet.self_cond else None
932 |
933 | model_output_next = self.preconditioned_network_forward(
934 | unet.forward_with_cond_scale,
935 | images_next,
936 | sigma_next,
937 | self_cond = self_cond,
938 | **unet_kwargs
939 | )
940 |
941 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
942 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
943 |
944 | images = images_next
945 |
946 | if has_inpainting and not (is_last_resample_step or is_last_timestep):
947 | # renoise in repaint and then resample
948 | repaint_noise = torch.randn(shape, device = self.device)
949 | images = images + (sigma - sigma_next) * repaint_noise
950 |
951 | x_start = model_output if not has_second_order_correction else model_output_next # save model output for self conditioning
952 |
953 | images = images.clamp(-1., 1.)
954 |
955 | if has_inpainting:
956 | images = images * ~inpaint_masks + inpaint_images * inpaint_masks
957 |
958 | return self.unnormalize_img(images)
959 |
960 | @torch.no_grad()
961 | @eval_decorator
962 | def sample(
963 | self,
964 | # 1. on text condition
965 | texts: List[str] = None,
966 | text_masks = None,
967 | text_embeds = None,
968 | # 2. on condition images
969 | cond_images = None,
970 | cond_video_frames = None,
971 | post_cond_video_frames = None,
972 | # 3. inpaint images
973 | inpaint_videos = None,
974 | inpaint_images = None,
975 | inpaint_masks = None,
976 | inpaint_resample_times = 5,
977 | #
978 | init_images = None,
979 | skip_steps = None,
980 | sigma_min = None,
981 | sigma_max = None,
982 | video_frames = None,
983 | batch_size = 1,
984 | cond_scale = 1.,
985 | lowres_sample_noise_level = None,
986 | start_at_unet_number = 1,
987 | start_image_or_video = None,
988 | stop_at_unet_number = None,
989 | return_all_unet_outputs = False,
990 | return_pil_images = False,
991 | use_tqdm = True,
992 | use_one_unet_in_gpu = True,
993 | device = None,
994 | ):
995 | # ++
996 | if self.CKeys['Debug_Level']==Imagen_Samp_Level:
997 | print (f"Debug mode for .sample func...")
998 |
999 | device = default(device, self.device)
1000 | self.reset_unets_all_one_device(device = device)
1001 | # ++
1002 | if self.CKeys['Debug_Level']==Imagen_Samp_Level:
1003 | print (f"device for unets: {device}")
1004 |
1005 | cond_images = maybe(cast_uint8_images_to_float)(cond_images)
1006 | # ++
1007 | if self.CKeys['Debug_Level']==Imagen_Samp_Level:
1008 | if not cond_images==None:
1009 | print (f"input cond_images.shape: {cond_images.shape}")
1010 | else:
1011 | print (f"input cond_images: None")
1012 |
1013 | # Channel t-1: use texts directly, not text_embeds; otherwise, text_embeds will be passed in
1014 | if exists(texts) and not exists(text_embeds) and not self.unconditional:
1015 | assert all([*map(len, texts)]), 'text cannot be empty'
1016 |
1017 | with autocast(enabled = False):
1018 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
1019 |
1020 | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
1021 |
1022 | if not self.unconditional:
1023 | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
1024 |
1025 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
1026 | batch_size = text_embeds.shape[0]
1027 | # ++
1028 | if self.CKeys['Debug_Level']==Imagen_Samp_Level:
1029 | if not (text_embeds==None):
1030 | print (f"text_embeds.shape: {text_embeds.shape}")
1031 | if not (text_masks==None):
1032 | print (f"text_masks.shape: {text_masks.shape}")
1033 |
1034 | # inpainting
1035 |
1036 | inpaint_images = default(inpaint_videos, inpaint_images)
1037 |
1038 | if exists(inpaint_images):
1039 | if self.unconditional:
1040 | if batch_size == 1: # assume researcher wants to broadcast along inpainted images
1041 | batch_size = inpaint_images.shape[0]
1042 |
1043 | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``'
1044 | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
1045 |
1046 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
1047 | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
1048 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
1049 |
1050 | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
1051 |
1052 | outputs = []
1053 |
1054 | is_cuda = next(self.parameters()).is_cuda
1055 | device = next(self.parameters()).device
1056 |
1057 | # will be applied to the lowres images
1058 | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
1059 |
1060 | num_unets = len(self.unets)
1061 | cond_scale = cast_tuple(cond_scale, num_unets)
1062 |
1063 | # handle video and frame dimension
1064 |
1065 | if self.is_video and exists(inpaint_images):
1066 | video_frames = inpaint_images.shape[2]
1067 |
1068 | if inpaint_masks.ndim == 3:
1069 | inpaint_masks = repeat(
1070 | inpaint_masks,
1071 | # 'b h w -> b f h w',
1072 | 'b h -> b f h',
1073 | f = video_frames
1074 | )
1075 |
1076 | assert inpaint_masks.shape[1] == video_frames
1077 |
1078 | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
1079 |
1080 | # determine the frame dimensions, if needed
1081 |
1082 | all_frame_dims = calc_all_frame_dims(
1083 | self.temporal_downsample_factor,
1084 | video_frames
1085 | )
1086 |
1087 | # initializing with an image or video
1088 |
1089 | init_images = cast_tuple(init_images, num_unets)
1090 | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
1091 | # ++
1092 | if self.CKeys['Debug_Level']==Imagen_Samp_Level:
1093 | print (f"init_images: {init_images}")
1094 |
1095 | skip_steps = cast_tuple(skip_steps, num_unets)
1096 |
1097 | sigma_min = cast_tuple(sigma_min, num_unets)
1098 | sigma_max = cast_tuple(sigma_max, num_unets)
1099 |
1100 | # handle starting at a unet greater than 1, for training only-upscaler training
1101 |
1102 | if start_at_unet_number > 1:
1103 | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
1104 | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
1105 | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
1106 |
1107 | prev_image_size = self.image_sizes[start_at_unet_number - 2]
1108 | img = self.resize_to(start_image_or_video, prev_image_size)
1109 |
1110 | # go through each unet in cascade
1111 |
1112 | for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(
1113 | zip(
1114 | range(1, num_unets + 1), self.unets,
1115 | self.sample_channels, self.image_sizes,
1116 | all_frame_dims, self.hparams,
1117 | self.dynamic_thresholding, cond_scale,
1118 | init_images, skip_steps,
1119 | sigma_min, sigma_max
1120 | ),
1121 | disable = not use_tqdm
1122 | ):
1123 | if unet_number < start_at_unet_number:
1124 | continue
1125 |
1126 | assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
1127 |
1128 | context = self.one_unet_in_gpu(unet = unet) if is_cuda and use_one_unet_in_gpu else nullcontext()
1129 |
1130 | with context:
1131 | lowres_cond_img = lowres_noise_times = None
1132 |
1133 | # --
1134 | # shape = (batch_size, channel, *frame_dims, image_size, image_size)
1135 | # ++
1136 | shape = (batch_size, channel, *frame_dims, image_size)
1137 |
1138 | resize_kwargs = dict()
1139 | video_kwargs = dict()
1140 |
1141 | if self.is_video:
1142 | resize_kwargs = dict(target_frames = frame_dims[0])
1143 |
1144 | video_kwargs = dict(
1145 | cond_video_frames = cond_video_frames,
1146 | post_cond_video_frames = post_cond_video_frames
1147 | )
1148 |
1149 | video_kwargs = compact(video_kwargs)
1150 |
1151 | # handle video conditioning frames
1152 |
1153 | if self.is_video and self.resize_cond_video_frames:
1154 | downsample_scale = self.temporal_downsample_factor[unet_number - 1]
1155 | temporal_downsample_fn = partial(
1156 | scale_video_time,
1157 | downsample_scale = downsample_scale
1158 | )
1159 | video_kwargs = maybe_transform_dict_key(
1160 | video_kwargs, 'cond_video_frames',
1161 | temporal_downsample_fn
1162 | )
1163 | video_kwargs = maybe_transform_dict_key(
1164 | video_kwargs, 'post_cond_video_frames',
1165 | temporal_downsample_fn
1166 | )
1167 |
1168 | # low resolution conditioning
1169 |
1170 | if unet.lowres_cond:
1171 | lowres_noise_times = self.lowres_noise_schedule.get_times(
1172 | batch_size, lowres_sample_noise_level, device = device
1173 | )
1174 |
1175 | lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
1176 | lowres_cond_img = self.normalize_img(lowres_cond_img)
1177 |
1178 | lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(
1179 | x_start = lowres_cond_img,
1180 | t = lowres_noise_times,
1181 | noise = torch.randn_like(lowres_cond_img)
1182 | )
1183 |
1184 | if exists(unet_init_images):
1185 | unet_init_images = self.resize_to(
1186 | unet_init_images, image_size, **resize_kwargs
1187 | )
1188 |
1189 | # --
1190 | # shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
1191 | # ++
1192 | shape = (batch_size, self.channels, *frame_dims, image_size)
1193 |
1194 | img = self.one_unet_sample(
1195 | unet,
1196 | shape,
1197 | unet_number = unet_number,
1198 | text_embeds = text_embeds,
1199 | text_mask = text_masks,
1200 | cond_images = cond_images,
1201 | inpaint_images = inpaint_images,
1202 | inpaint_masks = inpaint_masks,
1203 | inpaint_resample_times = inpaint_resample_times,
1204 | init_images = unet_init_images,
1205 | skip_steps = unet_skip_steps,
1206 | sigma_min = unet_sigma_min,
1207 | sigma_max = unet_sigma_max,
1208 | cond_scale = unet_cond_scale,
1209 | lowres_cond_img = lowres_cond_img,
1210 | lowres_noise_times = lowres_noise_times,
1211 | dynamic_threshold = dynamic_threshold,
1212 | use_tqdm = use_tqdm,
1213 | **video_kwargs
1214 | )
1215 |
1216 | outputs.append(img)
1217 |
1218 | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
1219 | break
1220 |
1221 | output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
1222 |
1223 | if not return_pil_images:
1224 | return outputs[output_index]
1225 |
1226 | if not return_all_unet_outputs:
1227 | outputs = outputs[-1:]
1228 |
1229 | # assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
1230 |
1231 | # pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
1232 |
1233 | # return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
1234 |
1235 | # end of sampling ===================================================================================
1236 |
1237 | # training
1238 |
1239 | def loss_weight(self, sigma_data, sigma):
1240 | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
1241 |
1242 | def noise_distribution(self, P_mean, P_std, batch_size):
1243 | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
1244 |
1245 | def forward(
1246 | self,
1247 | images, # rename to images or video
1248 | unet: Union[Unet_OneD, Unet3D, NullUnet, DistributedDataParallel] = None,
1249 | texts: List[str] = None,
1250 | text_embeds = None,
1251 | text_masks = None,
1252 | unet_number = None,
1253 | cond_images = None,
1254 | **kwargs
1255 | ):
1256 | # ++
1257 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1258 | print (f"Now, in EImagen.forward() ...")
1259 |
1260 | if self.is_video and images.ndim == 4:
1261 | # --
1262 | # images = rearrange(images, 'b c h w -> b c 1 h w')
1263 | # ++
1264 | images = rearrange(images, 'b c h -> b c 1 h')
1265 | kwargs.update(ignore_time = True)
1266 |
1267 | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
1268 | unet_number = default(unet_number, 1)
1269 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
1270 |
1271 | images = cast_uint8_images_to_float(images) # do nothing if input is not uint8: float btw (0, 1)
1272 | cond_images = maybe(cast_uint8_images_to_float)(cond_images)
1273 | # ++ for one_D need adjustment
1274 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1275 | print (f"Reformat images and cond_images into float")
1276 | print (f"images.shape: {images.shape}")
1277 | print (f"images.dtype: {images.dtype}")
1278 | print (f"max and min: {torch.max(images)} and {torch.min(images)}")
1279 | if not (cond_images==None):
1280 | print (f"cond_images.shape: {cond_images.shape}")
1281 | else:
1282 | print (f"cond_images: None")
1283 |
1284 | assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'
1285 |
1286 | unet_index = unet_number - 1
1287 |
1288 | unet = default(unet, lambda: self.get_unet(unet_number))
1289 |
1290 | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
1291 |
1292 | target_image_size = self.image_sizes[unet_index]
1293 | random_crop_size = self.random_crop_sizes[unet_index]
1294 | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
1295 | hp = self.hparams[unet_index]
1296 | # ++
1297 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1298 | print (f"target_image_size: {target_image_size}")
1299 | print (f"random_crop_size: {random_crop_size}")
1300 | print (f"prev_image_size: {prev_image_size}")
1301 | print (f"hp: {hp}")
1302 |
1303 | # --
1304 | # batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
1305 | # ++
1306 | batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4)
1307 | # ++
1308 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1309 | print (f"batch_size: {batch_size}")
1310 | print (f"channel c: {c}")
1311 | print (f"1d image size, h: {h} ")
1312 |
1313 |
1314 | frames = images.shape[2] if is_video else None
1315 | all_frame_dims = tuple(
1316 | safe_get_tuple_index(el, 0) for el in calc_all_frame_dims(
1317 | self.temporal_downsample_factor, frames
1318 | )
1319 | )
1320 | ignore_time = kwargs.get('ignore_time', False)
1321 | # ++
1322 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1323 | print (f"frames: {frames}")
1324 | print (f"all_frame_dims: {all_frame_dims}")
1325 | print (f"ignore_time: {ignore_time}")
1326 |
1327 | target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
1328 | prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
1329 | frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
1330 | # ++
1331 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1332 | print (f"target_frame_size: {target_frame_size}")
1333 | print (f"prev_frame_size: {prev_frame_size}")
1334 | print (f"frames_to_resize_kwargs: {frames_to_resize_kwargs}")
1335 |
1336 | assert images.shape[1] == self.channels
1337 | assert h >= target_image_size # and w >= target_image_size
1338 |
1339 | # texts provided, not text_embeds
1340 | #
1341 | if exists(texts) and not exists(text_embeds) and not self.unconditional:
1342 | assert all([*map(len, texts)]), 'text cannot be empty'
1343 | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
1344 |
1345 | with autocast(enabled = False):
1346 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
1347 |
1348 | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
1349 | # now we have text_embeds, and text_masks
1350 |
1351 | if not self.unconditional:
1352 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
1353 | # ++
1354 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1355 | # print (f"text_masks: \n{text_masks}")
1356 | print (f"text_masks.shape: \n{text_masks.shape}")
1357 |
1358 | assert not (
1359 | self.condition_on_text and not exists(text_embeds)
1360 | ), 'text or text encodings must be passed into decoder if specified'
1361 | assert not (
1362 | not self.condition_on_text and exists(text_embeds)
1363 | ), 'decoder specified not to be conditioned on text, yet it is presented'
1364 |
1365 | assert not (
1366 | exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim
1367 | ), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
1368 |
1369 | # handle video conditioning frames
1370 |
1371 | if self.is_video and self.resize_cond_video_frames:
1372 | downsample_scale = self.temporal_downsample_factor[unet_index]
1373 | temporal_downsample_fn = partial(scale_video_time, downsample_scale = downsample_scale)
1374 | kwargs = maybe_transform_dict_key(kwargs, 'cond_video_frames', temporal_downsample_fn)
1375 | kwargs = maybe_transform_dict_key(kwargs, 'post_cond_video_frames', temporal_downsample_fn)
1376 |
1377 | # low resolution conditioning
1378 | # this part is on if the trained one is the 2nd unet
1379 |
1380 | lowres_cond_img = lowres_aug_times = None
1381 | if exists(prev_image_size): # so, this is the 2nd unet
1382 | # ++
1383 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1384 | print (f"prev_image_size detected. So, this is for 2nd UNet")
1385 | print (f"Create lowres_cond_img by resizing true image")
1386 | print (f" images.shape: {images.shape}")
1387 | lowres_cond_img = self.resize_to(
1388 | images,
1389 | prev_image_size,
1390 | **frames_to_resize_kwargs(prev_frame_size),
1391 | clamp_range = self.input_image_range
1392 | )
1393 | # ++
1394 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1395 | print (f" 1. resize full image to previous size: coarsening")
1396 | print (f" .resize_to(images,prev_image_size)->lowres_cond_img.shape: {lowres_cond_img.shape}")
1397 | lowres_cond_img = self.resize_to(
1398 | lowres_cond_img, target_image_size,
1399 | **frames_to_resize_kwargs(target_frame_size),
1400 | clamp_range = self.input_image_range
1401 | )
1402 | # ++
1403 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1404 | print (f" 2. fit into the traget size: only change size")
1405 | print (f" .resize_to(lowres_cond_img,target_image_size)->lowres_cond_img.shape: {lowres_cond_img.shape}")
1406 |
1407 | if self.per_sample_random_aug_noise_level:
1408 | lowres_aug_times = self.lowres_noise_schedule.sample_random_times(
1409 | batch_size, device = device
1410 | )
1411 | else: # i.e., all samples in the batch use the same
1412 | lowres_aug_time = self.lowres_noise_schedule.sample_random_times(
1413 | 1, device = device
1414 | )
1415 | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
1416 | # ++
1417 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1418 | print (f"get noise schedule for lowres_cond_img, lowres_aug_time.shape: {lowres_aug_times.shape}")
1419 |
1420 | # ++
1421 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1422 | print (f"images.shape: {images.shape}")
1423 | images = self.resize_to(
1424 | images,
1425 | target_image_size,
1426 | **frames_to_resize_kwargs(target_frame_size)
1427 | )
1428 | # not triggered if images.shape[-1]==target_image_size
1429 | # ++
1430 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1431 | print (f".resize_to() -> images.shape: {images.shape}")
1432 |
1433 | # normalize to [-1, 1]
1434 | # ++
1435 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1436 | print (f"Bef. normalize_img, min/ax of images: {torch.max(images)}, {torch.min(images)}")
1437 | print (f".normalize_img: {self.normalize_img}")
1438 | images = self.normalize_img(images) # assume images (0,1)->(-1,1)
1439 | lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
1440 | # ++
1441 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1442 | print (f"After normalize_img, should be [-1,1]")
1443 | print (f"images max and min: {torch.max(images)} and {torch.min(images)}")
1444 | if exists(lowres_cond_img):
1445 | print (f"lowres_cond_img max and min: {torch.max(lowres_cond_img)} and {torch.min(lowres_cond_img)}")
1446 |
1447 | # random cropping during training
1448 | # for upsamplers
1449 |
1450 | if exists(random_crop_size):
1451 | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
1452 |
1453 | if is_video:
1454 | images, lowres_cond_img = map(
1455 | # --
1456 | # lambda t: rearrange(t, 'b c f h w -> (b f) c h w'),
1457 | # ++
1458 | lambda t: rearrange(t, 'b c f h -> (b f) c h'),
1459 | (images, lowres_cond_img)
1460 | )
1461 |
1462 | # make sure low res conditioner and image both get augmented the same way
1463 | # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
1464 | images = aug(images)
1465 | lowres_cond_img = aug(lowres_cond_img, params = aug._params)
1466 |
1467 | if is_video:
1468 | images, lowres_cond_img = map(
1469 | # --
1470 | # lambda t: rearrange(t, '(b f) c h w -> b c f h w', f = frames),
1471 | # ++
1472 | lambda t: rearrange(t, '(b f) c h -> b c f h', f = frames),
1473 | (images, lowres_cond_img)
1474 | )
1475 |
1476 | # noise the lowres conditioning image
1477 | # at sample time, they then fix the noise level of 0.1 - 0.3
1478 |
1479 | lowres_cond_img_noisy = None
1480 | if exists(lowres_cond_img):
1481 | lowres_cond_img_noisy, *_ = self.lowres_noise_schedule.q_sample(
1482 | x_start = lowres_cond_img,
1483 | t = lowres_aug_times,
1484 | noise = torch.randn_like(lowres_cond_img)
1485 | )
1486 | # ++
1487 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1488 | print (f"add noise to lowres_cond_img...")
1489 | print (f"lowres_cond_img_noisy.shape: {lowres_cond_img_noisy.shape}")
1490 |
1491 |
1492 | # get the sigmas
1493 |
1494 | sigmas = self.noise_distribution(
1495 | hp.P_mean, hp.P_std, batch_size
1496 | )
1497 | padded_sigmas = self.right_pad_dims_to_datatype(sigmas)
1498 | # ++
1499 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1500 | print (f"sigmas.shape: {sigmas.shape}")
1501 | print (f"padded_sigmas.shape: {padded_sigmas.shape}")
1502 |
1503 | # noise
1504 |
1505 | noise = torch.randn_like(images)
1506 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
1507 | # ++
1508 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1509 | print (f"add noise into images")
1510 |
1511 | # unet kwargs
1512 |
1513 | unet_kwargs = dict(
1514 | sigma_data = hp.sigma_data,
1515 | text_embeds = text_embeds,
1516 | text_mask = text_masks,
1517 | cond_images = cond_images,
1518 | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
1519 | lowres_cond_img = lowres_cond_img_noisy,
1520 | cond_drop_prob = self.cond_drop_prob,
1521 | **kwargs
1522 | )
1523 |
1524 | # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
1525 |
1526 | # Because 'unet' can be an instance of DistributedDataParallel coming from the
1527 | # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
1528 | # access the member 'module' of the wrapped unet instance.
1529 | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet.self_cond
1530 |
1531 | if self_cond and random() < 0.5:
1532 | # ++
1533 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1534 | print (f"self_cond is triggered.")
1535 | print (f"get into unet.......")
1536 |
1537 | with torch.no_grad():
1538 | pred_x0 = self.preconditioned_network_forward(
1539 | unet.forward,
1540 | noised_images,
1541 | sigmas,
1542 | **unet_kwargs
1543 | ).detach()
1544 | # ++
1545 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1546 | print (f"get out of unet.......")
1547 | print (f"prop noised_images via the net.")
1548 | print (f"get pred_x0.shape: {pred_x0.shape}")
1549 |
1550 | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
1551 |
1552 | # get prediction
1553 | # ++
1554 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1555 | # print (f"unet_kwargs: \n{unet_kwargs}")
1556 | print (f"unet_kwargs include keys: ")
1557 | for this_key in unet_kwargs.keys():
1558 | print (" "+this_key)
1559 | if torch.is_tensor(unet_kwargs[this_key]):
1560 | print (f" {unet_kwargs[this_key].shape}")
1561 | else:
1562 | print (f" {type(unet_kwargs[this_key])}")
1563 |
1564 | denoised_images = self.preconditioned_network_forward(
1565 | unet.forward,
1566 | noised_images,
1567 | sigmas,
1568 | **unet_kwargs
1569 | )
1570 | # ++
1571 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1572 | print (f"Back to EImagen, get denoised_images")
1573 | print (f"denoised_images.shape: {denoised_images.shape}")
1574 |
1575 | # losses
1576 |
1577 | losses = F.mse_loss(denoised_images, images, reduction = 'none')
1578 | losses = reduce(losses, 'b ... -> b', 'mean')
1579 |
1580 | # loss weighting
1581 |
1582 | losses = losses * self.loss_weight(hp.sigma_data, sigmas)
1583 | # ++
1584 | if self.CKeys['Debug_Level']==Imagen_Forw_Level:
1585 | print (f"postprocess losses based on hp.sigma_data")
1586 |
1587 | # return average loss
1588 |
1589 | return losses.mean()
1590 |
--------------------------------------------------------------------------------