├── .gitignore
├── LICENSE
├── README.md
├── ckpt
└── put_your_model_here.txt
├── configs
├── inference
│ ├── config_16z.yaml
│ ├── config_16z_cap.yaml
│ ├── config_4z.yaml
│ └── config_4z_cap.yaml
└── train
│ ├── config_16z.yaml
│ ├── config_16z_cap.yaml
│ ├── config_16z_joint.yaml
│ ├── config_4z.yaml
│ ├── config_4z_cap.yaml
│ └── config_4z_joint.yaml
├── data
├── dataset.py
└── lightning_data.py
├── docs
├── case1
│ ├── fkanimal2.gif
│ └── gtanimal2.gif
├── case2
│ ├── fkcloseshot1.gif
│ └── gtcloseshot1.gif
├── case3
│ ├── fkface.gif
│ └── gtface.gif
├── case4
│ ├── fkmotion4.gif
│ └── gtmotion4.gif
├── case5
│ ├── fkview7.gif
│ └── gtview7.gif
└── sota-table.png
├── evaluation
├── compute_metrics.py
└── compute_metrics_img.py
├── examples
├── images
│ ├── gt
│ │ ├── 00000091.jpg
│ │ ├── 00000103.jpg
│ │ ├── 00000110.jpg
│ │ ├── 00000212.jpg
│ │ ├── 00000268.jpg
│ │ ├── 00000592.jpg
│ │ ├── 00006871.jpg
│ │ ├── 00007252.jpg
│ │ ├── 00007826.jpg
│ │ └── 00008868.jpg
│ └── recon
│ │ ├── 00000091.jpeg
│ │ ├── 00000110.jpeg
│ │ ├── 00000212.jpeg
│ │ ├── 00000268.jpeg
│ │ ├── 00000592.jpeg
│ │ ├── 00006871.jpeg
│ │ ├── 00007252.jpeg
│ │ ├── 00007826.jpeg
│ │ └── 00008868.jpeg
└── videos
│ ├── gt
│ ├── 40.mp4
│ ├── 40.txt
│ ├── 8.mp4
│ ├── 8.txt
│ ├── animal.mp4
│ ├── animal.txt
│ ├── closeshot.mp4
│ ├── closeshot.txt
│ ├── face.mp4
│ ├── face.txt
│ ├── view.mp4
│ └── view.txt
│ └── recon
│ ├── 40_reconstructed.mp4
│ ├── 8_reconstructed.mp4
│ ├── animal_reconstructed.mp4
│ ├── closeshot_reconstructed.mp4
│ ├── face_reconstructed.mp4
│ └── view_reconstructed.mp4
├── inference_image.py
├── inference_video.py
├── requirements.txt
├── scripts
├── evaluation_image.sh
├── evaluation_video.sh
├── run_inference_image.sh
├── run_inference_video.sh
└── run_train.sh
├── src
├── distributions.py
├── models
│ ├── autoencoder.py
│ ├── autoencoder2plus1d_1dcnn.py
│ └── autoencoder_temporal.py
└── modules
│ ├── ae_modules.py
│ ├── attention_temporal_videoae.py
│ ├── losses
│ ├── __init__.py
│ └── contperceptual.py
│ ├── t5.py
│ └── utils.py
├── train.py
└── utils
├── callbacks.py
├── common_utils.py
├── save_video.py
└── train_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pyc
3 |
4 | .vscode
5 | .DS_Store
6 | .idea
7 | .git
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-NoDerivatives 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
58 | International Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-NoDerivatives 4.0 International Public
63 | License ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Copyright and Similar Rights means copyright and/or similar rights
84 | closely related to copyright including, without limitation,
85 | performance, broadcast, sound recording, and Sui Generis Database
86 | Rights, without regard to how the rights are labeled or
87 | categorized. For purposes of this Public License, the rights
88 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
89 | Rights.
90 |
91 | c. Effective Technological Measures means those measures that, in the
92 | absence of proper authority, may not be circumvented under laws
93 | fulfilling obligations under Article 11 of the WIPO Copyright
94 | Treaty adopted on December 20, 1996, and/or similar international
95 | agreements.
96 |
97 | d. Exceptions and Limitations means fair use, fair dealing, and/or
98 | any other exception or limitation to Copyright and Similar Rights
99 | that applies to Your use of the Licensed Material.
100 |
101 | e. Licensed Material means the artistic or literary work, database,
102 | or other material to which the Licensor applied this Public
103 | License.
104 |
105 | f. Licensed Rights means the rights granted to You subject to the
106 | terms and conditions of this Public License, which are limited to
107 | all Copyright and Similar Rights that apply to Your use of the
108 | Licensed Material and that the Licensor has authority to license.
109 |
110 | g. Licensor means the individual(s) or entity(ies) granting rights
111 | under this Public License.
112 |
113 | h. NonCommercial means not primarily intended for or directed towards
114 | commercial advantage or monetary compensation. For purposes of
115 | this Public License, the exchange of the Licensed Material for
116 | other material subject to Copyright and Similar Rights by digital
117 | file-sharing or similar means is NonCommercial provided there is
118 | no payment of monetary compensation in connection with the
119 | exchange.
120 |
121 | i. Share means to provide material to the public by any means or
122 | process that requires permission under the Licensed Rights, such
123 | as reproduction, public display, public performance, distribution,
124 | dissemination, communication, or importation, and to make material
125 | available to the public including in ways that members of the
126 | public may access the material from a place and at a time
127 | individually chosen by them.
128 |
129 | j. Sui Generis Database Rights means rights other than copyright
130 | resulting from Directive 96/9/EC of the European Parliament and of
131 | the Council of 11 March 1996 on the legal protection of databases,
132 | as amended and/or succeeded, as well as other essentially
133 | equivalent rights anywhere in the world.
134 |
135 | k. You means the individual or entity exercising the Licensed Rights
136 | under this Public License. Your has a corresponding meaning.
137 |
138 |
139 | Section 2 -- Scope.
140 |
141 | a. License grant.
142 |
143 | 1. Subject to the terms and conditions of this Public License,
144 | the Licensor hereby grants You a worldwide, royalty-free,
145 | non-sublicensable, non-exclusive, irrevocable license to
146 | exercise the Licensed Rights in the Licensed Material to:
147 |
148 | a. reproduce and Share the Licensed Material, in whole or
149 | in part, for NonCommercial purposes only; and
150 |
151 | b. produce and reproduce, but not Share, Adapted Material
152 | for NonCommercial purposes only.
153 |
154 | 2. Exceptions and Limitations. For the avoidance of doubt, where
155 | Exceptions and Limitations apply to Your use, this Public
156 | License does not apply, and You do not need to comply with
157 | its terms and conditions.
158 |
159 | 3. Term. The term of this Public License is specified in Section
160 | 6(a).
161 |
162 | 4. Media and formats; technical modifications allowed. The
163 | Licensor authorizes You to exercise the Licensed Rights in
164 | all media and formats whether now known or hereafter created,
165 | and to make technical modifications necessary to do so. The
166 | Licensor waives and/or agrees not to assert any right or
167 | authority to forbid You from making technical modifications
168 | necessary to exercise the Licensed Rights, including
169 | technical modifications necessary to circumvent Effective
170 | Technological Measures. For purposes of this Public License,
171 | simply making modifications authorized by this Section 2(a)
172 | (4) never produces Adapted Material.
173 |
174 | 5. Downstream recipients.
175 |
176 | a. Offer from the Licensor -- Licensed Material. Every
177 | recipient of the Licensed Material automatically
178 | receives an offer from the Licensor to exercise the
179 | Licensed Rights under the terms and conditions of this
180 | Public License.
181 |
182 | b. No downstream restrictions. You may not offer or impose
183 | any additional or different terms or conditions on, or
184 | apply any Effective Technological Measures to, the
185 | Licensed Material if doing so restricts exercise of the
186 | Licensed Rights by any recipient of the Licensed
187 | Material.
188 |
189 | 6. No endorsement. Nothing in this Public License constitutes or
190 | may be construed as permission to assert or imply that You
191 | are, or that Your use of the Licensed Material is, connected
192 | with, or sponsored, endorsed, or granted official status by,
193 | the Licensor or others designated to receive attribution as
194 | provided in Section 3(a)(1)(A)(i).
195 |
196 | b. Other rights.
197 |
198 | 1. Moral rights, such as the right of integrity, are not
199 | licensed under this Public License, nor are publicity,
200 | privacy, and/or other similar personality rights; however, to
201 | the extent possible, the Licensor waives and/or agrees not to
202 | assert any such rights held by the Licensor to the limited
203 | extent necessary to allow You to exercise the Licensed
204 | Rights, but not otherwise.
205 |
206 | 2. Patent and trademark rights are not licensed under this
207 | Public License.
208 |
209 | 3. To the extent possible, the Licensor waives any right to
210 | collect royalties from You for the exercise of the Licensed
211 | Rights, whether directly or through a collecting society
212 | under any voluntary or waivable statutory or compulsory
213 | licensing scheme. In all other cases the Licensor expressly
214 | reserves any right to collect such royalties, including when
215 | the Licensed Material is used other than for NonCommercial
216 | purposes.
217 |
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material, You must:
227 |
228 | a. retain the following if it is supplied by the Licensor
229 | with the Licensed Material:
230 |
231 | i. identification of the creator(s) of the Licensed
232 | Material and any others designated to receive
233 | attribution, in any reasonable manner requested by
234 | the Licensor (including by pseudonym if
235 | designated);
236 |
237 | ii. a copyright notice;
238 |
239 | iii. a notice that refers to this Public License;
240 |
241 | iv. a notice that refers to the disclaimer of
242 | warranties;
243 |
244 | v. a URI or hyperlink to the Licensed Material to the
245 | extent reasonably practicable;
246 |
247 | b. indicate if You modified the Licensed Material and
248 | retain an indication of any previous modifications; and
249 |
250 | c. indicate the Licensed Material is licensed under this
251 | Public License, and include the text of, or the URI or
252 | hyperlink to, this Public License.
253 |
254 | For the avoidance of doubt, You do not have permission under
255 | this Public License to Share Adapted Material.
256 |
257 | 2. You may satisfy the conditions in Section 3(a)(1) in any
258 | reasonable manner based on the medium, means, and context in
259 | which You Share the Licensed Material. For example, it may be
260 | reasonable to satisfy the conditions by providing a URI or
261 | hyperlink to a resource that includes the required
262 | information.
263 |
264 | 3. If requested by the Licensor, You must remove any of the
265 | information required by Section 3(a)(1)(A) to the extent
266 | reasonably practicable.
267 |
268 |
269 | Section 4 -- Sui Generis Database Rights.
270 |
271 | Where the Licensed Rights include Sui Generis Database Rights that
272 | apply to Your use of the Licensed Material:
273 |
274 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
275 | to extract, reuse, reproduce, and Share all or a substantial
276 | portion of the contents of the database for NonCommercial purposes
277 | only and provided You do not Share Adapted Material;
278 |
279 | b. if You include all or a substantial portion of the database
280 | contents in a database in which You have Sui Generis Database
281 | Rights, then the database in which You have Sui Generis Database
282 | Rights (but not its individual contents) is Adapted Material; and
283 |
284 | c. You must comply with the conditions in Section 3(a) if You Share
285 | all or a substantial portion of the contents of the database.
286 |
287 | For the avoidance of doubt, this Section 4 supplements and does not
288 | replace Your obligations under this Public License where the Licensed
289 | Rights include other Copyright and Similar Rights.
290 |
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 |
350 | Section 7 -- Other Terms and Conditions.
351 |
352 | a. The Licensor shall not be bound by any additional or different
353 | terms or conditions communicated by You unless expressly agreed.
354 |
355 | b. Any arrangements, understandings, or agreements regarding the
356 | Licensed Material not stated herein are separate from and
357 | independent of the terms and conditions of this Public License.
358 |
359 |
360 | Section 8 -- Interpretation.
361 |
362 | a. For the avoidance of doubt, this Public License does not, and
363 | shall not be interpreted to, reduce, limit, restrict, or impose
364 | conditions on any use of the Licensed Material that could lawfully
365 | be made without permission under this Public License.
366 |
367 | b. To the extent possible, if any provision of this Public License is
368 | deemed unenforceable, it shall be automatically reformed to the
369 | minimum extent necessary to make it enforceable. If the provision
370 | cannot be reformed, it shall be severed from this Public License
371 | without affecting the enforceability of the remaining terms and
372 | conditions.
373 |
374 | c. No term or condition of this Public License will be waived and no
375 | failure to comply consented to unless expressly agreed to by the
376 | Licensor.
377 |
378 | d. Nothing in this Public License constitutes or may be interpreted
379 | as a limitation upon, or waiver of, any privileges and immunities
380 | that apply to the Licensor or You, including from the legal
381 | processes of any jurisdiction or authority.
382 |
383 | =======================================================================
384 |
385 | Creative Commons is not a party to its public
386 | licenses. Notwithstanding, Creative Commons may elect to apply one of
387 | its public licenses to material it publishes and in those instances
388 | will be considered the “Licensor.†The text of the Creative Commons
389 | public licenses is dedicated to the public domain under the CC0 Public
390 | Domain Dedication. Except for the limited purpose of indicating that
391 | material is shared under a Creative Commons public license or as
392 | otherwise permitted by the Creative Commons policies published at
393 | creativecommons.org/policies, Creative Commons does not authorize the
394 | use of the trademark "Creative Commons" or any other trademark or logo
395 | of Creative Commons without its prior written consent including,
396 | without limitation, in connection with any unauthorized modifications
397 | to any of its public licenses or any other arrangements,
398 | understandings, or agreements concerning use of licensed material. For
399 | the avoidance of doubt, this paragraph does not form part of the
400 | public licenses.
401 |
402 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VideoVAE+: Large Motion Video Autoencoding with Cross-modal Video VAE
2 |
3 | | Ground Truth (GT) | Reconstructed |
4 | |-------------------|---------------|
5 | |
|
|
6 | |
|
|
7 | |
|
|
8 | |
|
|
9 | |
|
|
10 |
11 | [Yazhou Xing](https://yzxing87.github.io)\*, [Yang Fei](https://sunfly04.github.io)\*, [Yingqing He](https://yingqinghe.github.io)\*†, [Jingye Chen](https://jingyechen.github.io), [Jiaxin Xie](https://jiaxinxie97.github.io/Jiaxin-Xie), [Xiaowei Chi](https://scholar.google.com/citations?user=Vl1X_-sAAAAJ&hl=zh-CN), [Qifeng Chen](https://cqf.io/)† (*equal contribution, †corresponding author)
12 | *The Hong Kong University of Science and Technology*
13 |
14 | #### [Project Page](https://yzxing87.github.io/vae/) | [Paper](https://arxiv.org/abs/2412.17805) | [High-Res Demo](https://www.youtube.com/embed/Kb4rn9z9xAA)
15 |
16 | A state-of-the-art **Video Variational Autoencoder (VAE)** designed for high-fidelity video reconstruction. This project leverages cross-modal and joint video-image training to enhance reconstruction quality.
17 |
18 | ---
19 |
20 | ## ✨ Features
21 |
22 | - **High-Fidelity Reconstruction**: Achieve superior image and video reconstruction quality.
23 | - **Cross-Modal Reconstruction**: Utilize captions to guide the reconstruction process.
24 | - **State-of-the-Art Performance**: Set new benchmarks in video reconstruction tasks.
25 |
26 | 
27 | ---
28 |
29 | ## 📰 News
30 | - [Jan 2025] 🏋️ Released training code & better pretrained 4z-text weight
31 | - [Dec 2024] 🚀 Released inference code and pretrained models
32 | - [Dec 2024] 📝 Released paper on [arXiv](https://arxiv.org/abs/2412.17805)
33 | - [Dec 2024] 💡 Project page is live at [VideoVAE+](https://yzxing87.github.io/vae/)
34 |
35 | ---
36 |
37 | ## ⏰ Todo
38 |
39 | - [x] **Release Pretrained Model Weights**
40 | - [x] **Release Inference Code**
41 | - [x] **Release Training Code**
42 |
43 | ---
44 |
45 | ## 🚀 Get Started
46 |
47 | Follow these steps to set up your environment and run the code:
48 |
49 | ### 1. Clone the Repository
50 |
51 | ```bash
52 | git clone https://github.com/VideoVerses/VideoVAEPlus.git
53 | cd VideoVAEPlus
54 | ```
55 |
56 | ### 2. Set Up the Environment
57 |
58 | Create a Conda environment and install dependencies:
59 |
60 | ```bash
61 | conda create --name vae python=3.10 -y
62 | conda activate vae
63 | pip install -r requirements.txt
64 | ```
65 |
66 | ---
67 |
68 | ## 📦 Pretrained Models
69 |
70 | | Model Name | Latent Channels | Download Link |
71 | |-----------------|-----------------|------------------|
72 | | sota-4z | 4 | [Download](https://drive.google.com/file/d/1WEKBdRFjEUxwcBgX_thckXklD8s6dDTj/view?usp=drive_link) |
73 | | sota-4z-text | 4 | [Download](https://drive.google.com/file/d/1QfqrKIWu5zG10U-xRgeF8Dhp__njC8OH/view?usp=sharing) |
74 | | sota-16z | 16 | [Download](https://drive.google.com/file/d/13v2Pq6dG1jo7RNImxNOXr9-WizgMiJ7M/view?usp=sharing) |
75 | | sota-16z-text | 16 | [Download](https://drive.google.com/file/d/1iYCAtmdaOX0V41p0vbt_6g8kRS1EK56p/view?usp=sharing) |
76 |
77 | - **Note**: '4z' and '16z' indicate the number of latent channels in the VAE model. Models with 'text' support text guidance.
78 |
79 | ---
80 |
81 | ## 📁 Data Preparation
82 |
83 | To reconstruct videos and images using our VAE model, organize your data in the following structure:
84 |
85 | ### Videos
86 |
87 | Place your videos and optional captions in the `examples/videos/gt` directory.
88 |
89 | #### Directory Structure:
90 |
91 | ```
92 | examples/videos/
93 | ├── gt/
94 | │ ├── video1.mp4
95 | │ ├── video1.txt # Optional caption
96 | │ ├── video2.mp4
97 | │ ├── video2.txt
98 | │ └── ...
99 | ├── recon/
100 | └── (reconstructed videos will be saved here)
101 | ```
102 |
103 | - **Captions**: For cross-modal reconstruction, include a `.txt` file with the same name as the video containing its caption.
104 |
105 | ### Images
106 |
107 | Place your images in the `examples/images/gt` directory.
108 |
109 | #### Directory Structure:
110 |
111 | ```
112 | examples/images/
113 | ├── gt/
114 | │ ├── image1.jpg
115 | │ ├── image2.png
116 | │ └── ...
117 | ├── recon/
118 | └── (reconstructed images will be saved here)
119 | ```
120 |
121 | - **Note**: The images dataset does not require captions.
122 |
123 | ---
124 |
125 | ## 🔧 Inference
126 |
127 | Our video VAE supports both image and video reconstruction.
128 |
129 | Please ensure that the `ckpt_path` in all your configuration files is set to the actual path of your checkpoint.
130 |
131 | ### Video Reconstruction
132 |
133 | Run video reconstruction using:
134 |
135 | ```bash
136 | bash scripts/run_inference_video.sh
137 | ```
138 |
139 | This is equivalent to:
140 |
141 | ```bash
142 | python inference_video.py \
143 | --data_root 'examples/videos/gt' \
144 | --out_root 'examples/videos/recon' \
145 | --config_path 'configs/inference/config_16z.yaml' \
146 | --chunk_size 8 \
147 | --resolution 720 1280
148 | ```
149 |
150 | - If the chunk size is too large, you may encounter memory issues. In this case, reduce the `chunk_size` parameter. Ensure the `chunk_size` is divisible by 4.
151 |
152 | - To enable cross-modal reconstruction using captions, modify `config_path` to `'configs/config_16z_cap.yaml'` for the 16-channel model with caption guidance.
153 |
154 | ### Image Reconstruction
155 |
156 | Run image reconstruction using:
157 |
158 | ```bash
159 | bash scripts/run_inference_image.sh
160 | ```
161 |
162 | This is equivalent to:
163 |
164 | ```bash
165 | python inference_image.py \
166 | --data_root 'examples/images/gt' \
167 | --out_root 'examples/images/recon' \
168 | --config_path 'configs/inference/config_16z.yaml' \
169 | --batch_size 1
170 | ```
171 |
172 | - **Note**: that the batch size is set to 1 because the images in the example folder have varying resolutions. If you have a batch of images with the same resolution, you can increase the batch size to accelerate inference.
173 |
174 | ---
175 |
176 | ## 🏋️ Training
177 |
178 | ### Quick Start
179 |
180 | To start training, use the following command:
181 |
182 | ```bash
183 | bash scripts/run_training.sh config_16z
184 | ```
185 |
186 | This default command trains the 16-channel model with video reconstruction on a single GPU.
187 |
188 | ### Configuration Options
189 |
190 | You can modify the training configuration by changing the config parameter:
191 |
192 | - `config_4z`: 4-channel model
193 | - `config_4z_joint`: 4-channel model trained jointly on both image and video data
194 | - `config_4z_cap`': 4-channel model with text guidance
195 | - `config_16z`: Default 16-channel model
196 | - `config_16z_joint`: 16-channel model trained jointly on both image and video data
197 | - `config_16z_cap`: 16-channel model with text guidance
198 |
199 | Note: Do not include the `.yaml` extension when specifying the config.
200 |
201 | ### Data Preparation
202 |
203 | #### Dataset Structure
204 | The training data should be organized in a CSV file with the following format:
205 |
206 | ```csv
207 | path,text
208 | /absolute/path/to/video1.mp4,A person walking on the beach
209 | /absolute/path/to/video2.mp4,A car driving down the road
210 | ```
211 |
212 | #### Requirements:
213 | - Use absolute paths for video files
214 | - Include two columns: path and text
215 | - For training without text guidance, leave the caption column empty but maintain the CSV structure
216 |
217 | #### Example CSV:
218 | ```csv
219 | # With captions
220 | /data/videos/clip1.mp4,A dog playing in the park
221 | /data/videos/clip2.mp4,Sunset over the ocean
222 |
223 | # Without captions
224 | /data/videos/clip1.mp4,
225 | /data/videos/clip2.mp4,
226 | ```
227 |
228 | ---
229 |
230 | ## 📊 Evaluation
231 |
232 | Use the provided scripts to evaluate reconstruction quality using **PSNR**, **SSIM**, and **LPIPS** metrics.
233 |
234 | ### Evaluate Image Reconstruction
235 |
236 | ```bash
237 | bash scripts/evaluation_image.sh
238 | ```
239 |
240 | ### Evaluate Video Reconstruction
241 |
242 | ```bash
243 | bash scripts/evaluation_video.sh
244 | ```
245 |
246 | ---
247 |
248 | ## 📝 License
249 |
250 | Please follow [CC-BY-NC-ND](./LICENSE).
251 |
252 | ## Star History
253 |
254 | [](https://star-history.com/#VideoVerses/VideoVAEPlus&Date)
255 |
--------------------------------------------------------------------------------
/ckpt/put_your_model_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/ckpt/put_your_model_here.txt
--------------------------------------------------------------------------------
/configs/inference/config_16z.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | ckpt_path: ckpt/sota-4-16z.ckpt
10 | input_dim: 5
11 | ignore_keys_3d: ['loss']
12 | caption_guide: False
13 | use_quant_conv: False
14 | img_video_joint_train: False
15 |
16 | lossconfig:
17 | target: src.modules.losses.LPIPSWithDiscriminator3D
18 | params:
19 | disc_start: 50001
20 | kl_weight: 0
21 | disc_weight: 0.5
22 |
23 | ddconfig:
24 | double_z: True
25 | z_channels: 16
26 | resolution: 216
27 | in_channels: 3
28 | out_ch: 3
29 | ch: 128
30 | ch_mult: [ 1,2,4,4 ]
31 | temporal_down_factor: 1
32 | num_res_blocks: 2
33 | attn_resolutions: []
34 | dropout: 0.0
35 |
36 | ppconfig:
37 | temporal_scale_factor: 4
38 | z_channels: 16
39 | out_ch: 16
40 | ch: 16
41 | attn_temporal_factor: []
--------------------------------------------------------------------------------
/configs/inference/config_16z_cap.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 #5.80e-04
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | img_video_joint_train: False
10 | caption_guide: True
11 | use_quant_conv: False
12 | t5_model_max_length: 100
13 |
14 | ckpt_path: ckpt/sota-4-16z-text.ckpt
15 | input_dim: 5
16 | ignore_keys_3d: ['loss']
17 |
18 | lossconfig:
19 | target: src.modules.losses.LPIPSWithDiscriminator3D
20 | params:
21 | disc_start: 50001
22 | kl_weight: 0
23 | disc_weight: 0.5
24 |
25 | ddconfig:
26 | double_z: True
27 | z_channels: 16
28 | resolution: 216
29 | in_channels: 3
30 | out_ch: 3
31 | ch: 128
32 | ch_mult: [ 1,2,4,4 ]
33 | temporal_down_factor: 1
34 | num_res_blocks: 2
35 | attn_resolutions: [27, 54, 108, 216]
36 | dropout: 0.0
37 |
38 | ppconfig:
39 | temporal_scale_factor: 4
40 | z_channels: 16
41 | out_ch: 16
42 | ch: 16
43 | attn_temporal_factor: [2,4]
--------------------------------------------------------------------------------
/configs/inference/config_4z.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 #5.80e-04
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | embed_dim: 4
8 | video_key: video
9 | image_key: video #jpg
10 | ckpt_path: ckpt/sota-4-4z.ckpt
11 | input_dim: 5
12 | ignore_keys_3d: ['loss']
13 |
14 | img_video_joint_train: False
15 | caption_guide: False
16 | use_quant_conv: True
17 |
18 | lossconfig:
19 | target: src.modules.losses.LPIPSWithDiscriminator3D
20 | params:
21 | disc_start: 50001
22 | kl_weight: 0
23 | disc_weight: 0.5
24 |
25 | ddconfig:
26 | double_z: True
27 | z_channels: 4
28 | resolution: 216
29 | in_channels: 3
30 | out_ch: 3
31 | ch: 128
32 | ch_mult: [ 1,2,4,4 ]
33 | temporal_down_factor: 1
34 | num_res_blocks: 2
35 | attn_resolutions: [ ]
36 | dropout: 0.0
37 |
38 | ppconfig:
39 | temporal_scale_factor: 4
40 | z_channels: 4
41 | out_ch: 4
42 | ch: 4
43 | attn_temporal_factor: []
--------------------------------------------------------------------------------
/configs/inference/config_4z_cap.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | img_video_joint_train: False
10 | caption_guide: True
11 | use_quant_conv: True
12 | t5_model_max_length: 100
13 |
14 | ckpt_path: ckpt/sota-4-4z-text.ckpt
15 | input_dim: 5
16 |
17 | lossconfig:
18 | target: src.modules.losses.LPIPSWithDiscriminator3D
19 | params:
20 | disc_start: 50001
21 | kl_weight: 0
22 | disc_weight: 0.5
23 |
24 | ddconfig:
25 | double_z: True
26 | z_channels: 4
27 | resolution: 216
28 | in_channels: 3
29 | out_ch: 3
30 | ch: 128
31 | ch_mult: [ 1,2,4,4 ]
32 | temporal_down_factor: 1
33 | num_res_blocks: 2
34 | attn_resolutions: [27, 54, 108, 216]
35 | dropout: 0.0
36 |
37 | ppconfig:
38 | temporal_scale_factor: 4
39 | z_channels: 4
40 | out_ch: 4
41 | ch: 4
42 | attn_temporal_factor: [2,4]
--------------------------------------------------------------------------------
/configs/train/config_16z.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 #5.80e-04
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | img_video_joint_train: False
10 | caption_guide: False
11 | use_quant_conv: False
12 |
13 | ignore_keys_3d: ['loss']
14 | ckpt_path: ckpt/sota-4-16z.ckpt
15 | input_dim: 5
16 |
17 | lossconfig:
18 | target: src.modules.losses.LPIPSWithDiscriminator3D
19 | params:
20 | disc_start: 50001
21 | kl_weight: 0.000001
22 | disc_weight: 0.5
23 |
24 | ddconfig:
25 | double_z: True
26 | z_channels: 16
27 | resolution: 216
28 | in_channels: 3
29 | out_ch: 3
30 | ch: 128
31 | ch_mult: [ 1,2,4,4 ]
32 | temporal_down_factor: 1
33 | num_res_blocks: 2
34 | attn_resolutions: [ ]
35 | dropout: 0.0
36 |
37 | ppconfig:
38 | temporal_scale_factor: 4
39 | z_channels: 16
40 | out_ch: 16
41 | ch: 16
42 | attn_temporal_factor: []
43 |
44 | data:
45 | target: data.lightning_data.DataModuleFromConfig
46 | params:
47 | img_video_joint_train: False
48 | batch_size: 1
49 | num_workers: 32
50 | wrap: false
51 | train:
52 | target: data.dataset.DatasetVideoLoader
53 | params:
54 | csv_file: path/to/your.csv
55 | resolution: [216, 216]
56 | video_length: 16
57 | subset_split: train
58 | validation:
59 | target: data.dataset.DatasetVideoLoader
60 | params:
61 | csv_file: path/to/your.csv
62 | resolution: [216, 216]
63 | video_length: 16
64 | subset_split: val
65 |
66 | lightning:
67 | find_unused_parameters: True
68 | callbacks:
69 | image_logger:
70 | target: utils.callbacks.ImageLogger
71 | params:
72 | batch_frequency: 1009
73 | max_images: 8
74 | metrics_over_trainsteps_checkpoint:
75 | target: pytorch_lightning.callbacks.ModelCheckpoint
76 | params:
77 | filename: '{epoch:06}-{step:09}'
78 | save_weights_only: False
79 | every_n_train_steps: 5000
80 | trainer:
81 | benchmark: True
82 | accumulate_grad_batches: 2
83 | batch_size: 1
84 | num_workers: 32
85 | max_epochs: 3000
86 | modelcheckpoint:
87 | target: pytorch_lightning.callbacks.ModelCheckpoint
88 | params:
89 | every_n_train_steps: 3000
90 | filename: "{epoch:04}-{step:06}"
91 |
--------------------------------------------------------------------------------
/configs/train/config_16z_cap.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | img_video_joint_train: False
10 | caption_guide: True
11 | use_quant_conv: False
12 | t5_model_max_length: 100
13 |
14 | ckpt_path: ckpt/sota-4-16z-text.ckpt
15 |
16 | input_dim: 5
17 | ignore_keys_3d: ['loss']
18 |
19 | lossconfig:
20 | target: src.modules.losses.LPIPSWithDiscriminator3D
21 | params:
22 | disc_start: 50001
23 | kl_weight: 0.000001
24 | disc_weight: 0.5
25 |
26 | ddconfig:
27 | double_z: True
28 | z_channels: 16
29 | resolution: 216
30 | in_channels: 3
31 | out_ch: 3
32 | ch: 128
33 | ch_mult: [ 1,2,4,4 ]
34 | temporal_down_factor: 1
35 | num_res_blocks: 2
36 | attn_resolutions: [27, 54, 108, 216]
37 | dropout: 0.0
38 |
39 | ppconfig:
40 | temporal_scale_factor: 4
41 | z_channels: 16
42 | out_ch: 16
43 | ch: 16
44 | attn_temporal_factor: [2, 4]
45 |
46 | data:
47 | target: data.lightning_data.DataModuleFromConfig
48 | params:
49 | batch_size: 1
50 | num_workers: 32
51 | wrap: false
52 | train:
53 | target: data.dataset.DatasetVideoLoader
54 | params:
55 | csv_file: path/to/your.csv
56 | resolution: [216, 216]
57 | video_length: 16
58 | subset_split: train
59 | validation:
60 | target: data.dataset.DatasetVideoLoader
61 | params:
62 | csv_file: path/to/your.csv
63 | resolution: [216, 216]
64 | video_length: 16
65 | subset_split: val
66 |
67 | lightning:
68 | find_unused_parameters: True
69 | callbacks:
70 | image_logger:
71 | target: utils.callbacks.ImageLogger
72 | params:
73 | batch_frequency: 509
74 | max_images: 8
75 | metrics_over_trainsteps_checkpoint:
76 | target: pytorch_lightning.callbacks.ModelCheckpoint
77 | params:
78 | filename: '{epoch:06}-{step:09}'
79 | save_weights_only: True
80 | every_n_train_steps: 5000
81 | trainer:
82 | benchmark: True
83 | accumulate_grad_batches: 2
84 | batch_size: 1
85 | num_workers: 32
86 | max_epochs: 3000
87 | modelcheckpoint:
88 | target: pytorch_lightning.callbacks.ModelCheckpoint
89 | params:
90 | every_n_train_steps: 3000
91 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/train/config_16z_joint.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 #5.80e-04
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | video_key: video
8 | image_key: video
9 | img_video_joint_train: True
10 | caption_guide: False
11 | use_quant_conv: False
12 |
13 | ignore_keys_3d: ['loss']
14 | ckpt_path: ckpt/sota-4-16z.ckpt
15 | input_dim: 5
16 |
17 | lossconfig:
18 | target: src.modules.losses.LPIPSWithDiscriminator
19 | params:
20 | disc_start: 50001
21 | kl_weight: 0.000001
22 | disc_weight: 0.5
23 |
24 | ddconfig:
25 | double_z: True
26 | z_channels: 16
27 | resolution: 216
28 | in_channels: 3
29 | out_ch: 3
30 | ch: 128
31 | ch_mult: [ 1,2,4,4 ]
32 | temporal_down_factor: 1
33 | num_res_blocks: 2
34 | attn_resolutions: [ ]
35 | dropout: 0.0
36 |
37 | ppconfig:
38 | temporal_scale_factor: 4
39 | z_channels: 16
40 | out_ch: 16
41 | ch: 16
42 | attn_temporal_factor: []
43 |
44 | data:
45 | target: data.lightning_data.DataModuleFromConfig
46 | params:
47 | img_video_joint_train: True
48 | batch_size: 1
49 | num_workers: 20
50 | wrap: false
51 | train:
52 | target: data.dataset.DatasetVideoLoader
53 | params:
54 | csv_file: path/to/your.csv
55 | resolution: [216, 216]
56 | video_length: 16
57 | subset_split: train
58 | validation:
59 | target: data.dataset.DatasetVideoLoader
60 | params:
61 | csv_file: path/to/your.csv
62 | resolution: [216, 216]
63 | video_length: 16
64 | subset_split: val
65 |
66 | lightning:
67 | find_unused_parameters: True
68 | callbacks:
69 | image_logger:
70 | target: utils.callbacks.ImageLogger
71 | params:
72 | batch_frequency: 1009
73 | max_images: 8
74 | metrics_over_trainsteps_checkpoint:
75 | target: pytorch_lightning.callbacks.ModelCheckpoint
76 | params:
77 | filename: '{epoch:06}-{step:09}'
78 | save_weights_only: False
79 | every_n_train_steps: 5000
80 | trainer:
81 | benchmark: True
82 | accumulate_grad_batches: 2
83 | batch_size: 1
84 | num_workers: 20
85 | max_epochs: 3000
86 | modelcheckpoint:
87 | target: pytorch_lightning.callbacks.ModelCheckpoint
88 | params:
89 | every_n_train_steps: 3000
90 | filename: "{epoch:04}-{step:06}"
91 |
--------------------------------------------------------------------------------
/configs/train/config_4z.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 #5.80e-04
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | embed_dim: 4
8 | video_key: video
9 | image_key: video #jpg
10 | ckpt_path: ckpt/sota-4-4z.ckpt
11 | input_dim: 5
12 | ignore_keys_3d: ['loss']
13 |
14 | img_video_joint_train: False
15 | caption_guide: False
16 | use_quant_conv: True
17 |
18 | lossconfig:
19 | target: src.modules.losses.LPIPSWithDiscriminator3D
20 | params:
21 | disc_start: 50001
22 | kl_weight: 0.000001
23 | disc_weight: 0.5
24 |
25 | ddconfig:
26 | double_z: True
27 | z_channels: 4
28 | resolution: 216
29 | in_channels: 3
30 | out_ch: 3
31 | ch: 128
32 | ch_mult: [ 1,2,4,4 ]
33 | temporal_down_factor: 1
34 | num_res_blocks: 2
35 | attn_resolutions: [ ]
36 | dropout: 0.0
37 |
38 | ppconfig:
39 | temporal_scale_factor: 4
40 | z_channels: 4
41 | out_ch: 4
42 | ch: 4 # 16*4
43 | attn_temporal_factor: []
44 |
45 | data:
46 | target: data.lightning_data.DataModuleFromConfig
47 | params:
48 | img_video_joint_train: False
49 | batch_size: 1
50 | num_workers: 32
51 | wrap: false
52 | train:
53 | target: data.dataset.DatasetVideoLoader
54 | params:
55 | csv_file: path/to/your.csv
56 | resolution: [216, 216]
57 | video_length: 16
58 | subset_split: train
59 | validation:
60 | target: data.dataset.DatasetVideoLoader
61 | params:
62 | csv_file: path/to/your.csv
63 | resolution: [216, 216]
64 | video_length: 16
65 | subset_split: val
66 |
67 | lightning:
68 | find_unused_parameters: True
69 | callbacks:
70 | image_logger:
71 | target: utils.callbacks.ImageLogger
72 | params:
73 | batch_frequency: 1009
74 | max_images: 8
75 | metrics_over_trainsteps_checkpoint:
76 | target: pytorch_lightning.callbacks.ModelCheckpoint
77 | params:
78 | filename: '{epoch:06}-{step:09}'
79 | save_weights_only: False
80 | every_n_train_steps: 5000
81 | trainer:
82 | benchmark: True
83 | accumulate_grad_batches: 2
84 | batch_size: 1
85 | num_workers: 32
86 | max_epochs: 3000
87 | modelcheckpoint:
88 | target: pytorch_lightning.callbacks.ModelCheckpoint
89 | params:
90 | every_n_train_steps: 3000
91 | filename: "{epoch:04}-{step:06}"
92 |
--------------------------------------------------------------------------------
/configs/train/config_4z_cap.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | embed_dim: 4
7 | monitor: "val/rec_loss"
8 | video_key: video
9 | image_key: video
10 | img_video_joint_train: False
11 | caption_guide: True
12 | use_quant_conv: True
13 | t5_model_max_length: 100
14 |
15 | ckpt_path: ckpt/sota-4-4z-text.ckpt
16 | input_dim: 5
17 |
18 | ignore_keys_3d: ['loss']
19 |
20 | lossconfig:
21 | target: src.modules.losses.LPIPSWithDiscriminator3D
22 | params:
23 | disc_start: 50001
24 | kl_weight: 0.000001
25 | disc_weight: 0.5
26 |
27 | ddconfig:
28 | double_z: True
29 | z_channels: 4
30 | resolution: 216
31 | in_channels: 3
32 | out_ch: 3
33 | ch: 128
34 | ch_mult: [ 1,2,4,4 ]
35 | temporal_down_factor: 1
36 | num_res_blocks: 2
37 | attn_resolutions: [27, 54, 108, 216]
38 | dropout: 0.0
39 |
40 | ppconfig:
41 | temporal_scale_factor: 4
42 | z_channels: 4
43 | out_ch: 4
44 | ch: 4
45 | attn_temporal_factor: [2,4]
46 |
47 | data:
48 | target: data.lightning_data.DataModuleFromConfig
49 | params:
50 | batch_size: 1
51 | num_workers: 32
52 | wrap: false
53 | train:
54 | target: data.dataset.DatasetVideoLoader
55 | params:
56 | csv_file: path/to/your.csv
57 | resolution: [216, 216]
58 | video_length: 16
59 | subset_split: train
60 | validation:
61 | target: data.dataset.DatasetVideoLoader
62 | params:
63 | csv_file: path/to/your.csv
64 | resolution: [216, 216]
65 | video_length: 16
66 | subset_split: val
67 |
68 | lightning:
69 | find_unused_parameters: True
70 | callbacks:
71 | image_logger:
72 | target: utils.callbacks.ImageLogger
73 | params:
74 | batch_frequency: 509
75 | max_images: 8
76 | metrics_over_trainsteps_checkpoint:
77 | target: pytorch_lightning.callbacks.ModelCheckpoint
78 | params:
79 | filename: '{epoch:06}-{step:09}'
80 | save_weights_only: True
81 | every_n_train_steps: 5000
82 | trainer:
83 | benchmark: True
84 | accumulate_grad_batches: 2
85 | batch_size: 1
86 | num_workers: 32
87 | max_epochs: 3000
88 | modelcheckpoint:
89 | target: pytorch_lightning.callbacks.ModelCheckpoint
90 | params:
91 | every_n_train_steps: 3000
92 | filename: "{epoch:04}-{step:06}"
--------------------------------------------------------------------------------
/configs/train/config_4z_joint.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4
3 | scale_lr: False
4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn
5 | params:
6 | monitor: "val/rec_loss"
7 | embed_dim: 4
8 | video_key: video
9 | image_key: video
10 | ckpt_path: ckpt/sota-4-4z.ckpt
11 | input_dim: 5
12 |
13 | ignore_keys_3d: ['loss']
14 |
15 | img_video_joint_train: True
16 | caption_guide: False
17 | use_quant_conv: True
18 |
19 | lossconfig:
20 | target: src.modules.losses.LPIPSWithDiscriminator
21 | params:
22 | disc_start: 50001
23 | kl_weight: 0.000001
24 | disc_weight: 0.5
25 |
26 | ddconfig:
27 | double_z: True
28 | z_channels: 4
29 | resolution: 216
30 | in_channels: 3
31 | out_ch: 3
32 | ch: 128
33 | ch_mult: [ 1,2,4,4 ]
34 | temporal_down_factor: 1
35 | num_res_blocks: 2
36 | attn_resolutions: [ ]
37 | dropout: 0.0
38 |
39 | ppconfig:
40 | temporal_scale_factor: 4
41 | z_channels: 4
42 | out_ch: 4
43 | ch: 4
44 | attn_temporal_factor: []
45 |
46 | data:
47 | target: data.lightning_data.DataModuleFromConfig
48 | params:
49 | img_video_joint_train: True
50 | batch_size: 1
51 | num_workers: 20
52 | wrap: false
53 | train:
54 | target: data.dataset.DatasetVideoLoader
55 | params:
56 | csv_file: path/to/your.csv
57 | resolution: [216, 216]
58 | video_length: 16
59 | subset_split: train
60 | validation:
61 | target: data.dataset.DatasetVideoLoader
62 | params:
63 | csv_file: path/to/your.csv
64 | resolution: [216, 216]
65 | video_length: 16
66 | subset_split: val
67 |
68 | lightning:
69 | find_unused_parameters: True
70 | callbacks:
71 | image_logger:
72 | target: utils.callbacks.ImageLogger
73 | params:
74 | batch_frequency: 1009
75 | max_images: 8
76 | metrics_over_trainsteps_checkpoint:
77 | target: pytorch_lightning.callbacks.ModelCheckpoint
78 | params:
79 | filename: '{epoch:06}-{step:09}'
80 | save_weights_only: False
81 | every_n_train_steps: 5000
82 | trainer:
83 | benchmark: True
84 | accumulate_grad_batches: 2
85 | batch_size: 1
86 | num_workers: 20
87 | max_epochs: 3000
88 | modelcheckpoint:
89 | target: pytorch_lightning.callbacks.ModelCheckpoint
90 | params:
91 | every_n_train_steps: 3000
92 | filename: "{epoch:04}-{step:06}"
93 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset
5 | from decord import VideoReader, cpu
6 | import pandas as pd
7 |
8 |
9 | class DatasetVideoLoader(Dataset):
10 | """
11 | Dataset for loading videos and captions from a CSV file.
12 | CSV file contains two columns: 'path' and 'text', where:
13 | - 'path' is the path to the video file
14 | - 'text' is the caption for the video.
15 | """
16 |
17 | def __init__(
18 | self,
19 | csv_file,
20 | resolution,
21 | video_length,
22 | frame_stride=4,
23 | subset_split="all",
24 | clip_length=1.0,
25 | random_stride=False,
26 | mode="video",
27 | ):
28 | self.csv_file = csv_file
29 | self.resolution = resolution
30 | self.video_length = video_length
31 | self.subset_split = subset_split
32 | self.frame_stride = frame_stride
33 | self.clip_length = clip_length
34 | self.random_stride = random_stride
35 | self.mode = mode
36 |
37 | assert self.subset_split in ["train", "test", "val", "all"]
38 | self.exts = ["avi", "mp4", "webm"]
39 |
40 | if isinstance(self.resolution, int):
41 | self.resolution = [self.resolution, self.resolution]
42 |
43 | # Load dataset from CSV file
44 | self._make_dataset()
45 |
46 | def _make_dataset(self):
47 | """
48 | Load video paths and captions from the CSV file.
49 | """
50 | self.videos = pd.read_csv(self.csv_file)
51 | print(f"Loaded {len(self.videos)} videos from {self.csv_file}")
52 |
53 | if self.subset_split == "val":
54 | self.videos = self.videos[-300:]
55 | elif self.subset_split == "train":
56 | self.videos = self.videos[:-300]
57 | elif self.subset_split == "test":
58 | self.videos = self.videos[-30:]
59 |
60 | print(f"Number of videos = {len(self.videos)}")
61 |
62 | # Create video indices for image mode
63 | self.video_indices = list(range(len(self.videos)))
64 |
65 | def set_mode(self, mode):
66 | self.mode = mode
67 |
68 | def _get_video_path(self, index):
69 | return self.videos.iloc[index]["path"]
70 |
71 | def __getitem__(self, index):
72 | if self.mode == "image":
73 | return self.__getitem__images(index)
74 | else:
75 | return self.__getitem__video(index)
76 |
77 | def __getitem__video(self, index):
78 | while True:
79 | video_path = self.videos.iloc[index]["path"]
80 | caption = self.videos.iloc[index]["text"]
81 |
82 | try:
83 | video_reader = VideoReader(
84 | video_path,
85 | ctx=cpu(0),
86 | width=self.resolution[1],
87 | height=self.resolution[0],
88 | )
89 | if len(video_reader) < self.video_length:
90 | index = (index + 1) % len(self.videos)
91 | continue
92 | else:
93 | break
94 | except Exception as e:
95 | print(f"Load video failed! path = {video_path}, error: {str(e)}")
96 | index = (index + 1) % len(self.videos)
97 | continue
98 |
99 | if self.random_stride:
100 | self.frame_stride = random.choice([4, 8, 12, 16])
101 |
102 | all_frames = list(range(0, len(video_reader), self.frame_stride))
103 | if len(all_frames) < self.video_length:
104 | all_frames = list(range(0, len(video_reader), 1))
105 |
106 | # Select random clip
107 | rand_idx = random.randint(0, len(all_frames) - self.video_length)
108 | frame_indices = all_frames[rand_idx : rand_idx + self.video_length]
109 | frames = video_reader.get_batch(frame_indices)
110 | assert (
111 | frames.shape[0] == self.video_length
112 | ), f"{len(frames)}, self.video_length={self.video_length}"
113 |
114 | frames = (
115 | torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
116 | ) # [t,h,w,c] -> [c,t,h,w]
117 | assert (
118 | frames.shape[2] == self.resolution[0]
119 | and frames.shape[3] == self.resolution[1]
120 | ), f"frames={frames.shape}, self.resolution={self.resolution}"
121 | frames = (frames / 255 - 0.5) * 2
122 |
123 | return {"video": frames, "caption": caption, "is_video": True}
124 |
125 | def __getitem__images(self, index):
126 | frames_list = []
127 | for i in range(self.video_length):
128 | # Get a unique video for each frame
129 | video_index = (index + i) % len(self.video_indices)
130 | video_path = self._get_video_path(video_index)
131 |
132 | try:
133 | video_reader = VideoReader(
134 | video_path,
135 | ctx=cpu(0),
136 | width=self.resolution[1],
137 | height=self.resolution[0],
138 | )
139 | except Exception as e:
140 | print(f"Load video failed! path = {video_path}, error = {e}")
141 | # Skip this video and try the next one
142 | return self.__getitem__images((index + 1) % len(self.video_indices))
143 |
144 | # Randomly select a frame from the video
145 | rand_idx = random.randint(0, len(video_reader) - 1)
146 | frame = video_reader[rand_idx]
147 | frame_tensor = (
148 | torch.tensor(frame.asnumpy()).permute(2, 0, 1).float().unsqueeze(0)
149 | ) # [h,w,c] -> [c,h,w] -> [1, c, h, w]
150 |
151 | frames_list.append(frame_tensor)
152 |
153 | frames = torch.cat(frames_list, dim=0)
154 | frames = (frames / 255 - 0.5) * 2
155 | frames = frames.permute(1, 0, 2, 3)
156 | assert (
157 | frames.shape[2] == self.resolution[0]
158 | and frames.shape[3] == self.resolution[1]
159 | ), f"frame={frames.shape}, self.resolution={self.resolution}"
160 |
161 | data = {"video": frames, "is_video": False}
162 | return data
163 |
164 | def __len__(self):
165 | return len(self.videos)
166 |
--------------------------------------------------------------------------------
/data/lightning_data.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 |
4 | import torch
5 | import pytorch_lightning as pl
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 | import argparse, os, sys, glob
9 |
10 | os.chdir(sys.path[0])
11 | sys.path.append("..")
12 |
13 | from utils.common_utils import instantiate_from_config
14 |
15 |
16 | def worker_init_fn(_):
17 | worker_info = torch.utils.data.get_worker_info()
18 |
19 | dataset = worker_info.dataset
20 | worker_id = worker_info.id
21 |
22 | mode = "image" if worker_id < worker_info.num_workers * 0.2 else "video"
23 | print(f"Mode is {mode}")
24 | dataset.set_mode(mode)
25 |
26 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
27 |
28 |
29 | class WrappedDataset(Dataset):
30 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
31 |
32 | def __init__(self, dataset):
33 | self.data = dataset
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def __getitem__(self, idx):
39 | return self.data[idx]
40 |
41 |
42 | class DataModuleFromConfig(pl.LightningDataModule):
43 | def __init__(
44 | self,
45 | batch_size,
46 | train=None,
47 | validation=None,
48 | test=None,
49 | predict=None,
50 | wrap=False,
51 | num_workers=None,
52 | shuffle_test_loader=False,
53 | img_video_joint_train=False,
54 | shuffle_val_dataloader=False,
55 | train_img=None,
56 | test_max_n_samples=None,
57 | ):
58 | super().__init__()
59 | self.batch_size = batch_size
60 | self.dataset_configs = dict()
61 | self.num_workers = num_workers if num_workers is not None else batch_size * 2
62 | self.use_worker_init_fn = img_video_joint_train
63 | if train is not None:
64 | self.dataset_configs["train"] = train
65 | self.train_dataloader = self._train_dataloader
66 | if validation is not None:
67 | self.dataset_configs["validation"] = validation
68 | self.val_dataloader = partial(
69 | self._val_dataloader, shuffle=shuffle_val_dataloader
70 | )
71 | if test is not None:
72 | self.dataset_configs["test"] = test
73 | self.test_dataloader = partial(
74 | self._test_dataloader, shuffle=shuffle_test_loader
75 | )
76 | if predict is not None:
77 | self.dataset_configs["predict"] = predict
78 | self.predict_dataloader = self._predict_dataloader
79 | # train 2 dataset
80 | # if img_loader is not None:
81 | # img_data = instantiate_from_config(img_loader)
82 | # img_data.setup()
83 | if train_img is not None:
84 | if train_img["params"]["batch_size"] == -1:
85 | train_img["params"]["batch_size"] = (
86 | batch_size * train["params"]["video_length"]
87 | )
88 | print(
89 | "Set train_img batch_size to {}".format(
90 | train_img["params"]["batch_size"]
91 | )
92 | )
93 | img_data = instantiate_from_config(train_img)
94 | self.img_loader = img_data.train_dataloader()
95 | else:
96 | self.img_loader = None
97 | self.wrap = wrap
98 | self.test_max_n_samples = test_max_n_samples
99 | self.collate_fn = None
100 |
101 | def prepare_data(self):
102 | # for data_cfg in self.dataset_configs.values():
103 | # instantiate_from_config(data_cfg)
104 | pass
105 |
106 | def setup(self, stage=None):
107 | self.datasets = dict(
108 | (k, instantiate_from_config(self.dataset_configs[k]))
109 | for k in self.dataset_configs
110 | )
111 | if self.wrap:
112 | for k in self.datasets:
113 | self.datasets[k] = WrappedDataset(self.datasets[k])
114 |
115 | def _train_dataloader(self):
116 | if self.use_worker_init_fn:
117 | init_fn = worker_init_fn
118 | else:
119 | init_fn = None
120 | loader = DataLoader(
121 | self.datasets["train"],
122 | batch_size=self.batch_size,
123 | num_workers=self.num_workers,
124 | shuffle=True,
125 | worker_init_fn=init_fn,
126 | collate_fn=self.collate_fn,
127 | )
128 | if self.img_loader is not None:
129 | return {"loader_video": loader, "loader_img": self.img_loader}
130 | else:
131 | return loader
132 |
133 | def _val_dataloader(self, shuffle=False):
134 | if self.use_worker_init_fn:
135 | init_fn = worker_init_fn
136 | else:
137 | init_fn = None
138 | return DataLoader(
139 | self.datasets["validation"],
140 | batch_size=self.batch_size,
141 | num_workers=self.num_workers,
142 | worker_init_fn=init_fn,
143 | shuffle=shuffle,
144 | collate_fn=self.collate_fn,
145 | )
146 |
147 | def _test_dataloader(self, shuffle=False):
148 | if self.use_worker_init_fn:
149 | init_fn = worker_init_fn
150 | else:
151 | init_fn = None
152 |
153 | if self.test_max_n_samples is not None:
154 | dataset = torch.utils.data.Subset(
155 | self.datasets["test"], list(range(self.test_max_n_samples))
156 | )
157 | else:
158 | dataset = self.datasets["test"]
159 | return DataLoader(
160 | dataset,
161 | batch_size=self.batch_size,
162 | num_workers=self.num_workers,
163 | worker_init_fn=init_fn,
164 | shuffle=shuffle,
165 | collate_fn=self.collate_fn,
166 | )
167 |
168 | def _predict_dataloader(self, shuffle=False):
169 | if self.use_worker_init_fn:
170 | init_fn = worker_init_fn
171 | else:
172 | init_fn = None
173 | return DataLoader(
174 | self.datasets["predict"],
175 | batch_size=self.batch_size,
176 | num_workers=self.num_workers,
177 | worker_init_fn=init_fn,
178 | collate_fn=self.collate_fn,
179 | )
180 |
--------------------------------------------------------------------------------
/docs/case1/fkanimal2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case1/fkanimal2.gif
--------------------------------------------------------------------------------
/docs/case1/gtanimal2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case1/gtanimal2.gif
--------------------------------------------------------------------------------
/docs/case2/fkcloseshot1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case2/fkcloseshot1.gif
--------------------------------------------------------------------------------
/docs/case2/gtcloseshot1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case2/gtcloseshot1.gif
--------------------------------------------------------------------------------
/docs/case3/fkface.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case3/fkface.gif
--------------------------------------------------------------------------------
/docs/case3/gtface.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case3/gtface.gif
--------------------------------------------------------------------------------
/docs/case4/fkmotion4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case4/fkmotion4.gif
--------------------------------------------------------------------------------
/docs/case4/gtmotion4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case4/gtmotion4.gif
--------------------------------------------------------------------------------
/docs/case5/fkview7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case5/fkview7.gif
--------------------------------------------------------------------------------
/docs/case5/gtview7.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case5/gtview7.gif
--------------------------------------------------------------------------------
/docs/sota-table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/sota-table.png
--------------------------------------------------------------------------------
/evaluation/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | import math
5 | from glob import glob
6 | from skimage.metrics import structural_similarity as compare_ssim
7 | import imageio
8 | import lpips
9 | import torch
10 | from tqdm import tqdm
11 | import logging
12 |
13 | # Configure logging
14 | logging.basicConfig(
15 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
16 | )
17 |
18 | # Argument parser
19 | parser = argparse.ArgumentParser(
20 | description="Compute PSNR, SSIM, and LPIPS for videos."
21 | )
22 | parser.add_argument(
23 | "--root1",
24 | "-r1",
25 | type=str,
26 | required=True,
27 | help="Directory for the first set of videos.",
28 | )
29 | parser.add_argument(
30 | "--root2",
31 | "-r2",
32 | type=str,
33 | required=True,
34 | help="Directory for the second set of videos.",
35 | )
36 | parser.add_argument("--ssim", action="store_true", default=False, help="Compute SSIM.")
37 | parser.add_argument("--psnr", action="store_true", default=False, help="Compute PSNR.")
38 | parser.add_argument(
39 | "--lpips", action="store_true", default=False, help="Compute LPIPS."
40 | )
41 |
42 | args = parser.parse_args()
43 |
44 | # Define metric functions
45 |
46 |
47 | def compute_psnr(img1, img2):
48 | mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2)
49 | if mse < 1.0e-10:
50 | return 100
51 | PIXEL_MAX = 1
52 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
53 |
54 |
55 | def compute_ssim(img1, img2):
56 | if np.all(img1 == img1[0, 0, 0]) or np.all(img2 == img2[0, 0, 0]):
57 | return 1.0
58 | return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), channel_axis=-1)
59 |
60 |
61 | def compute_lpips(img1, img2, loss_fn):
62 | img1_tensor = (
63 | torch.from_numpy(img1 / 255.0)
64 | .float()
65 | .permute(2, 0, 1)
66 | .unsqueeze(0)
67 | .to("cuda:0")
68 | )
69 | img2_tensor = (
70 | torch.from_numpy(img2 / 255.0)
71 | .float()
72 | .permute(2, 0, 1)
73 | .unsqueeze(0)
74 | .to("cuda:0")
75 | )
76 |
77 | img1_tensor = img1_tensor * 2 - 1 # Normalize to [-1, 1]
78 | img2_tensor = img2_tensor * 2 - 1
79 |
80 | return loss_fn(img1_tensor, img2_tensor).item()
81 |
82 |
83 | def read_video(file_path):
84 | try:
85 | video = imageio.get_reader(file_path)
86 | frames = [frame for frame in video]
87 | video.close()
88 | return frames
89 | except Exception as e:
90 | logging.error(f"Error reading video {file_path}: {e}")
91 | return []
92 |
93 |
94 | def save_results(results, root1, root2, output_file="metrics.txt"):
95 | with open(output_file, "a") as f:
96 | f.write("\n")
97 | f.write(f"Root1: {root1}\n")
98 | f.write(f"Root2: {root2}\n")
99 | for metric, value in results.items():
100 | f.write(f"{metric}: {value}\n")
101 | f.write("\n")
102 | logging.info(f"Results saved to {output_file}")
103 |
104 |
105 | def main():
106 | # Load video paths
107 | all_videos1 = sorted(glob(os.path.join(args.root1, "*mp4")))
108 | all_videos2 = sorted(glob(os.path.join(args.root2, "*mp4")))
109 |
110 | assert len(all_videos1) == len(
111 | all_videos2
112 | ), f"Number of files mismatch: {len(all_videos1)} in {args.root1}, {len(all_videos2)} in {args.root2}"
113 |
114 | # Metrics storage
115 | metric_psnr = []
116 | metric_ssim = []
117 | metric_lpips = []
118 |
119 | # Initialize LPIPS model if needed
120 | lpips_model = None
121 | if args.lpips:
122 | lpips_model = lpips.LPIPS(net="alex").to("cuda:0")
123 | logging.info("Initialized LPIPS model (AlexNet).")
124 |
125 | for vid1_path, vid2_path in tqdm(
126 | zip(all_videos1, all_videos2), total=len(all_videos1), desc="Processing videos"
127 | ):
128 | vid1_frames = read_video(vid1_path)
129 | vid2_frames = read_video(vid2_path)
130 |
131 | if not vid1_frames or not vid2_frames:
132 | logging.error(
133 | f"Skipping video pair due to read failure: {vid1_path}, {vid2_path}"
134 | )
135 | continue
136 |
137 | assert len(vid1_frames) == len(
138 | vid2_frames
139 | ), f"Frame count mismatch: {len(vid1_frames)} in {vid1_path}, {len(vid2_frames)} in {vid2_path}"
140 |
141 | # Process each pair of frames
142 | for f1, f2 in zip(vid1_frames, vid2_frames):
143 | if args.psnr:
144 | try:
145 | psnr_value = compute_psnr(f1, f2)
146 | metric_psnr.append(psnr_value)
147 | except Exception as e:
148 | logging.error(f"Error computing PSNR for frames: {e}")
149 |
150 | if args.ssim:
151 | try:
152 | ssim_value = compute_ssim(f1, f2)
153 | metric_ssim.append(ssim_value)
154 | except Exception as e:
155 | logging.error(f"Error computing SSIM for frames: {e}")
156 |
157 | if args.lpips:
158 | try:
159 | lpips_value = compute_lpips(f1, f2, lpips_model)
160 | metric_lpips.append(lpips_value)
161 | except Exception as e:
162 | logging.error(f"Error computing LPIPS for frames: {e}")
163 |
164 | # Compute average metrics
165 | results = {}
166 | if args.psnr and metric_psnr:
167 | results["PSNR"] = sum(metric_psnr) / len(metric_psnr)
168 | if args.ssim and metric_ssim:
169 | results["SSIM"] = sum(metric_ssim) / len(metric_ssim)
170 | if args.lpips and metric_lpips:
171 | results["LPIPS"] = sum(metric_lpips) / len(metric_lpips)
172 |
173 | # Print and save results
174 | logging.info(f"Results: {results}")
175 | save_results(results, args.root1, args.root2)
176 |
177 |
178 | if __name__ == "__main__":
179 | main()
180 |
--------------------------------------------------------------------------------
/evaluation/compute_metrics_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | import math
5 | from glob import glob
6 | from skimage.metrics import structural_similarity as compare_ssim
7 | import imageio
8 | import lpips
9 | import torch
10 | from tqdm import tqdm
11 | import logging
12 |
13 | # Configure logging
14 | logging.basicConfig(
15 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
16 | )
17 |
18 | # Argument parser
19 | parser = argparse.ArgumentParser(
20 | description="Calculate PSNR, SSIM, and LPIPS between two sets of images."
21 | )
22 | parser.add_argument(
23 | "--root1",
24 | "-r1",
25 | type=str,
26 | required=True,
27 | help="Directory for the first set of images.",
28 | )
29 | parser.add_argument(
30 | "--root2",
31 | "-r2",
32 | type=str,
33 | required=True,
34 | help="Directory for the second set of images.",
35 | )
36 | parser.add_argument("--ssim", action="store_true", default=False, help="Compute SSIM.")
37 | parser.add_argument("--psnr", action="store_true", default=False, help="Compute PSNR.")
38 | parser.add_argument(
39 | "--lpips", action="store_true", default=False, help="Compute LPIPS."
40 | )
41 |
42 | args = parser.parse_args()
43 |
44 | # Define metric functions
45 |
46 |
47 | def compute_psnr(img1, img2):
48 | mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2)
49 | if mse < 1.0e-10:
50 | return 100
51 | PIXEL_MAX = 1
52 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
53 |
54 |
55 | def compute_ssim(img1, img2):
56 | return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), channel_axis=-1)
57 |
58 |
59 | def compute_lpips(img1, img2, loss_fn):
60 | img1_tensor = (
61 | torch.from_numpy(img1 / 255.0)
62 | .float()
63 | .permute(2, 0, 1)
64 | .unsqueeze(0)
65 | .to("cuda:0")
66 | )
67 | img2_tensor = (
68 | torch.from_numpy(img2 / 255.0)
69 | .float()
70 | .permute(2, 0, 1)
71 | .unsqueeze(0)
72 | .to("cuda:0")
73 | )
74 |
75 | img1_tensor = img1_tensor * 2 - 1 # Normalize to [-1, 1]
76 | img2_tensor = img2_tensor * 2 - 1
77 |
78 | return loss_fn(img1_tensor, img2_tensor).item()
79 |
80 |
81 | def read_image(file_path):
82 | try:
83 | return imageio.imread(file_path)
84 | except Exception as e:
85 | logging.error(f"Error reading image {file_path}: {e}")
86 | return None
87 |
88 |
89 | def save_results(results, root1, root2, output_file="metrics.txt"):
90 | with open(output_file, "a") as f:
91 | f.write("\n")
92 | f.write(f"Root1: {root1}\n")
93 | f.write(f"Root2: {root2}\n")
94 | for metric, value in results.items():
95 | f.write(f"{metric}: {value}\n")
96 | f.write("\n")
97 | logging.info(f"Results saved to {output_file}")
98 |
99 |
100 | def main():
101 | # Load image paths
102 | all_images1 = sorted(glob(os.path.join(args.root1, "*jpeg")))
103 | all_images2 = sorted(glob(os.path.join(args.root2, "*jpeg")))
104 |
105 | assert len(all_images1) == len(
106 | all_images2
107 | ), f"Number of files mismatch: {len(all_images1)} in {args.root1}, {len(all_images2)} in {args.root2}"
108 |
109 | # Metrics storage
110 | metric_psnr = []
111 | metric_ssim = []
112 | metric_lpips = []
113 |
114 | lpips_model = None
115 | if args.lpips:
116 | lpips_model = lpips.LPIPS(net="alex").to("cuda:0")
117 | logging.info("Initialized LPIPS model (AlexNet).")
118 |
119 | # Compute metrics for each pair of images
120 | for i, (img1_path, img2_path) in enumerate(
121 | tqdm(
122 | zip(all_images1, all_images2),
123 | total=len(all_images1),
124 | desc="Processing images",
125 | )
126 | ):
127 | img1 = read_image(img1_path)
128 | img2 = read_image(img2_path)
129 | if img1 is None or img2 is None:
130 | logging.warning(f"Skipping pair: {img1_path}, {img2_path}")
131 | continue
132 |
133 | if args.psnr:
134 | try:
135 | psnr_value = compute_psnr(img1, img2)
136 | metric_psnr.append(psnr_value)
137 | except Exception as e:
138 | logging.error(f"Error computing PSNR for {img1_path}, {img2_path}: {e}")
139 |
140 | if args.ssim:
141 | try:
142 | ssim_value = compute_ssim(img1, img2)
143 | metric_ssim.append(ssim_value)
144 | except Exception as e:
145 | logging.error(f"Error computing SSIM for {img1_path}, {img2_path}: {e}")
146 |
147 | if args.lpips:
148 | try:
149 | lpips_value = compute_lpips(img1, img2, lpips_model)
150 | metric_lpips.append(lpips_value)
151 | except Exception as e:
152 | logging.error(
153 | f"Error computing LPIPS for {img1_path}, {img2_path}: {e}"
154 | )
155 |
156 | results = {}
157 | if args.psnr and metric_psnr:
158 | results["PSNR"] = sum(metric_psnr) / len(metric_psnr)
159 | if args.ssim and metric_ssim:
160 | results["SSIM"] = sum(metric_ssim) / len(metric_ssim)
161 | if args.lpips and metric_lpips:
162 | results["LPIPS"] = sum(metric_lpips) / len(metric_lpips)
163 |
164 | # Print and save results
165 | logging.info(f"Results: {results}")
166 | save_results(results, args.root1, args.root2)
167 |
168 |
169 | if __name__ == "__main__":
170 | main()
171 |
--------------------------------------------------------------------------------
/examples/images/gt/00000091.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000091.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00000103.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000103.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00000110.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000110.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00000212.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000212.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00000268.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000268.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00000592.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000592.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00006871.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00006871.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00007252.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00007252.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00007826.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00007826.jpg
--------------------------------------------------------------------------------
/examples/images/gt/00008868.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00008868.jpg
--------------------------------------------------------------------------------
/examples/images/recon/00000091.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000091.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00000110.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000110.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00000212.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000212.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00000268.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000268.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00000592.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000592.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00006871.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00006871.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00007252.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00007252.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00007826.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00007826.jpeg
--------------------------------------------------------------------------------
/examples/images/recon/00008868.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00008868.jpeg
--------------------------------------------------------------------------------
/examples/videos/gt/40.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/40.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/40.txt:
--------------------------------------------------------------------------------
1 | The video features a gray tabby cat lying on a wooden deck, grooming itself in the sunlight. The cat is seen licking its paw and then using it to clean its face and fur. The background includes some greenery, indicating an outdoor setting.
--------------------------------------------------------------------------------
/examples/videos/gt/8.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/8.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/8.txt:
--------------------------------------------------------------------------------
1 | The video showcases a glassblowing process, featuring a skilled artisan working with molten glass in a furnace. The scene is set in a dark environment with a bright orange glow from the furnace illuminating the glass and the artisan's tools. The glassblowing process involves shaping and manipulating the molten glass to create various forms, highlighting the intricate and artistic nature of this traditional craft.
--------------------------------------------------------------------------------
/examples/videos/gt/animal.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/animal.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/animal.txt:
--------------------------------------------------------------------------------
1 | The video features a juvenile Black-crowned Night Heron perched on a branch. The heron is primarily gray and brown with white speckling on its feathers. It has a long, pointed beak and is actively preening its feathers, using its beak to smooth and clean them. The background is a soft, blurred green, suggesting foliage and possibly other trees. The lighting is natural and warm, indicating that the shot was likely taken during the day. The overall tone of the video is peaceful and focused on the natural behavior of the bird. A small logo "8 EARTH" is visible in the bottom left corner.
--------------------------------------------------------------------------------
/examples/videos/gt/closeshot.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/closeshot.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/closeshot.txt:
--------------------------------------------------------------------------------
1 | The video features a detailed macro view of a 50mm prime lens, likely for a DSLR or mirrorless camera. The image highlights the lens's build quality, showcasing the textured grip and the aperture markings ranging from f/1.8 to f/16. The distance scale is also visible, indicating focus points. A red accent ring around the lens barrel provides a subtle visual highlight. The lens's glass reflects light, creating a soft, out-of-focus bokeh effect. This shot emphasizes the precision and design of professional camera equipment.
--------------------------------------------------------------------------------
/examples/videos/gt/face.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/face.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/face.txt:
--------------------------------------------------------------------------------
1 | The video captures a man with a beard and glasses is talking to the camera. He is gesturing with his hands as he explains something. Behind him, there are Star Trek action figures in their boxes and other items on a wooden cabinet. The man is likely discussing or explaining a process or topic.
--------------------------------------------------------------------------------
/examples/videos/gt/view.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/view.mp4
--------------------------------------------------------------------------------
/examples/videos/gt/view.txt:
--------------------------------------------------------------------------------
1 | The video captures the vibrant and unique atmosphere of Fremont Street in Las Vegas. The combination of the large mechanical praying mantis, the fire effects, and the text overlay suggests a lively and engaging walking tour experience. The video is likely the title card or intro to a video about Fremont Street.
--------------------------------------------------------------------------------
/examples/videos/recon/40_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/40_reconstructed.mp4
--------------------------------------------------------------------------------
/examples/videos/recon/8_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/8_reconstructed.mp4
--------------------------------------------------------------------------------
/examples/videos/recon/animal_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/animal_reconstructed.mp4
--------------------------------------------------------------------------------
/examples/videos/recon/closeshot_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/closeshot_reconstructed.mp4
--------------------------------------------------------------------------------
/examples/videos/recon/face_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/face_reconstructed.mp4
--------------------------------------------------------------------------------
/examples/videos/recon/view_reconstructed.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/view_reconstructed.mp4
--------------------------------------------------------------------------------
/inference_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import logging
4 | from glob import glob
5 | import argparse
6 | from omegaconf import OmegaConf
7 | from utils.common_utils import instantiate_from_config
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 | from PIL import Image
11 |
12 | logging.basicConfig(
13 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
14 | )
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(description="Image Inference Script")
19 | parser.add_argument(
20 | "--data_root",
21 | type=str,
22 | required=True,
23 | help="Path to the folder containing input images.",
24 | )
25 | parser.add_argument(
26 | "--out_root", type=str, required=True, help="Path to save reconstructed images."
27 | )
28 | parser.add_argument(
29 | "--config_path",
30 | type=str,
31 | required=True,
32 | help="Path to the model configuration file.",
33 | )
34 | parser.add_argument(
35 | "--batch_size", type=int, default=16, help="Batch size for image processing."
36 | )
37 | parser.add_argument(
38 | "--device",
39 | type=str,
40 | default="cuda:0",
41 | help="Device to run inference on (e.g., 'cpu', 'cuda:0').",
42 | )
43 | return parser.parse_args()
44 |
45 |
46 | def data_processing(img_path):
47 | try:
48 | img = Image.open(img_path).convert("RGB")
49 | transform = transforms.Compose(
50 | [
51 | transforms.ToTensor(),
52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
53 | ]
54 | )
55 | return transform(img)
56 | except Exception as e:
57 | logging.error(f"Error processing image {img_path}: {e}")
58 | return None
59 |
60 |
61 | def save_img(tensor, save_path):
62 | try:
63 | tensor = (tensor + 1) / 2 # Denormalize
64 | tensor = tensor.clamp(0, 1).detach().cpu()
65 | to_pil = transforms.ToPILImage()
66 | img = to_pil(tensor)
67 | img.save(save_path, format="JPEG")
68 | logging.info(f"Image saved to {save_path}")
69 | except Exception as e:
70 | logging.error(f"Error saving image to {save_path}: {e}")
71 |
72 |
73 | def process_batch(image_list, img_name_list, model, device, out_root):
74 | try:
75 | frames = torch.stack(image_list) # [batch_size, c, h, w]
76 | frames = frames.unsqueeze(1) # [batch_size, 1, c, h, w]
77 | frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, c, 1, h, w]
78 |
79 | with torch.no_grad():
80 | frames = frames.to(device)
81 | dec, _ = model.forward(frames, sample_posterior=False, mask_temporal=True)
82 | dec = dec.squeeze(2) # [batch_size, c, h, w]
83 |
84 | for i in range(len(image_list)):
85 | output_img = dec[i]
86 | save_img(output_img, os.path.join(out_root, img_name_list[i] + ".jpeg"))
87 | except Exception as e:
88 | logging.error(f"Error processing batch: {e}")
89 |
90 |
91 | def main():
92 | args = parse_args()
93 |
94 | os.makedirs(args.out_root, exist_ok=True)
95 |
96 | config = OmegaConf.load(args.config_path)
97 | model = instantiate_from_config(config.model)
98 | model = model.to(args.device)
99 | model.eval()
100 |
101 | # Load all image paths
102 | all_images = sorted(glob(os.path.join(args.data_root, "*jpeg")))
103 | if not all_images:
104 | logging.error(f"No images found in {args.data_root}")
105 | return
106 |
107 | batch_size = args.batch_size
108 | image_list = []
109 | img_name_list = []
110 |
111 | logging.info(f"Starting inference on {len(all_images)} images...")
112 |
113 | for img_path in all_images:
114 | img = data_processing(img_path) # [c, h, w]
115 | if img is None:
116 | logging.warning(f"Skipping invalid image {img_path}")
117 | continue
118 |
119 | img_name = os.path.basename(img_path).split(".")[0]
120 | image_list.append(img)
121 | img_name_list.append(img_name)
122 |
123 | # Process a batch when full
124 | if len(image_list) == batch_size:
125 | logging.info(f"Processing batch of {batch_size} images...")
126 | process_batch(image_list, img_name_list, model, args.device, args.out_root)
127 |
128 | # Clear lists for next batch
129 | image_list = []
130 | img_name_list = []
131 |
132 | # Process any remaining images
133 | if len(image_list) > 0:
134 | logging.info(f"Processing remaining {len(image_list)} images...")
135 | process_batch(image_list, img_name_list, model, args.device, args.out_root)
136 |
137 | logging.info("Inference completed successfully!")
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/inference_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import logging
5 | from decord import VideoReader, cpu
6 | from glob import glob
7 | from omegaconf import OmegaConf
8 | import numpy as np
9 | import imageio
10 | from tqdm import tqdm
11 | from utils.common_utils import instantiate_from_config
12 | from src.modules.t5 import T5Embedder
13 | import torchvision
14 |
15 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
16 | logging.basicConfig(
17 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
18 | )
19 |
20 |
21 | def parse_args():
22 | """Parse command-line arguments."""
23 | parser = argparse.ArgumentParser(description="Video VAE Inference Script")
24 | parser.add_argument(
25 | "--data_root",
26 | type=str,
27 | required=True,
28 | help="Path to the folder containing input videos.",
29 | )
30 | parser.add_argument(
31 | "--out_root", type=str, required=True, help="Path to save reconstructed videos."
32 | )
33 | parser.add_argument(
34 | "--config_path",
35 | type=str,
36 | required=True,
37 | help="Path to the model configuration file.",
38 | )
39 | parser.add_argument(
40 | "--device",
41 | type=str,
42 | default="cuda:0",
43 | help="Device to run inference on (e.g., 'cpu', 'cuda:0').",
44 | )
45 | parser.add_argument(
46 | "--chunk_size",
47 | type=int,
48 | default=16,
49 | help="Number of frames per chunk for processing.",
50 | )
51 | parser.add_argument(
52 | "--resolution",
53 | type=int,
54 | nargs=2,
55 | default=[720, 1280],
56 | help="Resolution to process videos (height, width).",
57 | )
58 | return parser.parse_args()
59 |
60 |
61 | def data_processing(video_path, resolution):
62 | """Load and preprocess video data."""
63 | try:
64 | video_reader = VideoReader(video_path, ctx=cpu(0))
65 | video_resolution = video_reader[0].shape
66 |
67 | # Rescale resolution to match specified limits
68 | resolution = [
69 | min(video_resolution[0], resolution[0]),
70 | min(video_resolution[1], resolution[1]),
71 | ]
72 | video_reader = VideoReader(
73 | video_path, ctx=cpu(0), width=resolution[1], height=resolution[0]
74 | )
75 |
76 | video_length = len(video_reader)
77 | vid_fps = video_reader.get_avg_fps()
78 | frame_indices = list(range(0, video_length))
79 | frames = video_reader.get_batch(frame_indices)
80 | assert (
81 | frames.shape[0] == video_length
82 | ), f"Frame mismatch: {len(frames)} != {video_length}"
83 |
84 | frames = (
85 | torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
86 | ) # [t, h, w, c] -> [c, t, h, w]
87 | frames = (frames / 255 - 0.5) * 2 # Normalize to [-1, 1]
88 | return frames, vid_fps
89 | except Exception as e:
90 | logging.error(f"Error processing video {video_path}: {e}")
91 | return None, None
92 |
93 |
94 | def save_video(tensor, save_path, fps: float):
95 | """Save video tensor to a file."""
96 | try:
97 | tensor = torch.clamp((tensor + 1) / 2, 0, 1) * 255
98 | arr = tensor.detach().cpu().squeeze().to(torch.uint8)
99 | c, t, h, w = arr.shape
100 |
101 | torchvision.io.write_video(save_path, arr.permute(1, 2, 3, 0), fps=fps, options={'codec': 'libx264', 'crf': '15'})
102 | logging.info(f"Video saved to {save_path}")
103 | except Exception as e:
104 | logging.error(f"Error saving video {save_path}: {e}")
105 |
106 |
107 | def process_in_chunks(
108 | video_data,
109 | model,
110 | chunk_size,
111 | text_embeddings=None,
112 | text_attn_mask=None,
113 | device="cuda:0",
114 | ):
115 | try:
116 | assert chunk_size % 4 == 0, "Chunk size must be a multiple of 4."
117 | num_frames = video_data.size(2)
118 | padding_frames = 0
119 | output_chunks = []
120 |
121 | # Pad video to make the frame count divisible by 4
122 | if num_frames % 4 != 0:
123 | padding_frames = 4 - (num_frames % 4)
124 | padding = video_data[:, :, -1:, :, :].repeat(1, 1, padding_frames, 1, 1)
125 | video_data = torch.cat((video_data, padding), dim=2)
126 | num_frames = video_data.size(2)
127 |
128 | start = 0
129 |
130 | while start < num_frames:
131 | end = min(start + chunk_size, num_frames)
132 | chunk = video_data[:, :, start:end, :, :]
133 |
134 | with torch.no_grad():
135 | chunk = chunk.to(device)
136 | if text_embeddings is not None and text_attn_mask is not None:
137 | recon_chunk, _ = model.forward(
138 | chunk,
139 | text_embeddings=text_embeddings,
140 | text_attn_mask=text_attn_mask,
141 | sample_posterior=False,
142 | )
143 | else:
144 | recon_chunk, _ = model.forward(chunk, sample_posterior=False)
145 | recon_chunk = recon_chunk.cpu().float()
146 | output_chunks.append(recon_chunk)
147 | start += chunk_size
148 |
149 | ret = torch.cat(output_chunks, dim=2)
150 | if padding_frames > 0:
151 | ret = ret[:, :, :-padding_frames, :, :]
152 | return ret
153 | except Exception as e:
154 | logging.error(f"Error processing chunks: {e}")
155 | return None
156 |
157 |
158 | def main():
159 | """Main function for video VAE inference."""
160 | args = parse_args()
161 |
162 | os.makedirs(args.out_root, exist_ok=True)
163 | config = OmegaConf.load(args.config_path)
164 |
165 | # Initialize model
166 | model = instantiate_from_config(config.model)
167 | is_t5 = getattr(model, "caption_guide", False)
168 | model = model.to(args.device)
169 | model.eval()
170 |
171 | # Initialize text embedder if T5 is used
172 | text_embedder = None
173 | if is_t5:
174 | text_embedder = T5Embedder(
175 | device=args.device, model_max_length=model.t5_model_max_length
176 | )
177 |
178 | # Get all videos
179 | all_videos = sorted(glob(os.path.join(args.data_root, "*.mp4")))
180 | if not all_videos:
181 | logging.error(f"No videos found in {args.data_root}")
182 | return
183 |
184 | # Process each video
185 | for video_path in tqdm(all_videos, desc="Processing videos", unit="video"):
186 | logging.info(f"Processing video: {video_path}")
187 | frames, vid_fps = data_processing(video_path, args.resolution)
188 | if frames is None:
189 | continue
190 |
191 | video_name = os.path.basename(video_path).split(".")[0]
192 | frames = torch.unsqueeze(frames, dim=0) # Add batch dimension
193 |
194 | with torch.no_grad():
195 | if is_t5:
196 | # Load caption if available
197 | text_path = os.path.join(args.data_root, f"{video_name}.txt")
198 | try:
199 | with open(text_path, "r") as f:
200 | caption = [f.read()]
201 | except Exception as e:
202 | logging.warning(f"Caption file not found for {video_name}: {e}")
203 | caption = [""]
204 |
205 | text_embedding, text_attn_mask = text_embedder.get_text_embeddings(
206 | caption
207 | )
208 | text_embedding = text_embedding.to(args.device, dtype=model.dtype)
209 | text_attn_mask = text_attn_mask.to(args.device, dtype=model.dtype)
210 |
211 | video_recon = process_in_chunks(
212 | frames,
213 | model,
214 | args.chunk_size,
215 | text_embedding,
216 | text_attn_mask,
217 | device=args.device,
218 | )
219 | else:
220 | video_recon = process_in_chunks(
221 | frames, model, args.chunk_size, device=args.device
222 | )
223 |
224 | if video_recon is not None:
225 | save_path = os.path.join(
226 | args.out_root, f"{video_name}_reconstructed.mp4"
227 | )
228 | save_video(video_recon, save_path, vid_fps)
229 |
230 |
231 | if __name__ == "__main__":
232 | main()
233 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | av==12.0.0
2 | accelerate==0.34.2
3 | academictorrents==2.3.3
4 | albumentations==1.4.16
5 | apex==0.9.10dev
6 | beautifulsoup4==4.12.3
7 | decord==0.6.0
8 | diffusers==0.30.3
9 | einops==0.8.0
10 | fairscale==0.4.13
11 | ftfy==6.2.3
12 | huggingface-hub==0.23.2
13 | imageio==2.33.1
14 | kornia==0.7.3
15 | moviepy==1.0.3
16 | more_itertools==10.5.0
17 | numpy==1.26.3
18 | nvidia-cublas-cu12==12.1.3.1
19 | nvidia-cuda-cupti-cu12==12.1.105
20 | nvidia-cuda-nvrtc-cu12==12.1.105
21 | nvidia-cuda-runtime-cu12==12.1.105
22 | nvidia-cudnn-cu12==8.9.2.26
23 | nvidia-cufft-cu12==11.0.2.54
24 | nvidia-curand-cu12==10.3.2.106
25 | nvidia-cusolver-cu12==11.4.5.107
26 | nvidia-cusparse-cu12==12.1.0.106
27 | nvidia-nccl-cu12==2.19.3
28 | nvidia-nvjitlink-cu12==12.1.105
29 | nvidia-nvtx-cu12==12.1.105
30 | omegaconf==2.3.0
31 | opencv-python==4.9.0.80
32 | packaging==24.0
33 | pandas==2.2.1
34 | psutil==6.0.0
35 | Pillow==10.4.0
36 | pytorch_lightning==1.9.4
37 | PyYAML==6.0.1
38 | protobuf==3.20.*
39 | Requests==2.32.3
40 | safetensors==0.4.5
41 | scipy==1.14.1
42 | sentencepiece==0.2.0
43 | tensorboard==2.18.0
44 | taming-transformers==0.0.1
45 | tensorboardX==2.6.2.2
46 | timm==1.0.9
47 | torch==2.2.0
48 | torchaudio==2.2.0
49 | torchmetrics==1.3.1
50 | torchvision==0.17.0
51 | tokenizers==0.13.3
52 | tqdm==4.66.2
53 | transformers==4.25.1
54 | typing_extensions==4.12.2
55 | xformers==0.0.24
56 | lpips
--------------------------------------------------------------------------------
/scripts/evaluation_image.sh:
--------------------------------------------------------------------------------
1 | python evaluation/compute_metrics_img.py \
2 | --root1 "examples/images/gt" \
3 | --root2 "examples/images/recon" \
4 | --ssim \
5 | --psnr \
6 | --lpips
--------------------------------------------------------------------------------
/scripts/evaluation_video.sh:
--------------------------------------------------------------------------------
1 | python evaluation/compute_metrics.py \
2 | --root1 "examples/videos/gt" \
3 | --root2 "examples/videos/recon" \
4 | --ssim \
5 | --psnr \
6 | --lpips
--------------------------------------------------------------------------------
/scripts/run_inference_image.sh:
--------------------------------------------------------------------------------
1 | python inference_image.py \
2 | --data_root 'examples/images/gt' \
3 | --out_root 'examples/images/recon' \
4 | --config_path 'configs/inference/config_16z.yaml' \
5 | --batch_size 1
--------------------------------------------------------------------------------
/scripts/run_inference_video.sh:
--------------------------------------------------------------------------------
1 | python inference_video.py \
2 | --data_root 'examples/videos/gt' \
3 | --out_root 'examples/videos/recon' \
4 | --config_path 'configs/inference/config_16z.yaml' \
5 | --chunk_size 8 --resolution 720 1280
6 |
--------------------------------------------------------------------------------
/scripts/run_train.sh:
--------------------------------------------------------------------------------
1 | yaml="configs/train/$1.yaml"
2 | exp_name="VideoVAEPlus_$1"
3 |
4 | n_HOST=1
5 | elastic=1
6 | GPUName="A"
7 | current_time=$(date +%Y%m%d%H%M%S)
8 |
9 | out_dir_name="${exp_name}_${n_HOST}nodes_e${elastic}_${GPUName}_$current_time"
10 | res_root="./debug"
11 |
12 | mkdir -p $res_root/$out_dir_name
13 |
14 | torchrun \
15 | --nproc_per_node=1 --nnodes=1 --master_port=16666 \
16 | train.py \
17 | --base $yaml \
18 | -t --devices 0, \
19 | lightning.trainer.num_nodes=1 \
20 | --name ${out_dir_name} \
21 | --logdir $res_root \
22 | --auto_resume True \
--------------------------------------------------------------------------------
/src/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(
34 | device=self.parameters.device
35 | )
36 |
37 | def sample(self, noise=None):
38 | if noise is None:
39 | noise = torch.randn(self.mean.shape)
40 |
41 | x = self.mean + self.std * noise.to(device=self.parameters.device)
42 | return x
43 |
44 | def kl(self, other=None):
45 | if self.deterministic:
46 | return torch.Tensor([0.0])
47 | else:
48 | if other is None:
49 | return 0.5 * torch.sum(
50 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
51 | dim=[1, 2, 3],
52 | )
53 | else:
54 | return 0.5 * torch.sum(
55 | torch.pow(self.mean - other.mean, 2) / other.var
56 | + self.var / other.var
57 | - 1.0
58 | - self.logvar
59 | + other.logvar,
60 | dim=[1, 2, 3],
61 | )
62 |
63 | def nll(self, sample, dims=[1, 2, 3]):
64 | if self.deterministic:
65 | return torch.Tensor([0.0])
66 | logtwopi = np.log(2.0 * np.pi)
67 | return 0.5 * torch.sum(
68 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
69 | dim=dims,
70 | )
71 |
72 | def mode(self):
73 | return self.mean
74 |
75 |
76 | def normal_kl(mean1, logvar1, mean2, logvar2):
77 | """
78 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
79 | Compute the KL divergence between two gaussians.
80 | Shapes are automatically broadcasted, so batches can be compared to
81 | scalars, among other use cases.
82 | """
83 | tensor = None
84 | for obj in (mean1, logvar1, mean2, logvar2):
85 | if isinstance(obj, torch.Tensor):
86 | tensor = obj
87 | break
88 | assert tensor is not None, "at least one argument must be a Tensor"
89 |
90 | # Force variances to be Tensors. Broadcasting helps convert scalars to
91 | # Tensors, but it does not work for torch.exp().
92 | logvar1, logvar2 = [
93 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
94 | for x in (logvar1, logvar2)
95 | ]
96 |
97 | return 0.5 * (
98 | -1.0
99 | + logvar2
100 | - logvar1
101 | + torch.exp(logvar1 - logvar2)
102 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
103 | )
104 |
--------------------------------------------------------------------------------
/src/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 | import torch.nn.functional as F
5 | import pytorch_lightning as pl
6 |
7 | from src.modules.ae_modules import Encoder, Decoder
8 | from src.distributions import DiagonalGaussianDistribution
9 | from utils.common_utils import instantiate_from_config
10 |
11 |
12 | class AutoencoderKL(pl.LightningModule):
13 | def __init__(
14 | self,
15 | ddconfig,
16 | lossconfig,
17 | embed_dim,
18 | use_quant_conv=True,
19 | ckpt_path=None,
20 | ignore_keys=[],
21 | image_key="image",
22 | colorize_nlabels=None,
23 | monitor=None,
24 | test=False,
25 | logdir=None,
26 | input_dim=4,
27 | test_args=None,
28 | ):
29 | super().__init__()
30 | self.image_key = image_key
31 | self.encoder = Encoder(**ddconfig)
32 | self.decoder = Decoder(**ddconfig)
33 | self.loss = instantiate_from_config(lossconfig)
34 | assert ddconfig["double_z"]
35 |
36 | if use_quant_conv:
37 | self.quant_conv = torch.nn.Conv2d(
38 | 2 * ddconfig["z_channels"], 2 * embed_dim, 1
39 | )
40 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
41 | self.embed_dim = embed_dim
42 |
43 | self.use_quant_conv = use_quant_conv
44 |
45 | self.input_dim = input_dim
46 | self.test = test
47 | self.test_args = test_args
48 | self.logdir = logdir
49 | if colorize_nlabels is not None:
50 | assert type(colorize_nlabels) == int
51 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
52 | if monitor is not None:
53 | self.monitor = monitor
54 | if ckpt_path is not None:
55 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
56 |
57 | def init_from_ckpt(self, path, ignore_keys=list()):
58 | sd = torch.load(path, map_location="cpu")
59 | try:
60 | self._cur_epoch = sd["epoch"]
61 | sd = sd["state_dict"]
62 | except:
63 | self._cur_epoch = "null"
64 | keys = list(sd.keys())
65 | for k in keys:
66 | for ik in ignore_keys:
67 | if k.startswith(ik):
68 | # print("Deleting key {} from state_dict.".format(k))
69 | del sd[k]
70 | self.load_state_dict(sd, strict=False)
71 | # self.load_state_dict(sd, strict=True)
72 | print(f"Restored from {path}")
73 |
74 | def encode(self, x, **kwargs):
75 |
76 | h = self.encoder(x)
77 | moments = h
78 | if self.use_quant_conv:
79 | moments = self.quant_conv(h)
80 | posterior = DiagonalGaussianDistribution(moments)
81 | return posterior
82 |
83 | def decode(self, z, **kwargs):
84 | if self.use_quant_conv:
85 | z = self.post_quant_conv(z)
86 | dec = self.decoder(z)
87 | return dec
88 |
89 | def forward(self, input, sample_posterior=True):
90 | posterior = self.encode(input)
91 | if sample_posterior:
92 | z = posterior.sample()
93 | else:
94 | z = posterior.mode()
95 | dec = self.decode(z)
96 | return dec, posterior
97 |
98 | def get_input(self, batch, k):
99 | x = batch[k]
100 | if x.dim() == 5 and self.input_dim == 4:
101 | b, c, t, h, w = x.shape
102 | self.b = b
103 | self.t = t
104 | x = rearrange(x, "b c t h w -> (b t) c h w")
105 |
106 | return x
107 |
108 | def training_step(self, batch, batch_idx, optimizer_idx):
109 | inputs = self.get_input(batch, self.image_key)
110 | reconstructions, posterior = self(inputs)
111 |
112 | if optimizer_idx == 0:
113 | # train encoder+decoder+logvar
114 | aeloss, log_dict_ae = self.loss(
115 | inputs,
116 | reconstructions,
117 | posterior,
118 | optimizer_idx,
119 | self.global_step,
120 | last_layer=self.get_last_layer(),
121 | split="train",
122 | )
123 | self.log(
124 | "aeloss",
125 | aeloss,
126 | prog_bar=True,
127 | logger=True,
128 | on_step=True,
129 | on_epoch=True,
130 | )
131 | self.log_dict(
132 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
133 | )
134 | return aeloss
135 |
136 | if optimizer_idx == 1:
137 | # train the discriminator
138 | discloss, log_dict_disc = self.loss(
139 | inputs,
140 | reconstructions,
141 | posterior,
142 | optimizer_idx,
143 | self.global_step,
144 | last_layer=self.get_last_layer(),
145 | split="train",
146 | )
147 |
148 | self.log(
149 | "discloss",
150 | discloss,
151 | prog_bar=True,
152 | logger=True,
153 | on_step=True,
154 | on_epoch=True,
155 | )
156 | self.log_dict(
157 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
158 | )
159 | return discloss
160 |
161 | def validation_step(self, batch, batch_idx):
162 | inputs = self.get_input(batch, self.image_key)
163 | reconstructions, posterior = self(inputs)
164 | aeloss, log_dict_ae = self.loss(
165 | inputs,
166 | reconstructions,
167 | posterior,
168 | 0,
169 | self.global_step,
170 | last_layer=self.get_last_layer(),
171 | split="val",
172 | )
173 |
174 | discloss, log_dict_disc = self.loss(
175 | inputs,
176 | reconstructions,
177 | posterior,
178 | 1,
179 | self.global_step,
180 | last_layer=self.get_last_layer(),
181 | split="val",
182 | )
183 |
184 | recontructions = reconstructions.cpu().detach()
185 |
186 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
187 | self.log_dict(log_dict_ae)
188 | self.log_dict(log_dict_disc)
189 | return self.log_dict
190 |
191 | def configure_optimizers(self):
192 | lr = self.learning_rate
193 | opt_ae = torch.optim.Adam(
194 | list(self.encoder.parameters())
195 | + list(self.decoder.parameters())
196 | + list(self.quant_conv.parameters())
197 | + list(self.post_quant_conv.parameters()),
198 | lr=lr,
199 | betas=(0.5, 0.9),
200 | )
201 | opt_disc = torch.optim.Adam(
202 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
203 | )
204 | return [opt_ae, opt_disc], []
205 |
206 | def get_last_layer(self):
207 | return self.decoder.conv_out.weight
208 |
209 | @torch.no_grad()
210 | def log_images(self, batch, only_inputs=False, **kwargs):
211 | log = dict()
212 | x = self.get_input(batch, self.image_key)
213 | x = x.to(self.device)
214 | if not only_inputs:
215 | xrec, posterior = self(x)
216 | if x.shape[1] > 3:
217 | # colorize with random projection
218 | assert xrec.shape[1] > 3
219 | x = self.to_rgb(x)
220 | xrec = self.to_rgb(xrec)
221 |
222 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
223 | xrec = xrec.cpu().detach()
224 | log["reconstructions"] = xrec
225 |
226 | x = x.cpu().detach()
227 | log["inputs"] = x
228 | return log
229 |
230 | def to_rgb(self, x):
231 | assert self.image_key == "segmentation"
232 | if not hasattr(self, "colorize"):
233 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
234 | x = F.conv2d(x, weight=self.colorize)
235 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
236 | return x
237 |
238 |
239 | class IdentityFirstStage(torch.nn.Module):
240 | def __init__(self, *args, vq_interface=False, **kwargs):
241 | # TODO: Should be true by default but check to not break older stuff
242 | self.vq_interface = vq_interface
243 | super().__init__()
244 |
245 | def encode(self, x, *args, **kwargs):
246 | return x
247 |
248 | def decode(self, x, *args, **kwargs):
249 | return x
250 |
251 | def quantize(self, x, *args, **kwargs):
252 | if self.vq_interface:
253 | return x, None, [None, None, None]
254 | return x
255 |
256 | def forward(self, x, *args, **kwargs):
257 | return x
258 |
--------------------------------------------------------------------------------
/src/models/autoencoder_temporal.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from src.modules.attention_temporal_videoae import *
6 | from einops import rearrange, reduce, repeat
7 |
8 | try:
9 | import xformers
10 | import xformers.ops as xops
11 |
12 | XFORMERS_IS_AVAILBLE = True
13 | except:
14 | XFORMERS_IS_AVAILBLE = False
15 |
16 |
17 | def silu(x):
18 | # swish
19 | return x * torch.sigmoid(x)
20 |
21 |
22 | class SiLU(nn.Module):
23 | def __init__(self):
24 | super(SiLU, self).__init__()
25 |
26 | def forward(self, x):
27 | return silu(x)
28 |
29 |
30 | def Normalize(in_channels, norm_type="group"):
31 | assert norm_type in ["group", "batch"]
32 | if norm_type == "group":
33 | return torch.nn.GroupNorm(
34 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
35 | )
36 | elif norm_type == "batch":
37 | return torch.nn.SyncBatchNorm(in_channels)
38 |
39 |
40 | # Does not support dilation
41 |
42 |
43 | class SamePadConv3d(nn.Module):
44 | def __init__(
45 | self,
46 | in_channels,
47 | out_channels,
48 | kernel_size,
49 | stride=1,
50 | bias=True,
51 | padding_type="replicate",
52 | ):
53 | super().__init__()
54 | if isinstance(kernel_size, int):
55 | kernel_size = (kernel_size,) * 3
56 | if isinstance(stride, int):
57 | stride = (stride,) * 3
58 |
59 | # assumes that the input shape is divisible by stride
60 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
61 | pad_input = []
62 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim
63 | pad_input.append((p // 2 + p % 2, p // 2))
64 | pad_input = sum(pad_input, tuple())
65 | self.pad_input = pad_input
66 | self.padding_type = padding_type
67 |
68 | self.conv = nn.Conv3d(
69 | in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias
70 | )
71 |
72 | def forward(self, x):
73 | # print(x.dtype)
74 | return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
75 |
76 |
77 | class SamePadConvTranspose3d(nn.Module):
78 | def __init__(
79 | self,
80 | in_channels,
81 | out_channels,
82 | kernel_size,
83 | stride=1,
84 | bias=True,
85 | padding_type="replicate",
86 | ):
87 | super().__init__()
88 | if isinstance(kernel_size, int):
89 | kernel_size = (kernel_size,) * 3
90 | if isinstance(stride, int):
91 | stride = (stride,) * 3
92 |
93 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
94 | pad_input = []
95 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim
96 | pad_input.append((p // 2 + p % 2, p // 2))
97 | pad_input = sum(pad_input, tuple())
98 | self.pad_input = pad_input
99 | self.padding_type = padding_type
100 |
101 | self.convt = nn.ConvTranspose3d(
102 | in_channels,
103 | out_channels,
104 | kernel_size,
105 | stride=stride,
106 | bias=bias,
107 | padding=tuple([k - 1 for k in kernel_size]),
108 | )
109 |
110 | def forward(self, x):
111 | return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
112 |
113 |
114 | class ResBlock(nn.Module):
115 | def __init__(
116 | self,
117 | in_channels,
118 | out_channels=None,
119 | conv_shortcut=False,
120 | dropout=0.0,
121 | norm_type="group",
122 | padding_type="replicate",
123 | ):
124 | super().__init__()
125 | self.in_channels = in_channels
126 | out_channels = in_channels if out_channels is None else out_channels
127 | self.out_channels = out_channels
128 | self.use_conv_shortcut = conv_shortcut
129 |
130 | self.norm1 = Normalize(in_channels, norm_type)
131 | self.conv1 = SamePadConv3d(
132 | in_channels, out_channels, kernel_size=3, padding_type=padding_type
133 | )
134 | self.dropout = torch.nn.Dropout(dropout)
135 | self.norm2 = Normalize(in_channels, norm_type)
136 | self.conv2 = SamePadConv3d(
137 | out_channels, out_channels, kernel_size=3, padding_type=padding_type
138 | )
139 | if self.in_channels != self.out_channels:
140 | self.conv_shortcut = SamePadConv3d(
141 | in_channels, out_channels, kernel_size=3, padding_type=padding_type
142 | )
143 |
144 | def forward(self, x):
145 | h = x
146 | h = self.norm1(h)
147 | h = silu(h)
148 | h = self.conv1(h)
149 | h = self.norm2(h)
150 | h = silu(h)
151 | h = self.conv2(h)
152 |
153 | if self.in_channels != self.out_channels:
154 | x = self.conv_shortcut(x)
155 |
156 | return x + h
157 |
158 |
159 | class SpatialCrossAttention(nn.Module):
160 | def __init__(
161 | self,
162 | query_dim,
163 | patch_size=1,
164 | context_dim=None,
165 | heads=8,
166 | dim_head=64,
167 | dropout=0.0,
168 | ):
169 | super().__init__()
170 | inner_dim = dim_head * heads
171 | context_dim = default(context_dim, query_dim)
172 |
173 | self.scale = dim_head**-0.5
174 | self.heads = heads
175 | self.dim_head = dim_head
176 |
177 | # print(f"query dimension is {query_dim}")
178 |
179 | self.patch_size = patch_size
180 | patch_dim = query_dim * patch_size * patch_size
181 | self.norm = nn.LayerNorm(patch_dim)
182 |
183 | self.to_q = nn.Linear(patch_dim, inner_dim, bias=False)
184 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
185 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
186 |
187 | self.to_out = nn.Sequential(
188 | nn.Linear(inner_dim, patch_dim), nn.Dropout(dropout)
189 | )
190 | self.attention_op: Optional[Any] = None
191 |
192 | def forward(self, x, context=None, mask=None):
193 | b, c, t, height, width = x.shape
194 |
195 | # patch: [patch_size, patch_size]
196 | divide_factor_height = height // self.patch_size
197 | divide_factor_width = width // self.patch_size
198 | x = rearrange(
199 | x,
200 | "b c t (df1 ph) (df2 pw) -> (b t) (df1 df2) (ph pw c)",
201 | df1=divide_factor_height,
202 | df2=divide_factor_width,
203 | ph=self.patch_size,
204 | pw=self.patch_size,
205 | )
206 | x = self.norm(x)
207 |
208 | context = default(context, x)
209 | context = repeat(context, "b n d -> (b t) n d", b=b, t=t)
210 |
211 | q = self.to_q(x)
212 | k = self.to_k(context)
213 | v = self.to_v(context)
214 |
215 | q, k, v = map(
216 | lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.heads), (q, k, v)
217 | )
218 |
219 | if exists(mask):
220 | mask = rearrange(mask, "b ... -> b (...)")
221 | mask = repeat(mask, "b j -> (b t h) () j", t=t, h=self.heads)
222 |
223 | if XFORMERS_IS_AVAILBLE:
224 | if exists(mask):
225 | mask = mask.to(q.dtype)
226 | max_neg_value = -torch.finfo(q.dtype).max
227 |
228 | attn_bias = torch.zeros_like(mask)
229 | attn_bias.masked_fill_(mask <= 0.5, max_neg_value)
230 |
231 | mask = mask.detach().cpu()
232 | attn_bias = attn_bias.expand(-1, q.shape[1], -1)
233 |
234 | attn_bias_expansion_q = (attn_bias.shape[1] + 7) // 8 * 8
235 | attn_bias_expansion_k = (attn_bias.shape[2] + 7) // 8 * 8
236 |
237 | attn_bias_expansion = torch.zeros(
238 | (attn_bias.shape[0], attn_bias_expansion_q, attn_bias_expansion_k),
239 | dtype=attn_bias.dtype,
240 | device=attn_bias.device,
241 | )
242 | attn_bias_expansion[:, : attn_bias.shape[1], : attn_bias.shape[2]] = (
243 | attn_bias
244 | )
245 |
246 | attn_bias = attn_bias.detach().cpu()
247 |
248 | out = xops.memory_efficient_attention(
249 | q,
250 | k,
251 | v,
252 | attn_bias=attn_bias_expansion[
253 | :, : attn_bias.shape[1], : attn_bias.shape[2]
254 | ],
255 | scale=self.scale,
256 | )
257 | else:
258 | out = xops.memory_efficient_attention(q, k, v, scale=self.scale)
259 | else:
260 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
261 | if exists(mask):
262 | max_neg_value = -torch.finfo(sim.dtype).max
263 | sim.masked_fill_(~(mask > 0.5), max_neg_value)
264 | attn = sim.softmax(dim=-1)
265 | out = einsum("b i j, b j d -> b i d", attn, v)
266 |
267 | out = rearrange(out, "(b h) n d -> b n (h d)", h=self.heads)
268 |
269 | ret = self.to_out(out)
270 | ret = rearrange(
271 | ret,
272 | "(b t) (df1 df2) (ph pw c) -> b c t (df1 ph) (df2 pw)",
273 | b=b,
274 | t=t,
275 | df1=divide_factor_height,
276 | df2=divide_factor_width,
277 | ph=self.patch_size,
278 | pw=self.patch_size,
279 | )
280 | return ret
281 |
282 |
283 | # ---------------------------------------------------------------------------------------------------=
284 |
285 |
286 | class EncoderTemporal1DCNN(nn.Module):
287 | def __init__(
288 | self,
289 | *,
290 | ch,
291 | out_ch,
292 | attn_temporal_factor=[],
293 | temporal_scale_factor=4,
294 | hidden_channel=128,
295 | **ignore_kwargs
296 | ):
297 | super().__init__()
298 |
299 | self.ch = ch
300 | self.temb_ch = 0
301 | self.temporal_scale_factor = temporal_scale_factor
302 |
303 | # conv_in + resblock + down_block + resblock + down_block + final_block
304 | self.conv_in = SamePadConv3d(
305 | ch, hidden_channel, kernel_size=3, padding_type="replicate"
306 | )
307 |
308 | self.mid_blocks = nn.ModuleList()
309 |
310 | num_ds = int(math.log2(temporal_scale_factor))
311 | norm_type = "group"
312 |
313 | curr_temporal_factor = 1
314 | for i in range(num_ds):
315 | block = nn.Module()
316 | # compute in_ch, out_ch, stride
317 | in_channels = hidden_channel * 2**i
318 | out_channels = hidden_channel * 2 ** (i + 1)
319 | temporal_stride = 2
320 | curr_temporal_factor = curr_temporal_factor * 2
321 |
322 | block.down = SamePadConv3d(
323 | in_channels,
324 | out_channels,
325 | kernel_size=3,
326 | stride=(temporal_stride, 1, 1),
327 | padding_type="replicate",
328 | )
329 | block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
330 |
331 | block.attn = nn.ModuleList()
332 | if curr_temporal_factor in attn_temporal_factor:
333 | block.attn.append(
334 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
335 | )
336 |
337 | self.mid_blocks.append(block)
338 | # n_times_downsample -= 1
339 |
340 | self.final_block = nn.Sequential(
341 | Normalize(out_channels, norm_type),
342 | SiLU(),
343 | SamePadConv3d(
344 | out_channels, out_ch * 2, kernel_size=3, padding_type="replicate"
345 | ),
346 | )
347 |
348 | self.initialize_weights()
349 |
350 | def initialize_weights(self):
351 | # Initialize transformer layers:
352 | def _basic_init(module):
353 | if isinstance(module, nn.Linear):
354 | if module.weight.requires_grad_:
355 | torch.nn.init.xavier_uniform_(module.weight)
356 | if module.bias is not None:
357 | nn.init.constant_(module.bias, 0)
358 | if isinstance(module, nn.Conv3d):
359 | torch.nn.init.xavier_uniform_(module.weight)
360 | if module.bias is not None:
361 | nn.init.constant_(module.bias, 0)
362 |
363 | self.apply(_basic_init)
364 |
365 | def forward(self, x, text_embeddings=None, text_attn_mask=None):
366 | # x: [b c t h w]
367 | # x: [1, 4, 16, 32, 32]
368 | # timestep embedding
369 | h = self.conv_in(x)
370 | for block in self.mid_blocks:
371 | h = block.down(h)
372 | h = block.res(h)
373 | if len(block.attn) > 0:
374 | for attn in block.attn:
375 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
376 |
377 | h = self.final_block(h)
378 |
379 | return h
380 |
381 |
382 | class TemporalUpsample(nn.Module):
383 | def __init__(
384 | self, size=None, scale_factor=None, mode="nearest", align_corners=None
385 | ):
386 | super(TemporalUpsample, self).__init__()
387 | self.size = size
388 | self.scale_factor = scale_factor
389 | self.mode = mode
390 | self.align_corners = align_corners
391 |
392 | def forward(self, x):
393 | return F.interpolate(
394 | x,
395 | size=self.size,
396 | scale_factor=self.scale_factor,
397 | mode=self.mode,
398 | align_corners=self.align_corners,
399 | )
400 |
401 |
402 | class DecoderTemporal1DCNN(nn.Module):
403 | def __init__(
404 | self,
405 | *,
406 | ch,
407 | out_ch,
408 | attn_temporal_factor=[],
409 | temporal_scale_factor=4,
410 | hidden_channel=128,
411 | **ignore_kwargs
412 | ):
413 | super().__init__()
414 |
415 | self.ch = ch
416 | self.temb_ch = 0
417 | self.temporal_scale_factor = temporal_scale_factor
418 |
419 | num_us = int(math.log2(temporal_scale_factor))
420 | norm_type = "group"
421 |
422 | # conv_in, mid_blocks, final_block
423 | # out channel of encoder, before the last conv layer
424 | enc_out_channels = hidden_channel * 2**num_us
425 | self.conv_in = SamePadConv3d(
426 | ch, enc_out_channels, kernel_size=3, padding_type="replicate"
427 | )
428 |
429 | self.mid_blocks = nn.ModuleList()
430 | curr_temporal_factor = self.temporal_scale_factor
431 |
432 | for i in range(num_us):
433 | block = nn.Module()
434 | in_channels = (
435 | enc_out_channels if i == 0 else hidden_channel * 2 ** (num_us - i + 1)
436 | ) # max_us: 3
437 | out_channels = hidden_channel * 2 ** (num_us - i)
438 | temporal_stride = 2
439 | # block.up = SamePadConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=(temporal_stride, 1, 1))
440 | block.up = torch.nn.ConvTranspose3d(
441 | in_channels,
442 | out_channels,
443 | kernel_size=(3, 3, 3),
444 | stride=(2, 1, 1),
445 | padding=(1, 1, 1),
446 | output_padding=(1, 0, 0),
447 | )
448 | block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
449 | block.attn1 = nn.ModuleList()
450 |
451 | if curr_temporal_factor in attn_temporal_factor:
452 | block.attn1.append(
453 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
454 | )
455 |
456 | block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
457 |
458 | block.attn2 = nn.ModuleList()
459 | if curr_temporal_factor in attn_temporal_factor:
460 | block.attn2.append(
461 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024)
462 | )
463 |
464 | curr_temporal_factor = curr_temporal_factor / 2
465 | self.mid_blocks.append(block)
466 |
467 | self.conv_last = SamePadConv3d(out_channels, out_ch, kernel_size=3)
468 |
469 | self.initialize_weights()
470 |
471 | def initialize_weights(self):
472 | # Initialize transformer layers:
473 | def _basic_init(module):
474 | if isinstance(module, nn.Linear):
475 | if module.weight.requires_grad_:
476 | torch.nn.init.xavier_uniform_(module.weight)
477 | if module.bias is not None:
478 | nn.init.constant_(module.bias, 0)
479 | if isinstance(module, nn.Conv3d):
480 | torch.nn.init.xavier_uniform_(module.weight)
481 | if module.bias is not None:
482 | nn.init.constant_(module.bias, 0)
483 | if isinstance(module, nn.ConvTranspose3d):
484 | torch.nn.init.xavier_uniform_(module.weight)
485 | if module.bias is not None:
486 | nn.init.constant_(module.bias, 0)
487 |
488 | self.apply(_basic_init)
489 |
490 | def forward(self, x, text_embeddings=None, text_attn_mask=None):
491 | # x: [b c t h w]
492 | h = self.conv_in(x)
493 | for i, block in enumerate(self.mid_blocks):
494 | h = block.up(h)
495 | h = block.res1(h)
496 | if len(block.attn1) > 0:
497 | for attn in block.attn1:
498 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
499 |
500 | h = block.res2(h)
501 | if len(block.attn2) > 0:
502 | for attn in block.attn2:
503 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h
504 |
505 | h = self.conv_last(h)
506 |
507 | return h
508 |
--------------------------------------------------------------------------------
/src/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from src.modules.losses.contperceptual import (
2 | LPIPSWithDiscriminator,
3 | MSEWithDiscriminator,
4 | LPIPSWithDiscriminator3D,
5 | )
6 |
--------------------------------------------------------------------------------
/src/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 | import functools
6 |
7 |
8 | class LPIPSWithDiscriminator(nn.Module):
9 | def __init__(
10 | self,
11 | disc_start,
12 | logvar_init=0.0,
13 | kl_weight=1.0,
14 | pixelloss_weight=1.0,
15 | disc_num_layers=3,
16 | disc_in_channels=3,
17 | disc_factor=1.0,
18 | disc_weight=1.0,
19 | perceptual_weight=1.0,
20 | use_actnorm=False,
21 | disc_conditional=False,
22 | disc_loss="hinge",
23 | max_bs=None,
24 | ):
25 |
26 | super().__init__()
27 | assert disc_loss in ["hinge", "vanilla"]
28 | self.kl_weight = kl_weight
29 | self.pixel_weight = pixelloss_weight
30 | self.perceptual_loss = LPIPS().eval()
31 | self.perceptual_weight = perceptual_weight
32 | # output log variance
33 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
34 |
35 | self.discriminator = NLayerDiscriminator(
36 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
37 | ).apply(weights_init)
38 | self.discriminator_iter_start = disc_start
39 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
40 | self.disc_factor = disc_factor
41 | self.discriminator_weight = disc_weight
42 | self.disc_conditional = disc_conditional
43 | self.max_bs = max_bs
44 |
45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
46 | if last_layer is not None:
47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
49 | else:
50 | nll_grads = torch.autograd.grad(
51 | nll_loss, self.last_layer[0], retain_graph=True
52 | )[0]
53 | g_grads = torch.autograd.grad(
54 | g_loss, self.last_layer[0], retain_graph=True
55 | )[0]
56 |
57 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
58 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
59 | d_weight = d_weight * self.discriminator_weight
60 | return d_weight
61 |
62 | def forward(
63 | self,
64 | inputs,
65 | reconstructions,
66 | posteriors,
67 | optimizer_idx,
68 | global_step,
69 | last_layer=None,
70 | cond=None,
71 | split="train",
72 | weights=None,
73 | ):
74 | if inputs.dim() == 5:
75 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
76 | if reconstructions.dim() == 5:
77 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
78 |
79 | # print('loss shape: ', inputs.shape, reconstructions.shape)
80 | # exit()
81 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
82 | if self.perceptual_weight > 0:
83 | if self.max_bs is not None and self.max_bs < inputs.shape[0]:
84 | input_list = torch.split(inputs, self.max_bs, dim=0)
85 | reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0)
86 | p_losses = [
87 | self.perceptual_loss(
88 | inputs.contiguous(), reconstructions.contiguous()
89 | )
90 | for inputs, reconstructions in zip(input_list, reconstruction_list)
91 | ]
92 | p_loss = torch.cat(p_losses, dim=0)
93 | else:
94 | p_loss = self.perceptual_loss(
95 | inputs.contiguous(), reconstructions.contiguous()
96 | )
97 | rec_loss = rec_loss + self.perceptual_weight * p_loss
98 |
99 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
100 | weighted_nll_loss = nll_loss
101 | if weights is not None:
102 | weighted_nll_loss = weights * nll_loss
103 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
104 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
105 |
106 | kl_loss = posteriors.kl()
107 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
108 |
109 | if global_step < self.discriminator_iter_start:
110 | loss = weighted_nll_loss + self.kl_weight * kl_loss
111 | log = {
112 | "{}/total_loss".format(split): loss.clone().detach().mean(),
113 | "{}/logvar".format(split): self.logvar.detach(),
114 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
115 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
116 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
117 | }
118 |
119 | return loss, log
120 |
121 | # now the GAN part
122 | if optimizer_idx == 0:
123 | # generator update
124 | if cond is None:
125 | assert not self.disc_conditional
126 | logits_fake = self.discriminator(reconstructions.contiguous())
127 | else:
128 | assert self.disc_conditional
129 | logits_fake = self.discriminator(
130 | torch.cat((reconstructions.contiguous(), cond), dim=1)
131 | )
132 | g_loss = -torch.mean(logits_fake)
133 |
134 | if self.disc_factor > 0.0:
135 | try:
136 | d_weight = self.calculate_adaptive_weight(
137 | nll_loss, g_loss, last_layer=last_layer
138 | )
139 | except RuntimeError:
140 | assert not self.training
141 | d_weight = torch.tensor(0.0)
142 | else:
143 | d_weight = torch.tensor(0.0)
144 |
145 | disc_factor = adopt_weight(
146 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
147 | )
148 | loss = (
149 | weighted_nll_loss
150 | + self.kl_weight * kl_loss
151 | + d_weight * disc_factor * g_loss
152 | )
153 |
154 | log = {
155 | "{}/total_loss".format(split): loss.clone().detach().mean(),
156 | "{}/logvar".format(split): self.logvar.detach(),
157 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
158 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
159 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
160 | "{}/d_weight".format(split): d_weight.detach(),
161 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
162 | "{}/g_loss".format(split): g_loss.detach().mean(),
163 | }
164 | return loss, log
165 |
166 | if optimizer_idx == 1:
167 | # second pass for discriminator update
168 | if cond is None:
169 | logits_real = self.discriminator(inputs.contiguous().detach())
170 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
171 | else:
172 | logits_real = self.discriminator(
173 | torch.cat((inputs.contiguous().detach(), cond), dim=1)
174 | )
175 | logits_fake = self.discriminator(
176 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
177 | )
178 |
179 | disc_factor = adopt_weight(
180 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
181 | )
182 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
183 |
184 | log = {
185 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
186 | "{}/logits_real".format(split): logits_real.detach().mean(),
187 | "{}/logits_fake".format(split): logits_fake.detach().mean(),
188 | }
189 | return d_loss, log
190 |
191 |
192 | ### Modified for 1dcnn lpips -> mse
193 |
194 |
195 | class MSEWithDiscriminator(nn.Module):
196 | def __init__(
197 | self,
198 | disc_start,
199 | logvar_init=0.0,
200 | kl_weight=1.0,
201 | pixelloss_weight=1.0,
202 | disc_num_layers=3,
203 | disc_in_channels=4,
204 | disc_factor=1.0,
205 | disc_weight=1.0,
206 | perceptual_weight=1.0,
207 | use_actnorm=False,
208 | disc_conditional=False,
209 | disc_loss="hinge",
210 | max_bs=None,
211 | ):
212 |
213 | super().__init__()
214 | assert disc_loss in ["hinge", "vanilla"]
215 | self.kl_weight = kl_weight
216 | self.pixel_weight = pixelloss_weight
217 | self.perceptual_loss = nn.MSELoss()
218 | self.perceptual_weight = perceptual_weight
219 | # output log variance
220 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
221 |
222 | self.discriminator = NLayerDiscriminator(
223 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
224 | ).apply(weights_init)
225 | self.discriminator_iter_start = disc_start
226 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
227 | self.disc_factor = disc_factor
228 | self.discriminator_weight = disc_weight
229 | self.disc_conditional = disc_conditional
230 | self.max_bs = max_bs
231 |
232 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
233 | if last_layer is not None:
234 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
235 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
236 | else:
237 | nll_grads = torch.autograd.grad(
238 | nll_loss, self.last_layer[0], retain_graph=True
239 | )[0]
240 | g_grads = torch.autograd.grad(
241 | g_loss, self.last_layer[0], retain_graph=True
242 | )[0]
243 |
244 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
245 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
246 | d_weight = d_weight * self.discriminator_weight
247 | return d_weight
248 |
249 | def forward(
250 | self,
251 | inputs,
252 | reconstructions,
253 | posteriors,
254 | optimizer_idx,
255 | global_step,
256 | last_layer=None,
257 | cond=None,
258 | split="train",
259 | weights=None,
260 | ):
261 | if inputs.dim() == 5:
262 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
263 | if reconstructions.dim() == 5:
264 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
265 |
266 | # print('loss shape: ', inputs.shape, reconstructions.shape)
267 | # exit()
268 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
269 | if self.perceptual_weight > 0:
270 | p_loss = self.perceptual_loss(
271 | inputs.contiguous(), reconstructions.contiguous()
272 | )
273 | rec_loss = rec_loss + self.perceptual_weight * p_loss
274 |
275 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
276 | weighted_nll_loss = nll_loss
277 | if weights is not None:
278 | weighted_nll_loss = weights * nll_loss
279 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
280 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
281 |
282 | kl_loss = posteriors.kl()
283 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
284 |
285 | if global_step < self.discriminator_iter_start:
286 | loss = weighted_nll_loss + self.kl_weight * kl_loss
287 | log = {
288 | "{}/total_loss".format(split): loss.clone().detach().mean(),
289 | "{}/logvar".format(split): self.logvar.detach(),
290 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
291 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
292 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
293 | }
294 |
295 | return loss, log
296 |
297 | # now the GAN part
298 | if optimizer_idx == 0:
299 | # generator update
300 | if cond is None:
301 | assert not self.disc_conditional
302 | logits_fake = self.discriminator(reconstructions.contiguous())
303 | else:
304 | assert self.disc_conditional
305 | logits_fake = self.discriminator(
306 | torch.cat((reconstructions.contiguous(), cond), dim=1)
307 | )
308 | g_loss = -torch.mean(logits_fake)
309 |
310 | if self.disc_factor > 0.0:
311 | try:
312 | d_weight = self.calculate_adaptive_weight(
313 | nll_loss, g_loss, last_layer=last_layer
314 | )
315 | except RuntimeError:
316 | assert not self.training
317 | d_weight = torch.tensor(0.0)
318 | else:
319 | d_weight = torch.tensor(0.0)
320 |
321 | disc_factor = adopt_weight(
322 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
323 | )
324 | loss = (
325 | weighted_nll_loss
326 | + self.kl_weight * kl_loss
327 | + d_weight * disc_factor * g_loss
328 | )
329 |
330 | log = {
331 | "{}/total_loss".format(split): loss.clone().detach().mean(),
332 | "{}/logvar".format(split): self.logvar.detach(),
333 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
334 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
335 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
336 | "{}/d_weight".format(split): d_weight.detach(),
337 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
338 | "{}/g_loss".format(split): g_loss.detach().mean(),
339 | }
340 | return loss, log
341 |
342 | if optimizer_idx == 1:
343 | # second pass for discriminator update
344 | if cond is None:
345 | logits_real = self.discriminator(inputs.contiguous().detach())
346 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
347 | else:
348 | logits_real = self.discriminator(
349 | torch.cat((inputs.contiguous().detach(), cond), dim=1)
350 | )
351 | logits_fake = self.discriminator(
352 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
353 | )
354 |
355 | disc_factor = adopt_weight(
356 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
357 | )
358 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
359 |
360 | log = {
361 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
362 | "{}/logits_real".format(split): logits_real.detach().mean(),
363 | "{}/logits_fake".format(split): logits_fake.detach().mean(),
364 | }
365 | return d_loss, log
366 |
367 |
368 | class NLayerDiscriminator3D(nn.Module):
369 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
370 |
371 | def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
372 | """
373 | Construct a 3D PatchGAN discriminator
374 |
375 | Parameters:
376 | input_nc (int) -- the number of channels in input volumes
377 | ndf (int) -- the number of filters in the last conv layer
378 | n_layers (int) -- the number of conv layers in the discriminator
379 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm
380 | """
381 | super(NLayerDiscriminator3D, self).__init__()
382 | if not use_actnorm:
383 | norm_layer = nn.BatchNorm3d
384 | else:
385 | raise NotImplementedError("Not implemented.")
386 | if type(norm_layer) == functools.partial:
387 | use_bias = norm_layer.func != nn.BatchNorm3d
388 | else:
389 | use_bias = norm_layer != nn.BatchNorm3d
390 |
391 | kw = 3
392 | padw = 1
393 | sequence = [
394 | nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
395 | nn.LeakyReLU(0.2, True),
396 | ]
397 | nf_mult = 1
398 | nf_mult_prev = 1
399 | for n in range(1, n_layers): # gradually increase the number of filters
400 | nf_mult_prev = nf_mult
401 | nf_mult = min(2**n, 8)
402 | sequence += [
403 | nn.Conv3d(
404 | ndf * nf_mult_prev,
405 | ndf * nf_mult,
406 | kernel_size=(kw, kw, kw),
407 | stride=(2 if n == 1 else 1, 2, 2),
408 | padding=padw,
409 | bias=use_bias,
410 | ),
411 | norm_layer(ndf * nf_mult),
412 | nn.LeakyReLU(0.2, True),
413 | ]
414 |
415 | nf_mult_prev = nf_mult
416 | nf_mult = min(2**n_layers, 8)
417 | sequence += [
418 | nn.Conv3d(
419 | ndf * nf_mult_prev,
420 | ndf * nf_mult,
421 | kernel_size=(kw, kw, kw),
422 | stride=1,
423 | padding=padw,
424 | bias=use_bias,
425 | ),
426 | norm_layer(ndf * nf_mult),
427 | nn.LeakyReLU(0.2, True),
428 | ]
429 |
430 | sequence += [
431 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
432 | ] # output 1 channel prediction map
433 | self.main = nn.Sequential(*sequence)
434 |
435 | def forward(self, input):
436 | """Standard forward."""
437 | return self.main(input)
438 |
439 |
440 | class LPIPSWithDiscriminator3D(nn.Module):
441 | def __init__(
442 | self,
443 | disc_start,
444 | logvar_init=0.0,
445 | kl_weight=1.0,
446 | pixelloss_weight=1.0,
447 | perceptual_weight=1.0,
448 | # --- Discriminator Loss ---
449 | disc_num_layers=3,
450 | disc_in_channels=3,
451 | disc_factor=1.0,
452 | disc_weight=1.0,
453 | use_actnorm=False,
454 | disc_conditional=False,
455 | disc_loss="hinge",
456 | ):
457 |
458 | super().__init__()
459 | assert disc_loss in ["hinge", "vanilla"]
460 | self.kl_weight = kl_weight
461 | self.pixel_weight = pixelloss_weight
462 | self.perceptual_loss = LPIPS().eval()
463 | self.perceptual_weight = perceptual_weight
464 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
465 |
466 | self.discriminator = NLayerDiscriminator3D(
467 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
468 | ).apply(weights_init)
469 | self.discriminator_iter_start = disc_start
470 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
471 | self.disc_factor = disc_factor
472 | self.discriminator_weight = disc_weight
473 | self.disc_conditional = disc_conditional
474 |
475 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
476 | if last_layer is not None:
477 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
478 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
479 | else:
480 | nll_grads = torch.autograd.grad(
481 | nll_loss, self.last_layer[0], retain_graph=True
482 | )[0]
483 | g_grads = torch.autograd.grad(
484 | g_loss, self.last_layer[0], retain_graph=True
485 | )[0]
486 |
487 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
488 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
489 | d_weight = d_weight * self.discriminator_weight
490 | return d_weight
491 |
492 | def forward(
493 | self,
494 | inputs,
495 | reconstructions,
496 | posteriors,
497 | optimizer_idx,
498 | global_step,
499 | split="train",
500 | weights=None,
501 | last_layer=None,
502 | cond=None,
503 | ):
504 | t = inputs.shape[2]
505 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w")
506 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w")
507 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
508 | if self.perceptual_weight > 0:
509 | p_loss = self.perceptual_loss(
510 | inputs.contiguous(), reconstructions.contiguous()
511 | )
512 | rec_loss = rec_loss + self.perceptual_weight * p_loss
513 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
514 | weighted_nll_loss = nll_loss
515 | if weights is not None:
516 | weighted_nll_loss = weights * nll_loss
517 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
518 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
519 | kl_loss = posteriors.kl()
520 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
521 |
522 | if global_step < self.discriminator_iter_start:
523 | loss = weighted_nll_loss + self.kl_weight * kl_loss
524 | log = {
525 | "{}/total_loss".format(split): loss.clone().detach().mean(),
526 | "{}/logvar".format(split): self.logvar.detach(),
527 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
528 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
529 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
530 | }
531 |
532 | return loss, log
533 |
534 | inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t)
535 | reconstructions = rearrange(reconstructions, "(b t) c h w -> b c t h w", t=t)
536 | # GAN Part
537 | if optimizer_idx == 0:
538 | # generator update
539 | if cond is None:
540 | assert not self.disc_conditional
541 | logits_fake = self.discriminator(reconstructions.contiguous())
542 | else:
543 | assert self.disc_conditional
544 | logits_fake = self.discriminator(
545 | torch.cat((reconstructions.contiguous(), cond), dim=1)
546 | )
547 | g_loss = -torch.mean(logits_fake)
548 |
549 | if self.disc_factor > 0.0:
550 | try:
551 | d_weight = self.calculate_adaptive_weight(
552 | nll_loss, g_loss, last_layer=last_layer
553 | )
554 | except RuntimeError as e:
555 | assert not self.training, print(e)
556 | d_weight = torch.tensor(0.0)
557 | else:
558 | d_weight = torch.tensor(0.0)
559 |
560 | disc_factor = adopt_weight(
561 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
562 | )
563 | loss = (
564 | weighted_nll_loss
565 | + self.kl_weight * kl_loss
566 | + d_weight * disc_factor * g_loss
567 | )
568 | log = {
569 | "{}/total_loss".format(split): loss.clone().detach().mean(),
570 | "{}/logvar".format(split): self.logvar.detach(),
571 | "{}/kl_loss".format(split): kl_loss.detach().mean(),
572 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
573 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
574 | "{}/d_weight".format(split): d_weight.detach(),
575 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
576 | "{}/g_loss".format(split): g_loss.detach().mean(),
577 | }
578 | return loss, log
579 |
580 | if optimizer_idx == 1:
581 | if cond is None:
582 | logits_real = self.discriminator(inputs.contiguous().detach())
583 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
584 | else:
585 | logits_real = self.discriminator(
586 | torch.cat((inputs.contiguous().detach(), cond), dim=1)
587 | )
588 | logits_fake = self.discriminator(
589 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
590 | )
591 |
592 | disc_factor = adopt_weight(
593 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
594 | )
595 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
596 |
597 | log = {
598 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
599 | "{}/logits_real".format(split): logits_real.detach().mean(),
600 | "{}/logits_fake".format(split): logits_fake.detach().mean(),
601 | }
602 | return d_loss, log
603 |
--------------------------------------------------------------------------------
/src/modules/t5.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import re
4 | import html
5 | import urllib.parse as ul
6 |
7 | import ftfy
8 | import torch
9 | from bs4 import BeautifulSoup
10 | from transformers import T5EncoderModel, AutoTokenizer
11 | from huggingface_hub import hf_hub_download
12 |
13 |
14 | class T5Embedder:
15 | available_models = ["flan-t5-large"]
16 | bad_punct_regex = re.compile(
17 | r"["
18 | + "#®•©™&@·º½¾¿¡§~"
19 | + "\)"
20 | + "\("
21 | + "\]"
22 | + "\["
23 | + "\}"
24 | + "\{"
25 | + "\|"
26 | + "\\"
27 | + "\/"
28 | + "\*"
29 | + r"]{1,}"
30 | ) # noqa
31 |
32 | def __init__(
33 | self,
34 | device,
35 | dir_or_name="flan-t5-large",
36 | *,
37 | local_cache=False,
38 | cache_dir=None,
39 | hf_token=None,
40 | use_text_preprocessing=True,
41 | t5_model_kwargs=None,
42 | torch_dtype=None,
43 | use_offload_folder=None,
44 | model_max_length=180,
45 | ):
46 | self.device = torch.device(device)
47 | print(f"T5 embedder is on {self.device}")
48 | self.torch_dtype = torch_dtype or torch.bfloat16
49 | if t5_model_kwargs is None:
50 | t5_model_kwargs = {
51 | "low_cpu_mem_usage": True,
52 | "torch_dtype": self.torch_dtype,
53 | }
54 | if use_offload_folder is not None:
55 | t5_model_kwargs["offload_folder"] = use_offload_folder
56 | t5_model_kwargs["device_map"] = {
57 | "shared": self.device,
58 | "encoder.embed_tokens": self.device,
59 | "encoder.block.0": self.device,
60 | "encoder.block.1": self.device,
61 | "encoder.block.2": self.device,
62 | "encoder.block.3": self.device,
63 | "encoder.block.4": self.device,
64 | "encoder.block.5": self.device,
65 | "encoder.block.6": self.device,
66 | "encoder.block.7": self.device,
67 | "encoder.block.8": self.device,
68 | "encoder.block.9": self.device,
69 | "encoder.block.10": self.device,
70 | "encoder.block.11": self.device,
71 | "encoder.block.12": "disk",
72 | "encoder.block.13": "disk",
73 | "encoder.block.14": "disk",
74 | "encoder.block.15": "disk",
75 | "encoder.block.16": "disk",
76 | "encoder.block.17": "disk",
77 | "encoder.block.18": "disk",
78 | "encoder.block.19": "disk",
79 | "encoder.block.20": "disk",
80 | "encoder.block.21": "disk",
81 | "encoder.block.22": "disk",
82 | "encoder.block.23": "disk",
83 | "encoder.final_layer_norm": "disk",
84 | "encoder.dropout": "disk",
85 | }
86 | else:
87 | t5_model_kwargs["device_map"] = {
88 | "shared": self.device,
89 | "encoder": self.device,
90 | }
91 |
92 | self.use_text_preprocessing = use_text_preprocessing
93 | self.hf_token = hf_token
94 | self.cache_dir = cache_dir or os.path.expanduser("~/.cache/IF_")
95 | self.dir_or_name = dir_or_name
96 | tokenizer_path, path = dir_or_name, dir_or_name
97 | if local_cache:
98 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
99 | tokenizer_path, path = cache_dir, cache_dir
100 | elif dir_or_name in self.available_models:
101 | cache_dir = os.path.join(self.cache_dir, dir_or_name)
102 | for filename in [
103 | "config.json",
104 | "special_tokens_map.json",
105 | "spiece.model",
106 | "tokenizer_config.json",
107 | "pytorch_model.bin",
108 | ]:
109 |
110 | hf_hub_download(
111 | repo_id=f"google/{dir_or_name}",
112 | filename=filename,
113 | cache_dir=cache_dir,
114 | force_filename=filename,
115 | token=self.hf_token,
116 | )
117 | tokenizer_path, path = cache_dir, cache_dir
118 | else:
119 | cache_dir = os.path.join(self.cache_dir, "flan-t5-large")
120 | for filename in [
121 | "config.json",
122 | "special_tokens_map.json",
123 | "spiece.model",
124 | "tokenizer_config.json",
125 | ]:
126 | hf_hub_download(
127 | repo_id="google/flan-t5-large",
128 | filename=filename,
129 | cache_dir=cache_dir,
130 | force_filename=filename,
131 | token=self.hf_token,
132 | )
133 | tokenizer_path = cache_dir
134 |
135 | print(tokenizer_path)
136 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
137 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
138 | self.model_max_length = model_max_length
139 |
140 | def get_text_embeddings(self, texts):
141 | texts = [self.text_preprocessing(text) for text in texts]
142 |
143 | # print(self.model_max_length)
144 |
145 | text_tokens_and_mask = self.tokenizer(
146 | texts,
147 | max_length=self.model_max_length,
148 | padding="max_length",
149 | truncation=True,
150 | return_attention_mask=True,
151 | add_special_tokens=True,
152 | return_tensors="pt",
153 | )
154 |
155 | text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"]
156 | text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"]
157 |
158 | with torch.no_grad():
159 | text_encoder_embs = self.model(
160 | input_ids=text_tokens_and_mask["input_ids"].to(self.device),
161 | attention_mask=text_tokens_and_mask["attention_mask"].to(self.device),
162 | )["last_hidden_state"].detach()
163 | return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
164 |
165 | def text_preprocessing(self, text):
166 | if self.use_text_preprocessing:
167 | # The exact text cleaning as was in the training stage:
168 | text = self.clean_caption(text)
169 | text = self.clean_caption(text)
170 | return text
171 | else:
172 | return text.lower().strip()
173 |
174 | @staticmethod
175 | def basic_clean(text):
176 | text = ftfy.fix_text(text)
177 | text = html.unescape(html.unescape(text))
178 | return text.strip()
179 |
180 | def clean_caption(self, caption):
181 | caption = str(caption)
182 | caption = ul.unquote_plus(caption)
183 | caption = caption.strip().lower()
184 | caption = re.sub("", "person", caption)
185 | # urls:
186 | caption = re.sub(
187 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
188 | "",
189 | caption,
190 | ) # regex for urls
191 | caption = re.sub(
192 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
193 | "",
194 | caption,
195 | ) # regex for urls
196 | # html:
197 | caption = BeautifulSoup(caption, features="html.parser").text
198 |
199 | # @
200 | caption = re.sub(r"@[\w\d]+\b", "", caption)
201 |
202 | # 31C0—31EF CJK Strokes
203 | # 31F0—31FF Katakana Phonetic Extensions
204 | # 3200—32FF Enclosed CJK Letters and Months
205 | # 3300—33FF CJK Compatibility
206 | # 3400—4DBF CJK Unified Ideographs Extension A
207 | # 4DC0—4DFF Yijing Hexagram Symbols
208 | # 4E00—9FFF CJK Unified Ideographs
209 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
210 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
211 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
212 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
213 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
214 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
215 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
216 | #######################################################
217 |
218 | # все виды тире / all types of dash --> "-"
219 | caption = re.sub(
220 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
221 | "-",
222 | caption,
223 | )
224 |
225 | # кавычки к одному стандарту
226 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
227 | caption = re.sub(r"[‘’]", "'", caption)
228 |
229 | # "
230 | caption = re.sub(r""?", "", caption)
231 | # &
232 | caption = re.sub(r"&", "", caption)
233 |
234 | # ip adresses:
235 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
236 |
237 | # article ids:
238 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
239 |
240 | # \n
241 | caption = re.sub(r"\\n", " ", caption)
242 |
243 | # "#123"
244 | caption = re.sub(r"#\d{1,3}\b", "", caption)
245 | # "#12345.."
246 | caption = re.sub(r"#\d{5,}\b", "", caption)
247 | # "123456.."
248 | caption = re.sub(r"\b\d{6,}\b", "", caption)
249 | # filenames:
250 | caption = re.sub(
251 | r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
252 | )
253 |
254 | #
255 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
256 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
257 |
258 | caption = re.sub(
259 | self.bad_punct_regex, r" ", caption
260 | ) # ***AUSVERKAUFT***, #AUSVERKAUFT
261 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
262 |
263 | # this-is-my-cute-cat / this_is_my_cute_cat
264 | regex2 = re.compile(r"(?:\-|\_)")
265 | if len(re.findall(regex2, caption)) > 3:
266 | caption = re.sub(regex2, " ", caption)
267 |
268 | caption = self.basic_clean(caption)
269 |
270 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
271 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
272 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
273 |
274 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
275 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
276 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
277 | caption = re.sub(
278 | r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
279 | )
280 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
281 |
282 | caption = re.sub(
283 | r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
284 | ) # j2d1a2a...
285 |
286 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
287 |
288 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
289 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
290 | caption = re.sub(r"\s+", " ", caption)
291 |
292 | caption.strip()
293 |
294 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
295 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
296 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
297 | caption = re.sub(r"^\.\S+$", "", caption)
298 |
299 | return caption.strip()
300 |
301 | def find_phrase_indices(self, sentence, phrase):
302 | sentence_tokens = self.tokenizer.tokenize(sentence)
303 | phrase_tokens = self.tokenizer.tokenize(phrase)
304 |
305 | phrase_len = len(phrase_tokens)
306 | for i in range(len(sentence_tokens) - phrase_len + 1):
307 | if sentence_tokens[i : i + phrase_len] == phrase_tokens:
308 | return i + 1, i + phrase_len + 1
309 | return None
310 |
--------------------------------------------------------------------------------
/src/modules/utils.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 | import torch.nn as nn
11 | from utils.common_utils import instantiate_from_config
12 |
13 | import math
14 | from inspect import isfunction
15 | import torch
16 | from torch import nn
17 | import torch.distributed as dist
18 |
19 |
20 | def gather_data(data, return_np=True):
21 | """gather data from multiple processes to one list"""
22 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
23 | dist.all_gather(data_list, data) # gather not supported with NCCL
24 | if return_np:
25 | data_list = [data.cpu().numpy() for data in data_list]
26 | return data_list
27 |
28 |
29 | def autocast(f):
30 | def do_autocast(*args, **kwargs):
31 | with torch.cuda.amp.autocast(
32 | enabled=True,
33 | dtype=torch.get_autocast_gpu_dtype(),
34 | cache_enabled=torch.is_autocast_cache_enabled(),
35 | ):
36 | return f(*args, **kwargs)
37 |
38 | return do_autocast
39 |
40 |
41 | def extract_into_tensor(a, t, x_shape):
42 | b, *_ = t.shape
43 | out = a.gather(-1, t)
44 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
45 |
46 |
47 | def noise_like(shape, device, repeat=False):
48 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
49 | shape[0], *((1,) * (len(shape) - 1))
50 | )
51 | noise = lambda: torch.randn(shape, device=device)
52 | return repeat_noise() if repeat else noise()
53 |
54 |
55 | def default(val, d):
56 | if exists(val):
57 | return val
58 | return d() if isfunction(d) else d
59 |
60 |
61 | def exists(val):
62 | return val is not None
63 |
64 |
65 | def identity(*args, **kwargs):
66 | return nn.Identity()
67 |
68 |
69 | def uniq(arr):
70 | return {el: True for el in arr}.keys()
71 |
72 |
73 | def mean_flat(tensor):
74 | """
75 | Take the mean over all non-batch dimensions.
76 | """
77 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
78 |
79 |
80 | def ismap(x):
81 | if not isinstance(x, torch.Tensor):
82 | return False
83 | return (len(x.shape) == 4) and (x.shape[1] > 3)
84 |
85 |
86 | def isimage(x):
87 | if not isinstance(x, torch.Tensor):
88 | return False
89 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
90 |
91 |
92 | def max_neg_value(t):
93 | return -torch.finfo(t.dtype).max
94 |
95 |
96 | def shape_to_str(x):
97 | shape_str = "x".join([str(x) for x in x.shape])
98 | return shape_str
99 |
100 |
101 | def init_(tensor):
102 | dim = tensor.shape[-1]
103 | std = 1 / math.sqrt(dim)
104 | tensor.uniform_(-std, std)
105 | return tensor
106 |
107 |
108 | ckpt = torch.utils.checkpoint.checkpoint
109 |
110 |
111 | def checkpoint(func, inputs, params, flag):
112 | """
113 | Evaluate a function without caching intermediate activations, allowing for
114 | reduced memory at the expense of extra compute in the backward pass.
115 | :param func: the function to evaluate.
116 | :param inputs: the argument sequence to pass to `func`.
117 | :param params: a sequence of parameters `func` depends on but does not
118 | explicitly take as arguments.
119 | :param flag: if False, disable gradient checkpointing.
120 | """
121 | if flag:
122 | return ckpt(func, *inputs)
123 | else:
124 | return func(*inputs)
125 |
126 |
127 | def disabled_train(self, mode=True):
128 | """Overwrite model.train with this function to make sure train/eval mode
129 | does not change anymore."""
130 | return self
131 |
132 |
133 | def zero_module(module):
134 | """
135 | Zero out the parameters of a module and return it.
136 | """
137 | for p in module.parameters():
138 | p.detach().zero_()
139 | return module
140 |
141 |
142 | def scale_module(module, scale):
143 | """
144 | Scale the parameters of a module and return it.
145 | """
146 | for p in module.parameters():
147 | p.detach().mul_(scale)
148 | return module
149 |
150 |
151 | def conv_nd(dims, *args, **kwargs):
152 | """
153 | Create a 1D, 2D, or 3D convolution module.
154 | """
155 | if dims == 1:
156 | return nn.Conv1d(*args, **kwargs)
157 | elif dims == 2:
158 | return nn.Conv2d(*args, **kwargs)
159 | elif dims == 3:
160 | return nn.Conv3d(*args, **kwargs)
161 | raise ValueError(f"unsupported dimensions: {dims}")
162 |
163 |
164 | def linear(*args, **kwargs):
165 | """
166 | Create a linear module.
167 | """
168 | return nn.Linear(*args, **kwargs)
169 |
170 |
171 | def avg_pool_nd(dims, *args, **kwargs):
172 | """
173 | Create a 1D, 2D, or 3D average pooling module.
174 | """
175 | if dims == 1:
176 | return nn.AvgPool1d(*args, **kwargs)
177 | elif dims == 2:
178 | return nn.AvgPool2d(*args, **kwargs)
179 | elif dims == 3:
180 | return nn.AvgPool3d(*args, **kwargs)
181 | raise ValueError(f"unsupported dimensions: {dims}")
182 |
183 |
184 | def nonlinearity(type="silu"):
185 | if type == "silu":
186 | return nn.SiLU()
187 | elif type == "leaky_relu":
188 | return nn.LeakyReLU()
189 |
190 |
191 | class GroupNormSpecific(nn.GroupNorm):
192 | def forward(self, x):
193 | if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
194 | return super().forward(x).type(x.dtype)
195 | else:
196 | return super().forward(x.float()).type(x.dtype)
197 |
198 |
199 | def normalization(channels, num_groups=32):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNormSpecific(num_groups, channels)
206 |
207 |
208 | class HybridConditioner(nn.Module):
209 |
210 | def __init__(self, c_concat_config, c_crossattn_config):
211 | super().__init__()
212 | self.concat_conditioner = instantiate_from_config(c_concat_config)
213 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
214 |
215 | def forward(self, c_concat, c_crossattn):
216 | c_concat = self.concat_conditioner(c_concat)
217 | c_crossattn = self.crossattn_conditioner(c_crossattn)
218 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
219 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, datetime
2 | from omegaconf import OmegaConf
3 | from transformers import logging as transf_logging
4 |
5 | import torch
6 | import pytorch_lightning as pl
7 | from pytorch_lightning import seed_everything
8 | from pytorch_lightning.trainer import Trainer
9 |
10 | sys.path.insert(0, os.getcwd())
11 | from utils.common_utils import instantiate_from_config
12 | from utils.train_utils import (
13 | get_trainer_callbacks,
14 | get_trainer_logger,
15 | get_trainer_strategy,
16 | )
17 | from utils.train_utils import (
18 | set_logger,
19 | init_workspace,
20 | load_checkpoints,
21 | get_autoresume_path,
22 | )
23 |
24 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
25 |
26 |
27 | def get_parser(**parser_kwargs):
28 | parser = argparse.ArgumentParser(**parser_kwargs)
29 | parser.add_argument(
30 | "--seed", "-s", type=int, default=20230211, help="seed for seed_everything"
31 | )
32 | parser.add_argument(
33 | "--name", "-n", type=str, default="", help="experiment name, as saving folder"
34 | )
35 |
36 | parser.add_argument(
37 | "--base",
38 | "-b",
39 | nargs="*",
40 | metavar="base_config.yaml",
41 | help="paths to base configs. Loaded from left-to-right. "
42 | "Parameters can be overwritten or added with command-line options of the form `--key value`.",
43 | default=list(),
44 | )
45 |
46 | parser.add_argument(
47 | "--train", "-t", action="store_true", default=False, help="train"
48 | )
49 | parser.add_argument("--val", "-v", action="store_true", default=False, help="val")
50 | parser.add_argument("--test", action="store_true", default=False, help="test")
51 |
52 | parser.add_argument(
53 | "--logdir",
54 | "-l",
55 | type=str,
56 | default="logs",
57 | help="directory for logging dat shit",
58 | )
59 | parser.add_argument(
60 | "--auto_resume",
61 | action="store_true",
62 | default=False,
63 | help="resume from full-info checkpoint",
64 | )
65 | parser.add_argument(
66 | "--debug",
67 | "-d",
68 | action="store_true",
69 | default=False,
70 | help="enable post-mortem debugging",
71 | )
72 |
73 | return parser
74 |
75 |
76 | def get_nondefault_trainer_args(args):
77 | parser = argparse.ArgumentParser()
78 | parser = Trainer.add_argparse_args(parser)
79 | default_trainer_args = parser.parse_args([])
80 | return sorted(
81 | k
82 | for k in vars(default_trainer_args)
83 | if getattr(args, k) != getattr(default_trainer_args, k)
84 | )
85 |
86 |
87 | if __name__ == "__main__":
88 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
89 | try:
90 | local_rank = int(os.environ.get("LOCAL_RANK"))
91 | global_rank = int(os.environ.get("RANK"))
92 | num_rank = int(os.environ.get("WORLD_SIZE"))
93 | except:
94 | local_rank, global_rank, num_rank = 0, 0, 1
95 | # print(f'local_rank: {local_rank} | global_rank:{global_rank} | num_rank:{num_rank}')
96 |
97 | parser = get_parser()
98 | ## Extends existing argparse by default Trainer attributes
99 | parser = Trainer.add_argparse_args(parser)
100 | args, unknown = parser.parse_known_args()
101 | ## disable transformer warning
102 | transf_logging.set_verbosity_error()
103 | seed_everything(args.seed)
104 |
105 | ## yaml configs: "model" | "data" | "lightning"
106 | configs = [OmegaConf.load(cfg) for cfg in args.base]
107 | cli = OmegaConf.from_dotlist(unknown)
108 | config = OmegaConf.merge(*configs, cli)
109 | lightning_config = config.pop("lightning", OmegaConf.create())
110 | trainer_config = lightning_config.get("trainer", OmegaConf.create())
111 |
112 | ## setup workspace directories
113 | workdir, ckptdir, cfgdir, loginfo = init_workspace(
114 | args.name, args.logdir, config, lightning_config, global_rank
115 | )
116 | logger = set_logger(
117 | logfile=os.path.join(loginfo, "log_%d:%s.txt" % (global_rank, now))
118 | )
119 | logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
120 |
121 | ## MODEL CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
122 | logger.info("***** Configing Model *****")
123 | config.model.params.logdir = workdir
124 | model = instantiate_from_config(config.model)
125 |
126 | if args.auto_resume:
127 | ## the saved checkpoint must be: full-info checkpoint
128 | resume_ckpt_path = get_autoresume_path(workdir)
129 | if resume_ckpt_path is not None:
130 | args.resume_from_checkpoint = resume_ckpt_path
131 | logger.info("Resuming from checkpoint: %s" % args.resume_from_checkpoint)
132 | ## just in case train empy parameters only
133 | else:
134 | model = load_checkpoints(model, config.model)
135 | logger.warning("Auto-resuming skipped as No checkpoit found!")
136 | else:
137 | model = load_checkpoints(model, config.model)
138 |
139 | ## update trainer config
140 | for k in get_nondefault_trainer_args(args):
141 | trainer_config[k] = getattr(args, k)
142 |
143 | print(trainer_config)
144 | num_nodes = trainer_config.num_nodes
145 | ngpu_per_node = trainer_config.devices
146 | logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
147 |
148 | ## setup learning rate
149 | base_lr = config.model.base_learning_rate
150 | bs = config.data.params.batch_size
151 | if getattr(config.model, "scale_lr", True):
152 | model.learning_rate = num_rank * bs * base_lr
153 | else:
154 | model.learning_rate = base_lr
155 |
156 | ## DATA CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
157 | logger.info("***** Configing Data *****")
158 | data = instantiate_from_config(config.data)
159 | data.setup()
160 | for k in data.datasets:
161 | logger.info(
162 | f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
163 | )
164 |
165 | ## TRAINER CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
166 | logger.info("***** Configing Trainer *****")
167 | if "accelerator" not in trainer_config:
168 | trainer_config["accelerator"] = "gpu"
169 |
170 | torch.set_float32_matmul_precision("medium")
171 |
172 | ## setup trainer args: pl-logger and callbacks
173 | trainer_kwargs = dict()
174 | trainer_kwargs["num_sanity_val_steps"] = 0
175 | logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
176 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
177 |
178 | ## setup callbacks
179 | callbacks_cfg = get_trainer_callbacks(
180 | lightning_config, config, workdir, ckptdir, logger
181 | )
182 | trainer_kwargs["callbacks"] = [
183 | instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
184 | ]
185 | strategy_cfg = get_trainer_strategy(lightning_config)
186 | trainer_kwargs["strategy"] = (
187 | strategy_cfg
188 | if type(strategy_cfg) == str
189 | else instantiate_from_config(strategy_cfg)
190 | )
191 | trainer_kwargs["precision"] = lightning_config.get("precision", "bf16")
192 | trainer_kwargs["sync_batchnorm"] = False
193 |
194 | ## trainer config: others
195 | if (
196 | "train" in config.data.params
197 | and config.data.params.train.target == "lvdm.data.hdvila.HDVila"
198 | or (
199 | "validation" in config.data.params
200 | and config.data.params.validation.target == "lvdm.data.hdvila.HDVila"
201 | )
202 | ):
203 | trainer_kwargs["replace_sampler_ddp"] = False
204 |
205 | ## for debug
206 | # trainer_kwargs["fast_dev_run"] = 10
207 | # trainer_kwargs["limit_train_batches"] = 1./32
208 | # trainer_kwargs["limit_val_batches"] = 0.01
209 | # trainer_kwargs["val_check_interval"] = 20 #float: epoch ratio | integer: batch num
210 |
211 | trainer_args = argparse.Namespace(**trainer_config)
212 | trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
213 |
214 | ## allow checkpointing via USR1
215 | def melk(*args, **kwargs):
216 | ## run all checkpoint hooks
217 | if trainer.global_rank == 0:
218 | print("Summoning checkpoint.")
219 | ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
220 | trainer.save_checkpoint(ckpt_path)
221 |
222 | def divein(*args, **kwargs):
223 | if trainer.global_rank == 0:
224 | import pudb
225 |
226 | pudb.set_trace()
227 |
228 | import signal
229 |
230 | signal.signal(signal.SIGUSR1, melk)
231 | signal.signal(signal.SIGUSR2, divein)
232 |
233 | ## Running LOOP >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
234 | logger.info("***** Running the Loop *****")
235 | if args.train:
236 | try:
237 | if "strategy" in lightning_config:
238 | logger.info("")
239 | ## deepspeed
240 | with torch.cuda.amp.autocast():
241 | trainer.fit(model, data)
242 | else:
243 | logger.info("")
244 | ## ddpshare
245 | trainer.fit(model, data)
246 | except Exception:
247 | # melk()
248 | raise
249 | if args.val:
250 | trainer.validate(model, data)
251 | if args.test or not trainer.interrupted:
252 | trainer.test(model, data)
253 |
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 |
5 | mainlogger = logging.getLogger("mainlogger")
6 |
7 | import torch
8 | import torchvision
9 | import pytorch_lightning as pl
10 | from pytorch_lightning.callbacks import Callback
11 | from pytorch_lightning.utilities import rank_zero_only
12 | from pytorch_lightning.utilities import rank_zero_info
13 | from utils.save_video import log_local, prepare_to_log
14 |
15 |
16 | class ImageLogger(Callback):
17 | def __init__(
18 | self,
19 | batch_frequency,
20 | max_images=8,
21 | clamp=True,
22 | rescale=True,
23 | save_dir=None,
24 | to_local=False,
25 | log_images_kwargs=None,
26 | ):
27 | super().__init__()
28 | self.rescale = rescale
29 | self.batch_freq = batch_frequency
30 | self.max_images = max_images
31 | self.to_local = to_local
32 | self.clamp = clamp
33 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
34 | if self.to_local:
35 | ## default save dir
36 | self.save_dir = os.path.join(save_dir, "images")
37 | os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
38 | os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
39 |
40 | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=10):
41 | """log images and videos to tensorboard"""
42 | global_step = pl_module.global_step
43 | for key in batch_logs:
44 | value = batch_logs[key]
45 | tag = "gs%d-%s/%s-%s" % (global_step, split, filename, key)
46 | if isinstance(value, list) and isinstance(value[0], str):
47 | captions = " |------| ".join(value)
48 | pl_module.logger.experiment.add_text(
49 | tag, captions, global_step=global_step
50 | )
51 | elif isinstance(value, torch.Tensor) and value.dim() == 5:
52 | video = value
53 | n = video.shape[0]
54 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
55 | frame_grids = [
56 | torchvision.utils.make_grid(framesheet, nrow=int(n))
57 | for framesheet in video
58 | ] # [3, n*h, 1*w]
59 | grid = torch.stack(
60 | frame_grids, dim=0
61 | ) # stack in temporal dim [t, 3, n*h, w]
62 | grid = (grid + 1.0) / 2.0
63 | grid = grid.unsqueeze(dim=0)
64 | pl_module.logger.experiment.add_video(
65 | tag, grid, fps=save_fps, global_step=global_step
66 | )
67 | elif isinstance(value, torch.Tensor) and value.dim() == 4:
68 | img = value
69 | grid = torchvision.utils.make_grid(img, nrow=int(n))
70 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
71 | pl_module.logger.experiment.add_image(
72 | tag, grid, global_step=global_step
73 | )
74 | else:
75 | pass
76 |
77 | @rank_zero_only
78 | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
79 | """generate images, then save and log to tensorboard"""
80 | skip_freq = self.batch_freq if split == "train" else 5
81 | if (batch_idx + 1) % skip_freq == 0:
82 | is_train = pl_module.training
83 | if is_train:
84 | pl_module.eval()
85 |
86 | with torch.no_grad():
87 | log_func = pl_module.log_images
88 | batch_logs = log_func(batch, split=split, **self.log_images_kwargs)
89 |
90 | ## process: move to CPU and clamp
91 | batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp)
92 | torch.cuda.empty_cache()
93 |
94 | filename = "ep{}_idx{}_rank{}".format(
95 | pl_module.current_epoch, batch_idx, pl_module.global_rank
96 | )
97 | if self.to_local:
98 | mainlogger.info("Log [%s] batch <%s> to local ..." % (split, filename))
99 | filename = "gs{}_".format(pl_module.global_step) + filename
100 | log_local(
101 | batch_logs,
102 | os.path.join(self.save_dir, split),
103 | filename,
104 | save_fps=10,
105 | )
106 | else:
107 | mainlogger.info(
108 | "Log [%s] batch <%s> to tensorboard ..." % (split, filename)
109 | )
110 | self.log_to_tensorboard(
111 | pl_module, batch_logs, filename, split, save_fps=10
112 | )
113 | mainlogger.info("Finish!")
114 |
115 | if is_train:
116 | pl_module.train()
117 |
118 | def on_train_batch_end(
119 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
120 | ):
121 | if self.batch_freq != -1 and pl_module.logdir:
122 | self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
123 |
124 | def on_validation_batch_end(
125 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
126 | ):
127 | ## different with validation_step() that saving the whole validation set and only keep the latest,
128 | ## it records the performance of every validation (without overwritten) by only keep a subset
129 | if self.batch_freq != -1 and pl_module.logdir:
130 | self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
131 | if hasattr(pl_module, "calibrate_grad_norm"):
132 | if (
133 | pl_module.calibrate_grad_norm and batch_idx % 25 == 0
134 | ) and batch_idx > 0:
135 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
136 |
137 |
138 | """
139 | class DataModeSwitcher(Callback):
140 | def on_epoch_start(self, trainer, pl_module):
141 | mode = 'image' if random.random() <= 0.3 else 'video'
142 | trainer.datamodule.dataset.set_mode(mode)
143 | if trainer.global_rank == 0:
144 | torch.distributed.barrier()
145 | """
146 |
147 |
148 | class CUDACallback(Callback):
149 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
150 | def on_train_epoch_start(self, trainer, pl_module):
151 | # Reset the memory use counter
152 | # lightning update
153 | if int((pl.__version__).split(".")[1]) >= 7:
154 | gpu_index = trainer.strategy.root_device.index
155 | else:
156 | gpu_index = trainer.root_gpu
157 | torch.cuda.reset_peak_memory_stats(gpu_index)
158 | torch.cuda.synchronize(gpu_index)
159 | self.start_time = time.time()
160 |
161 | def on_train_epoch_end(self, trainer, pl_module):
162 | if int((pl.__version__).split(".")[1]) >= 7:
163 | gpu_index = trainer.strategy.root_device.index
164 | else:
165 | gpu_index = trainer.root_gpu
166 | torch.cuda.synchronize(gpu_index)
167 | max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
168 | epoch_time = time.time() - self.start_time
169 |
170 | try:
171 | max_memory = trainer.training_type_plugin.reduce(max_memory)
172 | epoch_time = trainer.training_type_plugin.reduce(epoch_time)
173 |
174 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
175 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
176 | except AttributeError:
177 | pass
178 |
--------------------------------------------------------------------------------
/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import cv2, os
4 | import torch
5 | import torch.distributed as dist
6 |
7 |
8 | def count_params(model, verbose=False):
9 | total_params = sum(p.numel() for p in model.parameters())
10 | if verbose:
11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
12 | return total_params
13 |
14 |
15 | def check_istarget(name, para_list):
16 | """
17 | name: full name of source para
18 | para_list: partial name of target para
19 | """
20 | istarget = False
21 | for para in para_list:
22 | if para in name:
23 | return True
24 | return istarget
25 |
26 |
27 | def instantiate_from_config(config):
28 | if not "target" in config:
29 | if config == "__is_first_stage__":
30 | return None
31 | elif config == "__is_unconditional__":
32 | return None
33 | raise KeyError("Expected key `target` to instantiate.")
34 |
35 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
36 |
37 |
38 | def get_obj_from_str(string, reload=False):
39 | module, cls = string.rsplit(".", 1)
40 | if reload:
41 | module_imp = importlib.import_module(module)
42 | importlib.reload(module_imp)
43 | return getattr(importlib.import_module(module, package=None), cls)
44 |
45 |
46 | def load_npz_from_dir(data_dir):
47 | data = [
48 | np.load(os.path.join(data_dir, data_name))["arr_0"]
49 | for data_name in os.listdir(data_dir)
50 | ]
51 | data = np.concatenate(data, axis=0)
52 | return data
53 |
54 |
55 | def load_npz_from_paths(data_paths):
56 | data = [np.load(data_path)["arr_0"] for data_path in data_paths]
57 | data = np.concatenate(data, axis=0)
58 | return data
59 |
60 |
61 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
62 | h, w = image.shape[:2]
63 | if resize_short_edge is not None:
64 | k = resize_short_edge / min(h, w)
65 | else:
66 | k = max_resolution / (h * w)
67 | k = k**0.5
68 | h = int(np.round(h * k / 64)) * 64
69 | w = int(np.round(w * k / 64)) * 64
70 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
71 | return image
72 |
73 |
74 | def setup_dist(args):
75 | if dist.is_initialized():
76 | return
77 | torch.cuda.set_device(args.local_rank)
78 | torch.distributed.init_process_group("nccl", init_method="env://")
79 |
--------------------------------------------------------------------------------
/utils/save_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from PIL import Image
5 | from einops import rearrange
6 |
7 | import torch
8 | import torchvision
9 | from torch import Tensor
10 | from torchvision.utils import make_grid
11 | from torchvision.transforms.functional import to_tensor
12 | from PIL import Image, ImageDraw, ImageFont
13 |
14 |
15 | def save_video_tensor_to_mp4(video, path, fps):
16 | # b,c,t,h,w
17 | video = video.detach().cpu()
18 | video = torch.clamp(video.float(), -1.0, 1.0)
19 | n = video.shape[0]
20 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
21 | frame_grids = [
22 | torchvision.utils.make_grid(framesheet, nrow=int(n)) for framesheet in video
23 | ] # [3, 1*h, n*w]
24 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
25 | grid = (grid + 1.0) / 2.0
26 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
27 | torchvision.io.write_video(
28 | path, grid, fps=fps, video_codec="h264", options={"crf": "10"}
29 | )
30 |
31 |
32 | def save_video_tensor_to_frames(video, dir):
33 | os.makedirs(dir, exist_ok=True)
34 | # b,c,t,h,w
35 | video = video.detach().cpu()
36 | video = torch.clamp(video.float(), -1.0, 1.0)
37 | n = video.shape[0]
38 | assert n == 1
39 | video = video[0] # cthw
40 | video = video.permute(1, 2, 3, 0) # thwc
41 | # video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
42 | video = (video + 1.0) / 2.0 * 255
43 | video = video.to(torch.uint8).numpy()
44 | for i in range(video.shape[0]):
45 | img = video[i] # hwc
46 | image = Image.fromarray(img)
47 | image.save(os.path.join(dir, f"frame{i:03d}.jpg"), q=95)
48 |
49 |
50 | def frames_to_mp4(frame_dir, output_path, fps):
51 | def read_first_n_frames(d: os.PathLike, num_frames: int):
52 | if num_frames:
53 | images = [
54 | Image.open(os.path.join(d, f))
55 | for f in sorted(os.listdir(d))[:num_frames]
56 | ]
57 | else:
58 | images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))]
59 | images = [to_tensor(x) for x in images]
60 | return torch.stack(images)
61 |
62 | videos = read_first_n_frames(frame_dir, num_frames=None)
63 | videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
64 | torchvision.io.write_video(
65 | output_path, videos, fps=fps, video_codec="h264", options={"crf": "10"}
66 | )
67 |
68 |
69 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
70 | """
71 | video: torch.Tensor, b,c,t,h,w, 0-1
72 | if -1~1, enable rescale=True
73 | """
74 | n = video.shape[0]
75 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
76 | nrow = int(np.sqrt(n)) if nrow is None else nrow
77 | frame_grids = [
78 | torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video
79 | ] # [3, grid_h, grid_w]
80 | grid = torch.stack(
81 | frame_grids, dim=0
82 | ) # stack in temporal dim [T, 3, grid_h, grid_w]
83 | grid = torch.clamp(grid.float(), -1.0, 1.0)
84 | if rescale:
85 | grid = (grid + 1.0) / 2.0
86 | grid = (
87 | (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
88 | ) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
89 | # print(f'Save video to {savepath}')
90 | torchvision.io.write_video(
91 | savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
92 | )
93 |
94 |
95 | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
96 |
97 | assert video.dim() == 5 # b,c,t,h,w
98 | assert isinstance(video, torch.Tensor)
99 |
100 | video = video.detach().cpu()
101 | if clamp:
102 | video = torch.clamp(video, -1.0, 1.0)
103 | n = video.shape[0]
104 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
105 | frame_grids = [
106 | torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
107 | for framesheet in video
108 | ] # [3, grid_h, grid_w]
109 | grid = torch.stack(
110 | frame_grids, dim=0
111 | ) # stack in temporal dim [T, 3, grid_h, grid_w]
112 | if rescale:
113 | grid = (grid + 1.0) / 2.0
114 | grid = (
115 | (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
116 | ) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
117 | path = os.path.join(root, filename)
118 | # print('Save video ...')
119 | torchvision.io.write_video(
120 | path, grid, fps=fps, video_codec="h264", options={"crf": "10"}
121 | )
122 | # print('Finish!')
123 |
124 |
125 | def log_txt_as_img(wh, xc, size=10):
126 | # wh a tuple of (width, height)
127 | # xc a list of captions to plot
128 | b = len(xc)
129 | txts = list()
130 | for bi in range(b):
131 | txt = Image.new("RGB", wh, color="white")
132 | draw = ImageDraw.Draw(txt)
133 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
134 | nc = int(40 * (wh[0] / 256))
135 | lines = "\n".join(
136 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
137 | )
138 |
139 | try:
140 | draw.text((0, 0), lines, fill="black", font=font)
141 | except UnicodeEncodeError:
142 | print("Cant encode string for logging. Skipping.")
143 |
144 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
145 | txts.append(txt)
146 | txts = np.stack(txts)
147 | txts = torch.tensor(txts)
148 | return txts
149 |
150 |
151 | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
152 | if batch_logs is None:
153 | return None
154 | """ save images and videos from images dict """
155 |
156 | def save_img_grid(grid, path, rescale):
157 | if rescale:
158 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
159 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
160 | grid = grid.numpy()
161 | grid = (grid * 255).astype(np.uint8)
162 | os.makedirs(os.path.split(path)[0], exist_ok=True)
163 | Image.fromarray(grid).save(path)
164 |
165 | for key in batch_logs:
166 | value = batch_logs[key]
167 | if isinstance(value, list) and isinstance(value[0], str):
168 | ## a batch of captions
169 | path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
170 | with open(path, "w") as f:
171 | for i, txt in enumerate(value):
172 | f.write(f"idx={i}, txt={txt}\n")
173 | f.close()
174 | elif isinstance(value, torch.Tensor) and value.dim() == 5:
175 | ## save video grids
176 | video = value # b,c,t,h,w
177 | ## only save grayscale or rgb mode
178 | if video.shape[1] != 1 and video.shape[1] != 3:
179 | continue
180 | n = video.shape[0]
181 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
182 | frame_grids = [
183 | torchvision.utils.make_grid(framesheet, nrow=int(1))
184 | for framesheet in video
185 | ] # [3, n*h, 1*w]
186 | grid = torch.stack(
187 | frame_grids, dim=0
188 | ) # stack in temporal dim [t, 3, n*h, w]
189 | if rescale:
190 | grid = (grid + 1.0) / 2.0
191 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
192 | path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
193 | torchvision.io.write_video(
194 | path, grid, fps=save_fps, video_codec="h264", options={"crf": "10"}
195 | )
196 |
197 | ## save frame sheet
198 | img = value
199 | video_frames = rearrange(img, "b c t h w -> (b t) c h w")
200 | t = img.shape[2]
201 | grid = torchvision.utils.make_grid(video_frames, nrow=t)
202 | path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
203 | # save_img_grid(grid, path, rescale)
204 | elif isinstance(value, torch.Tensor) and value.dim() == 4:
205 | ## save image grids
206 | img = value
207 | ## only save grayscale or rgb mode
208 | if img.shape[1] != 1 and img.shape[1] != 3:
209 | continue
210 | n = img.shape[0]
211 | grid = torchvision.utils.make_grid(img, nrow=1)
212 | path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
213 | save_img_grid(grid, path, rescale)
214 | else:
215 | pass
216 |
217 |
218 | def prepare_to_log(batch_logs, max_images=100000, clamp=True):
219 | if batch_logs is None:
220 | return None
221 | # process
222 | for key in batch_logs:
223 | if batch_logs[key] is not None:
224 | N = (
225 | batch_logs[key].shape[0]
226 | if hasattr(batch_logs[key], "shape")
227 | else len(batch_logs[key])
228 | )
229 | N = min(N, max_images)
230 | batch_logs[key] = batch_logs[key][:N]
231 | ## in batch_logs: images & caption
232 | if isinstance(batch_logs[key], torch.Tensor):
233 | batch_logs[key] = batch_logs[key].detach().cpu()
234 | if clamp:
235 | try:
236 | batch_logs[key] = torch.clamp(
237 | batch_logs[key].float(), -1.0, 1.0
238 | )
239 | except RuntimeError:
240 | print("clamp_scalar_cpu not implemented for Half")
241 | return batch_logs
242 |
243 |
244 | # ----------------------------------------------------------------------------------------------
245 |
246 |
247 | def fill_with_black_squares(video, desired_len: int) -> Tensor:
248 | if len(video) >= desired_len:
249 | return video
250 |
251 | return torch.cat(
252 | [
253 | video,
254 | torch.zeros_like(video[0])
255 | .unsqueeze(0)
256 | .repeat(desired_len - len(video), 1, 1, 1),
257 | ],
258 | dim=0,
259 | )
260 |
261 |
262 | # ----------------------------------------------------------------------------------------------
263 | def load_num_videos(data_path, num_videos):
264 | # first argument can be either data_path of np array
265 | if isinstance(data_path, str):
266 | videos = np.load(data_path)["arr_0"] # NTHWC
267 | elif isinstance(data_path, np.ndarray):
268 | videos = data_path
269 | else:
270 | raise Exception
271 |
272 | if num_videos is not None:
273 | videos = videos[:num_videos, :, :, :, :]
274 | return videos
275 |
276 |
277 | def npz_to_video_grid(
278 | data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True
279 | ):
280 | # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1
281 | if isinstance(data_path, str):
282 | videos = load_num_videos(data_path, num_videos)
283 | elif isinstance(data_path, np.ndarray):
284 | videos = data_path
285 | else:
286 | raise Exception
287 | n, t, h, w, c = videos.shape
288 | videos_th = []
289 | for i in range(n):
290 | video = videos[i, :, :, :, :]
291 | images = [video[j, :, :, :] for j in range(t)]
292 | images = [to_tensor(img) for img in images]
293 | video = torch.stack(images)
294 | videos_th.append(video)
295 | if verbose:
296 | videos = [
297 | fill_with_black_squares(v, num_frames)
298 | for v in tqdm(videos_th, desc="Adding empty frames")
299 | ] # NTCHW
300 | else:
301 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW
302 |
303 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W]
304 | if nrow is None:
305 | nrow = int(np.ceil(np.sqrt(n)))
306 | if verbose:
307 | frame_grids = [
308 | make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc="Making grids")
309 | ]
310 | else:
311 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
312 |
313 | if os.path.dirname(out_path) != "":
314 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
315 | frame_grids = (
316 | (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1)
317 | ) # [T, H, W, C]
318 | torchvision.io.write_video(
319 | out_path, frame_grids, fps=fps, video_codec="h264", options={"crf": "10"}
320 | )
321 |
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import OmegaConf
3 | from collections import OrderedDict
4 | import logging
5 |
6 | mainlogger = logging.getLogger("mainlogger")
7 |
8 | import torch
9 | from collections import OrderedDict
10 |
11 |
12 | def init_workspace(name, logdir, model_config, lightning_config, rank=0):
13 | workdir = os.path.join(logdir, name)
14 | ckptdir = os.path.join(workdir, "checkpoints")
15 | cfgdir = os.path.join(workdir, "configs")
16 | loginfo = os.path.join(workdir, "loginfo")
17 |
18 | # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
19 | os.makedirs(workdir, exist_ok=True)
20 | os.makedirs(ckptdir, exist_ok=True)
21 | os.makedirs(cfgdir, exist_ok=True)
22 | os.makedirs(loginfo, exist_ok=True)
23 |
24 | if rank == 0:
25 | if (
26 | "callbacks" in lightning_config
27 | and "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks
28 | ):
29 | os.makedirs(os.path.join(ckptdir, "trainstep_checkpoints"), exist_ok=True)
30 | OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
31 | OmegaConf.save(
32 | OmegaConf.create({"lightning": lightning_config}),
33 | os.path.join(cfgdir, "lightning.yaml"),
34 | )
35 | return workdir, ckptdir, cfgdir, loginfo
36 |
37 |
38 | def check_config_attribute(config, name):
39 | if name in config:
40 | value = getattr(config, name)
41 | return value
42 | else:
43 | return None
44 |
45 |
46 | def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
47 | default_callbacks_cfg = {
48 | "model_checkpoint": {
49 | "target": "pytorch_lightning.callbacks.ModelCheckpoint",
50 | "params": {
51 | "dirpath": ckptdir,
52 | "filename": "{epoch}",
53 | "verbose": True,
54 | "save_last": True,
55 | },
56 | },
57 | "batch_logger": {
58 | "target": "utils.callbacks.ImageLogger",
59 | "params": {
60 | "save_dir": logdir,
61 | "batch_frequency": 1000,
62 | "max_images": 4,
63 | "clamp": True,
64 | },
65 | },
66 | "learning_rate_logger": {
67 | "target": "pytorch_lightning.callbacks.LearningRateMonitor",
68 | "params": {"logging_interval": "step", "log_momentum": False},
69 | },
70 | "cuda_callback": {"target": "utils.callbacks.CUDACallback"},
71 | }
72 |
73 | ## optional setting for saving checkpoints
74 | monitor_metric = check_config_attribute(config.model.params, "monitor")
75 | if monitor_metric is not None:
76 | mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
77 | default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric
78 | default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
79 | default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
80 |
81 | if "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks:
82 | mainlogger.info(
83 | "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
84 | )
85 | default_metrics_over_trainsteps_ckpt_dict = {
86 | "metrics_over_trainsteps_checkpoint": {
87 | "target": "pytorch_lightning.callbacks.ModelCheckpoint",
88 | "params": {
89 | "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
90 | "filename": "{epoch}-{step}",
91 | "verbose": True,
92 | "save_top_k": -1,
93 | "every_n_train_steps": 10000,
94 | "save_weights_only": True,
95 | },
96 | }
97 | }
98 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
99 |
100 | if "callbacks" in lightning_config:
101 | callbacks_cfg = lightning_config.callbacks
102 | else:
103 | callbacks_cfg = OmegaConf.create()
104 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
105 |
106 | return callbacks_cfg
107 |
108 |
109 | def get_trainer_logger(lightning_config, logdir, on_debug):
110 | default_logger_cfgs = {
111 | "tensorboard": {
112 | "target": "pytorch_lightning.loggers.TensorBoardLogger",
113 | "params": {
114 | "save_dir": logdir,
115 | "name": "tensorboard",
116 | },
117 | },
118 | "testtube": {
119 | "target": "pytorch_lightning.loggers.CSVLogger",
120 | "params": {
121 | "name": "testtube",
122 | "save_dir": logdir,
123 | },
124 | },
125 | }
126 | os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
127 | default_logger_cfg = default_logger_cfgs["tensorboard"]
128 | if "logger" in lightning_config:
129 | logger_cfg = lightning_config.logger
130 | else:
131 | logger_cfg = OmegaConf.create()
132 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
133 | return logger_cfg
134 |
135 |
136 | def get_trainer_strategy(lightning_config):
137 | default_strategy_dict = {
138 | "target": "pytorch_lightning.strategies.DDPShardedStrategy"
139 | }
140 | if "strategy" in lightning_config:
141 | strategy_cfg = lightning_config.strategy
142 | return strategy_cfg
143 | else:
144 | strategy_cfg = OmegaConf.create()
145 |
146 | strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
147 | return strategy_cfg
148 |
149 |
150 | def load_checkpoints(model, model_cfg):
151 | ## special load setting for adapter training
152 | if check_config_attribute(model_cfg, "adapter_only"):
153 | pretrained_ckpt = model_cfg.pretrained_checkpoint
154 | assert os.path.exists(pretrained_ckpt), (
155 | "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
156 | )
157 | mainlogger.info(
158 | ">>> Load weights from pretrained checkpoint (training adapter only)"
159 | )
160 | print(f"Loading model from {pretrained_ckpt}")
161 | ## only load weight for the backbone model (e.g. latent diffusion model)
162 | state_dict = torch.load(pretrained_ckpt, map_location=f"cpu")
163 | if "state_dict" in list(state_dict.keys()):
164 | state_dict = state_dict["state_dict"]
165 | else:
166 | # deepspeed
167 | dp_state_dict = OrderedDict()
168 | for key in state_dict["module"].keys():
169 | dp_state_dict[key[16:]] = state_dict["module"][key]
170 | state_dict = dp_state_dict
171 | model.load_state_dict(state_dict, strict=False)
172 | model.empty_paras = None
173 | return model
174 | empty_paras = None
175 |
176 | if check_config_attribute(model_cfg, "pretrained_checkpoint"):
177 | pretrained_ckpt = model_cfg.pretrained_checkpoint
178 | assert os.path.exists(pretrained_ckpt), (
179 | "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
180 | )
181 | mainlogger.info(">>> Load weights from pretrained checkpoint")
182 | # mainlogger.info(pretrained_ckpt)
183 | print("Loading model from {pretrained_ckpt}")
184 | pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
185 | try:
186 | if "state_dict" in pl_sd.keys():
187 | model.load_state_dict(pl_sd["state_dict"])
188 | else:
189 | # deepspeed
190 | new_pl_sd = OrderedDict()
191 | for key in pl_sd["module"].keys():
192 | new_pl_sd[key[16:]] = pl_sd["module"][key]
193 | model.load_state_dict(new_pl_sd)
194 | except:
195 | model.load_state_dict(pl_sd)
196 | else:
197 | empty_paras = None
198 |
199 | ## record empty params
200 | model.empty_paras = empty_paras
201 | return model
202 |
203 |
204 | def get_autoresume_path(logdir):
205 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
206 | if os.path.exists(ckpt):
207 | try:
208 | tmp = torch.load(ckpt, map_location="cpu")
209 | e = tmp["epoch"]
210 | gs = tmp["global_step"]
211 | mainlogger.info(f"[INFO] Resume from epoch {e}, global step {gs}!")
212 | del tmp
213 | except:
214 | try:
215 | mainlogger.info("Load last.ckpt failed!")
216 | ckpts = sorted(
217 | [
218 | f
219 | for f in os.listdir(os.path.join(logdir, "checkpoints"))
220 | if not os.path.isdir(f)
221 | ]
222 | )
223 | mainlogger.info(f"all avaible checkpoints: {ckpts}")
224 | ckpts.remove("last.ckpt")
225 | if "trainstep_checkpoints" in ckpts:
226 | ckpts.remove("trainstep_checkpoints")
227 | ckpt_path = ckpts[-1]
228 | ckpt = os.path.join(logdir, "checkpoints", ckpt_path)
229 | mainlogger.info(f"Select resuming ckpt: {ckpt}")
230 | except ValueError:
231 | mainlogger.info("Load last.ckpt failed! and there is no other ckpts")
232 |
233 | resume_checkpt_path = ckpt
234 | mainlogger.info(f"[INFO] resume from: {ckpt}")
235 | else:
236 | resume_checkpt_path = None
237 | mainlogger.info(
238 | f"[INFO] no checkpoint found in current workspace: {os.path.join(logdir, 'checkpoints')}"
239 | )
240 |
241 | return resume_checkpt_path
242 |
243 |
244 | def set_logger(logfile, name="mainlogger"):
245 | logger = logging.getLogger(name)
246 | logger.setLevel(logging.INFO)
247 | fh = logging.FileHandler(logfile, mode="w")
248 | fh.setLevel(logging.INFO)
249 | ch = logging.StreamHandler()
250 | ch.setLevel(logging.DEBUG)
251 | fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
252 | ch.setFormatter(logging.Formatter("%(message)s"))
253 | logger.addHandler(fh)
254 | logger.addHandler(ch)
255 | return logger
256 |
--------------------------------------------------------------------------------