├── .gitattributes ├── .gitignore ├── DATASETS.md ├── LICENSE ├── README.md ├── assets ├── deepavfusion.png ├── demos │ ├── demo1_loc.png │ ├── demo1_orig.mp4 │ ├── demo1_orig.png │ ├── demo1_orig.wav │ ├── demo1_sep1.mp4 │ ├── demo1_sep1.png │ ├── demo1_sep1.wav │ ├── demo1_sep2.mp4 │ ├── demo1_sep2.png │ ├── demo1_sep2.wav │ ├── demo2_loc.png │ ├── demo2_orig.mp4 │ ├── demo2_orig.png │ ├── demo2_orig.wav │ ├── demo2_sep1.mp4 │ ├── demo2_sep1.png │ ├── demo2_sep1.wav │ ├── demo2_sep2.mp4 │ ├── demo2_sep2.png │ ├── demo2_sep2.wav │ ├── demo3_loc.png │ ├── demo3_orig.mp4 │ ├── demo3_orig.png │ ├── demo3_orig.wav │ ├── demo3_sep1.mp4 │ ├── demo3_sep1.png │ ├── demo3_sep1.wav │ ├── demo3_sep2.mp4 │ ├── demo3_sep2.png │ ├── demo3_sep2.wav │ ├── demo4_loc.png │ ├── demo4_orig.mp4 │ ├── demo4_orig.png │ ├── demo4_orig.wav │ ├── demo4_sep1.mp4 │ ├── demo4_sep1.png │ ├── demo4_sep1.wav │ ├── demo4_sep2.mp4 │ ├── demo4_sep2.png │ └── demo4_sep2.wav ├── models │ └── vitbase_audiomae_as2m.pth └── training_curve.png ├── avreader.py ├── checkpoints ├── deepavfusion_vitb_as2m_ep200 │ └── checkpoints │ │ └── checkpoint_latest.pth └── deepavfusion_vitb_vggsound_ep200 │ └── checkpoints │ └── checkpoint_latest.pth ├── configs ├── avsegm.yaml ├── avsrcsep.yaml ├── avsync.yaml ├── deepavfusion.yaml ├── env │ └── default.yaml ├── finetune.yaml ├── hydra │ └── default.yaml ├── linprobe.yaml ├── log │ └── default.yaml └── nn_probe │ └── default.yaml ├── datasets.py ├── eval_avsegm.py ├── eval_avsrcsep.py ├── eval_finetune.py ├── eval_linprobe.py ├── launcher.py ├── metadata ├── avsbench_test.csv ├── avsbench_train.txt ├── avsbench_val.csv ├── flickr_10k.txt ├── flickr_144k.txt ├── flickr_sup_train.txt ├── flickr_test.csv ├── music_duet.json ├── music_duet_test.csv ├── music_duet_train.txt ├── vgginstruments_test.csv ├── vgginstruments_train.txt ├── vggmusic_eval_ss.csv ├── vggmusic_train.txt ├── vggsound_test.csv ├── vggsound_train.txt ├── vggss.json ├── vggss_10k.txt ├── vggss_144k.txt ├── vggss_heard.txt ├── vggss_heard_test.csv ├── vggss_test.csv ├── vggss_train.txt └── vggss_unheard_test.csv ├── models ├── avmae.py ├── avsegm.py ├── avsrcsep.py ├── classifier.py ├── deepavfusion.py ├── fusion_blocks.py ├── swin.py ├── video_earlyfusion.py ├── video_vits.py └── vits.py ├── requirements.yml ├── train.py └── util ├── audio_transforms.py ├── data.py ├── distributed.py ├── image_labels_transforms.py ├── knn_probe.py ├── lars.py ├── lr_sched.py ├── meters.py ├── misc.py └── pos_embed.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | assets/models/vitbase_audiomae_as2m.pth filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | In this work, we used a variety of datasets, including VGGSound, AudioSet, MUSIC and AVSBench. 3 | We used datasets of mp4 files encoded with a constant key frame rate of 16 and 360p resolution (length of short-side). 4 | Saving frequent key-frames allows faster decoding, thus accelerating training. 5 | 6 | ## VGGSounds 7 | ([Download](https://www.robots.ox.ac.uk/~vgg/data/vggsound/)) 8 | Expected folder structure 9 | ``` 10 | ${ROOT} 11 | ${ROOT}/annotations 12 | ${ROOT}/annotations/vggsound.csv 13 | 14 | ${ROOT}/clips 15 | ${ROOT}/clips/${CLASS} 16 | ${ROOT}/clips/${CLASS}/${FILENAME}.mp4 17 | ``` 18 | 19 | ## AudioSet 20 | ([Download](https://research.google.com/audioset/index.html)) 21 | Expected folder structure 22 | ``` 23 | ${ROOT} 24 | ${ROOT}/annotations 25 | ${ROOT}/annotations/class_labels_indices.csv 26 | ${ROOT}/annotations/unbalanced_train_segments.csv 27 | ${ROOT}/annotations/balanced_train_segments.csv 28 | ${ROOT}/annotations/eval_segments.csv 29 | 30 | ${ROOT}/clips 31 | ${ROOT}/clips/${VID[:2]} 32 | ${ROOT}/clips/${VID[:2]}/${FILENAME}.mp4 33 | ``` 34 | 35 | ## MUSIC 36 | ([Download](https://github.com/roudimit/MUSIC_dataset)) 37 | Expected folder structure 38 | ``` 39 | ${ROOT} 40 | ${ROOT}/anno 41 | ${ROOT}/anno/music_solo.csv 42 | 43 | ${ROOT}/clips 44 | ${ROOT}/clips/${CLASS} 45 | ${ROOT}/clips/${CLASS}/${FILENAME}.mp4 46 | ``` 47 | 48 | ## AVSBench 49 | ([Download](https://opennlplab.github.io/AVSBench/)) 50 | Expected folder structure 51 | ``` 52 | ${ROOT} 53 | ${ROOT}/metadata.csv 54 | ${ROOT}/label2idx.json 55 | 56 | ${ROOT}/v1m 57 | ${ROOT}/v1m/${FILE} 58 | ${ROOT}/v1m/${FILE}/audio.wav 59 | ${ROOT}/v1m/${FILE}/frames 60 | ${ROOT}/v1m/${FILE}/labels_rgb 61 | ${ROOT}/v1m/${FILE}/labels_semantic 62 | 63 | ${ROOT}/v1s 64 | ${ROOT}/v1s/${FILE} 65 | ${ROOT}/v1s/${FILE}/audio.wav 66 | ${ROOT}/v1s/${FILE}/frames 67 | ${ROOT}/v1s/${FILE}/labels_rgb 68 | ${ROOT}/v1s/${FILE}/labels_semantic 69 | 70 | ${ROOT}/v2 71 | ${ROOT}/v2/${FILE} 72 | ${ROOT}/v2/${FILE}/audio.wav 73 | ${ROOT}/v2/${FILE}/frames 74 | ${ROOT}/v2/${FILE}/labels_rgb 75 | ${ROOT}/v2/${FILE}/labels_semantic 76 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /assets/deepavfusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/deepavfusion.png -------------------------------------------------------------------------------- /assets/demos/demo1_loc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_loc.png -------------------------------------------------------------------------------- /assets/demos/demo1_orig.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_orig.mp4 -------------------------------------------------------------------------------- /assets/demos/demo1_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_orig.png -------------------------------------------------------------------------------- /assets/demos/demo1_orig.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_orig.wav -------------------------------------------------------------------------------- /assets/demos/demo1_sep1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep1.mp4 -------------------------------------------------------------------------------- /assets/demos/demo1_sep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep1.png -------------------------------------------------------------------------------- /assets/demos/demo1_sep1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep1.wav -------------------------------------------------------------------------------- /assets/demos/demo1_sep2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep2.mp4 -------------------------------------------------------------------------------- /assets/demos/demo1_sep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep2.png -------------------------------------------------------------------------------- /assets/demos/demo1_sep2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo1_sep2.wav -------------------------------------------------------------------------------- /assets/demos/demo2_loc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_loc.png -------------------------------------------------------------------------------- /assets/demos/demo2_orig.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_orig.mp4 -------------------------------------------------------------------------------- /assets/demos/demo2_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_orig.png -------------------------------------------------------------------------------- /assets/demos/demo2_orig.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_orig.wav -------------------------------------------------------------------------------- /assets/demos/demo2_sep1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep1.mp4 -------------------------------------------------------------------------------- /assets/demos/demo2_sep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep1.png -------------------------------------------------------------------------------- /assets/demos/demo2_sep1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep1.wav -------------------------------------------------------------------------------- /assets/demos/demo2_sep2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep2.mp4 -------------------------------------------------------------------------------- /assets/demos/demo2_sep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep2.png -------------------------------------------------------------------------------- /assets/demos/demo2_sep2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo2_sep2.wav -------------------------------------------------------------------------------- /assets/demos/demo3_loc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_loc.png -------------------------------------------------------------------------------- /assets/demos/demo3_orig.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_orig.mp4 -------------------------------------------------------------------------------- /assets/demos/demo3_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_orig.png -------------------------------------------------------------------------------- /assets/demos/demo3_orig.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_orig.wav -------------------------------------------------------------------------------- /assets/demos/demo3_sep1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep1.mp4 -------------------------------------------------------------------------------- /assets/demos/demo3_sep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep1.png -------------------------------------------------------------------------------- /assets/demos/demo3_sep1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep1.wav -------------------------------------------------------------------------------- /assets/demos/demo3_sep2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep2.mp4 -------------------------------------------------------------------------------- /assets/demos/demo3_sep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep2.png -------------------------------------------------------------------------------- /assets/demos/demo3_sep2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo3_sep2.wav -------------------------------------------------------------------------------- /assets/demos/demo4_loc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_loc.png -------------------------------------------------------------------------------- /assets/demos/demo4_orig.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_orig.mp4 -------------------------------------------------------------------------------- /assets/demos/demo4_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_orig.png -------------------------------------------------------------------------------- /assets/demos/demo4_orig.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_orig.wav -------------------------------------------------------------------------------- /assets/demos/demo4_sep1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep1.mp4 -------------------------------------------------------------------------------- /assets/demos/demo4_sep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep1.png -------------------------------------------------------------------------------- /assets/demos/demo4_sep1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep1.wav -------------------------------------------------------------------------------- /assets/demos/demo4_sep2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep2.mp4 -------------------------------------------------------------------------------- /assets/demos/demo4_sep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep2.png -------------------------------------------------------------------------------- /assets/demos/demo4_sep2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/demos/demo4_sep2.wav -------------------------------------------------------------------------------- /assets/models/vitbase_audiomae_as2m.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ac444b8df9264a40bdf78e09ea878497611fc1b05254f90aaaf573c94aaadde 3 | size 548136466 4 | -------------------------------------------------------------------------------- /assets/training_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stoneMo/DeepAVFusion/2768e7f77b65c45e39e019b0b13315f0b4c4fddd/assets/training_curve.png -------------------------------------------------------------------------------- /avreader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import av 4 | from fractions import Fraction 5 | import numpy as np 6 | 7 | 8 | class VideoReader: 9 | def __init__(self, filename=None, container=None): 10 | self.container = av.open(filename) if container is None else container 11 | self.stream = self.container.streams.video[0] 12 | self.stream.thread_count = 4 13 | 14 | def quick_random_frame(self, t_min=None, t_max=None): 15 | t_min = self.start_time if t_min is None else t_min 16 | t_max = self.start_time + self.duration if t_max is None else t_max 17 | rnd_t = random.uniform(t_min, t_max) 18 | self.container.seek(int(rnd_t * av.time_base)) 19 | for frame in self.container.decode(video=0): 20 | frame_ts = float(frame.pts * frame.time_base) 21 | frame = frame.to_image() 22 | return frame, frame_ts 23 | 24 | def precise_frame(self, t, seek=True): 25 | if seek: 26 | self.container.seek(int(t * av.time_base)) 27 | for frame in self.container.decode(video=0): 28 | frame_ts = float(frame.pts * frame.time_base) 29 | if t - frame_ts < 1 / self.fps: 30 | frame = frame.to_image() 31 | return frame, frame_ts 32 | 33 | def get_clip(self, t_start=None, t_end=None): 34 | t_start = self.start_time if t_start is None else t_start 35 | t_end = self.start_time + self.duration if t_end is None else t_end 36 | self.container.seek(int(t_start * av.time_base)) 37 | clip, ts = [], [] 38 | for frame in self.container.decode(video=0): 39 | frame_ts = float(frame.pts * frame.time_base) 40 | if frame_ts < t_start: 41 | continue 42 | if frame_ts > t_end: 43 | return clip, ts 44 | clip.append(frame.to_image()), ts.append(frame_ts) 45 | return clip, ts 46 | 47 | def __iter__(self): 48 | for frame in self.container.decode(video=0): 49 | frame_ts = float(frame.pts * frame.time_base) 50 | frame = frame.to_image() 51 | yield frame, frame_ts 52 | 53 | def __len__(self): 54 | return self.num_frames 55 | 56 | @property 57 | def fps(self): 58 | return self.stream.average_rate 59 | 60 | @property 61 | def num_frames(self): 62 | return self.stream.frames 63 | 64 | @property 65 | def duration(self): 66 | return self.stream.duration * self.stream.time_base 67 | 68 | @property 69 | def start_time(self): 70 | return self.stream.start_time * self.stream.time_base 71 | 72 | 73 | class AudioReader: 74 | def __init__(self, filename=None, container=None, rate=None, layout='mono'): 75 | self.container = av.open(filename) if container is None else container 76 | self.stream = self.container.streams.audio[0] 77 | self.stream.thread_count = 4 78 | self.resampler = None 79 | self.rate = self.orig_rate 80 | if rate is not None: 81 | self.resampler = av.audio.resampler.AudioResampler(format="s16p", layout=layout, rate=rate) 82 | self.rate = rate 83 | 84 | def read(self, t_min=None, t_max=None, seek=True): 85 | t_min = self.start_time if t_min is None else t_min 86 | t_max = self.start_time + self.duration if t_max is None else t_max 87 | if seek: 88 | self.container.seek(int(t_min * av.time_base)) 89 | 90 | # Read data 91 | chunks = [] 92 | for chunk in self.container.decode(audio=0): 93 | chunk_ts = chunk.pts * chunk.time_base 94 | chunk_end = chunk_ts + Fraction(chunk.samples, chunk.rate) 95 | if chunk_end < t_min: # Skip until start time 96 | continue 97 | if chunk_ts > t_max: # Exit if clip has been extracted 98 | break 99 | 100 | # Resample 101 | chunk.pts = None 102 | if self.resampler is not None: 103 | chunk = self.resampler.resample(chunk) 104 | if isinstance(chunk, list): 105 | assert len(chunk) == 1 106 | chunk = chunk[0] 107 | chunk = chunk.to_ndarray() 108 | chunk = chunk / np.iinfo(chunk.dtype).max 109 | else: 110 | chunk = chunk.to_ndarray() 111 | 112 | if chunk_ts < t_min: 113 | chunk = chunk[:, int((t_min - chunk_ts) * self.rate):] 114 | if chunk_end > t_max: 115 | chunk = chunk[:, :-int((chunk_end - t_max) * self.rate)] 116 | chunks.append(chunk) 117 | 118 | # Trim for frame accuracy 119 | audio = np.concatenate(chunks, 1) 120 | 121 | nframes = int((t_max - t_min) * self.rate) 122 | if nframes > audio.shape[1]: 123 | audio = np.pad(audio, [(0, 0), (0, nframes-audio.shape[1])],mode='symmetric') 124 | if nframes < audio.shape[1]: 125 | audio = audio[:, :nframes] 126 | 127 | return audio 128 | 129 | @property 130 | def orig_rate(self): 131 | return self.stream.rate 132 | 133 | @property 134 | def num_frames(self): 135 | return self.stream.frames 136 | 137 | @property 138 | def duration(self): 139 | return self.stream.duration * self.stream.time_base 140 | 141 | @property 142 | def start_time(self): 143 | return self.stream.start_time * self.stream.time_base if self.stream.start_time is not None else 0. 144 | 145 | 146 | if __name__ == '__main__': 147 | import time 148 | fns = glob.glob('/home/pmorgado/datasets/vggsound/clips/*/*.mp4') 149 | t_open, t_load, t_load_audio, t_load_prec = 0., 0., 0., 0. 150 | for i in range(100): 151 | t = time.time() 152 | vreader = VideoReader(fns[random.randint(0, len(fns)-1)]) 153 | areader = AudioReader(fns[random.randint(0, len(fns)-1)], rate=20500) 154 | midpoint = vreader.start_time + vreader.duration / 2. 155 | t_open += time.time() - t 156 | 157 | t = time.time() 158 | frame, ts = vreader.quick_random_frame(midpoint - 3/2, midpoint + 3/2) 159 | t_load += time.time() - t 160 | 161 | t = time.time() 162 | frame = areader.read(midpoint - 3/2, midpoint + 3/2) 163 | t_load_audio += time.time() - t 164 | 165 | t = time.time() 166 | frame, ts = vreader.precise_frame(random.uniform(midpoint - 3/2, midpoint + 3/2)) 167 | t_load_prec += time.time() - t 168 | print(t_open/100, t_load/100, t_load_prec/100) -------------------------------------------------------------------------------- /checkpoints/deepavfusion_vitb_as2m_ep200/checkpoints/checkpoint_latest.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:109dbf04dbf86786a71ea0e2f8ac4762c9e19366887ea50d6260db8b96f5cd14 3 | size 1516280189 4 | -------------------------------------------------------------------------------- /checkpoints/deepavfusion_vitb_vggsound_ep200/checkpoints/checkpoint_latest.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:51437180761408f651ec51847164038a8c180afb2f45639f07f8fced2e8e63db 3 | size 1282543933 4 | -------------------------------------------------------------------------------- /configs/avsegm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | 6 | worker: eval_avsegm 7 | output_dir: checkpoints 8 | job_name: avsegm 9 | pretrain_job_name: 10 | checkpoint: 11 | encoder_prefix: 'encoder.' 12 | pretrain_resume_epoch: latest 13 | eval: False 14 | debug: False 15 | 16 | model: 17 | image: 18 | backbone: vit_base 19 | pretrained: vit_base_mae_in1k 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | fusion: 24 | arch: factorized_mmi 25 | layers: all 26 | num_fusion_tkns: 16 27 | num_aggr_image_tkns: 8 28 | num_aggr_audio_tkns: 8 29 | mlp_ratio: 4.0 30 | attn_ratio: 0.25 31 | num_heads: 12 32 | 33 | opt: 34 | resume: True 35 | use_amp: False 36 | batch_size: 16 37 | epochs: 100 38 | warmup_epochs: 20 39 | accum_iter: 8 40 | clip_grad: 41 | weight_decay: 0.05 42 | layer_decay: 0.75 43 | lr: 44 | blr: 3e-4 45 | min_lr: 0. 46 | drop_path: 0. 47 | attn_drop: 0. 48 | proj_drop: 0. 49 | 50 | data: 51 | dataset: avsbench_avss 52 | data_path: /srv/home/groups/pmorgado/datasets/avsbench 53 | audio_rate: 16000 54 | audio_dur: 3. 55 | audio_mels: 128 56 | image_size: 224 57 | crop_min: 0.5 58 | -------------------------------------------------------------------------------- /configs/avsrcsep.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | 6 | worker: eval_avsrcsep 7 | output_dir: checkpoints 8 | job_name: avsrcsep 9 | pretrain_job_name: 10 | checkpoint: 11 | encoder_prefix: 'encoder.' 12 | pretrain_resume_epoch: latest 13 | eval: False 14 | debug: False 15 | 16 | model: 17 | image: 18 | backbone: vit_base 19 | pretrained: vit_base_mae_in1k 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | fusion: 24 | arch: factorized_mmi 25 | layers: all 26 | num_fusion_tkns: 16 27 | num_aggr_image_tkns: 8 28 | num_aggr_audio_tkns: 8 29 | mlp_ratio: 4.0 30 | attn_ratio: 0.25 31 | num_heads: 12 32 | 33 | avss: 34 | log_freq: True 35 | weighted_loss: False 36 | binary_mask: True 37 | num_mixtures: 2 38 | 39 | opt: 40 | resume: True 41 | batch_size: 16 42 | epochs: 300 43 | warmup_epochs: 40 44 | accum_iter: 8 45 | clip_grad: 46 | weight_decay: 0.05 47 | layer_decay: 0.75 48 | lr: 49 | blr: 3e-4 50 | min_lr: 0. 51 | drop_path: 0. 52 | attn_drop: 0. 53 | proj_drop: 0. 54 | use_amp: False 55 | 56 | data: 57 | dataset: vggsound 58 | data_path: /srv/home/groups/pmorgado/datasets/vggsound 59 | audio_rate: 16000 60 | audio_dur: 3. 61 | audio_mels: 128 62 | image_size: 224 63 | crop_min: 0.5 64 | -------------------------------------------------------------------------------- /configs/avsync.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | 6 | worker: eval_avsync 7 | output_dir: checkpoints 8 | job_name: avsync 9 | pretrain_job_name: 10 | checkpoint: 11 | encoder_prefix: 'encoder.' 12 | pretrain_resume_epoch: latest 13 | eval: False 14 | debug: False 15 | 16 | model: 17 | video: 18 | backbone: video_vit_base 19 | pretrained: vit_base_mae_in1k 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | fusion: 24 | layers: all 25 | num_fusion_tkns: 16 26 | num_aggr_visual_tkns: 8 27 | num_aggr_audio_tkns: 8 28 | mlp_ratio: 4.0 29 | attn_ratio: 0.25 30 | num_heads: 12 31 | 32 | opt: 33 | resume: True 34 | joint_loss: True 35 | batch_size: 32 36 | epochs: 100 37 | warmup_epochs: 20 38 | accum_iter: 4 39 | clip_grad: 40 | weight_decay: 0.05 41 | layer_decay: 0.75 42 | smoothing: 0.1 43 | lr: 44 | blr: 3e-4 45 | min_lr: 0. 46 | drop_path: 0.2 47 | attn_drop: 0. 48 | proj_drop: 0. 49 | use_amp: False 50 | 51 | data: 52 | dataset: vggsounds 53 | data_path: /srv/home/groups/pmorgado/datasets/vggsounds 54 | audio_rate: 16000 55 | audio_dur: 3. 56 | audio_mels: 128 57 | crop_size: 224 58 | crop_min: 0.5 59 | num_frames: 16 60 | video_rate: 8 -------------------------------------------------------------------------------- /configs/deepavfusion.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | - nn_probe: default 6 | 7 | worker: train 8 | output_dir: checkpoints 9 | job_name: deepavfusion_${data.dataset}_ep${opt.epochs}_bs${opt.batch_size}x${env.ngpu}x${opt.accum_iter}_blr${opt.blr} 10 | debug: False 11 | 12 | model: 13 | image: 14 | backbone: vit_base 15 | pretrained: vit_base_mae_in1k 16 | decoder_arch: plain 17 | decoder_depth: 8 18 | mask_ratio: 0.75 19 | norm_loss: True 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | decoder_arch: plain 24 | decoder_depth: 8 25 | mask_ratio: 0.8 26 | norm_loss: True 27 | fusion: 28 | arch: factorized_mmi 29 | layers: all 30 | num_fusion_tkns: 16 31 | num_aggr_image_tkns: 8 32 | num_aggr_audio_tkns: 8 33 | mlp_ratio: 4.0 34 | attn_ratio: 0.25 35 | num_heads: 12 36 | 37 | data: 38 | dataset: vggsound 39 | data_path: /srv/home/groups/pmorgado/datasets/vggsound 40 | audio_rate: 16000 41 | audio_dur: 3. 42 | audio_mels: 128 43 | image_size: 224 44 | crop_min: 0.5 45 | 46 | opt: 47 | resume: True 48 | epochs: 300 49 | warmup_epochs: 50 50 | batch_size: 128 51 | accum_iter: 1 52 | weight_decay: 0.05 53 | blr: 1.5e-4 54 | min_lr: 0. 55 | lr: 56 | pt_lr_mult_start: 0 57 | pt_lr_mult_end: 1 58 | pt_warmup_epochs: ${opt.epochs}/2 59 | clip_grad: 60 | use_amp: True 61 | -------------------------------------------------------------------------------- /configs/env/default.yaml: -------------------------------------------------------------------------------- 1 | world_size: 1 2 | rank: 0 3 | dist_url: tcp://localhost:50000 4 | dist_backend: nccl 5 | port: 50000 6 | node: localhost 7 | 8 | distributed: True 9 | seed: 10 | gpu: 11 | ngpu: 1 12 | mem_gb: 13 | workers: 14 | 15 | slurm: True 16 | slurm_suffix: "" 17 | slurm_partition: morgadolab 18 | slurm_timeout: 1440 19 | nodelist: "" 20 | exclude: "euler01,euler02,euler03,euler04,euler05,euler06,euler07,euler08,euler09,euler11,euler12,euler13,euler14,euler15,euler16,euler21,euler24,euler25,euler26,euler27" 21 | -------------------------------------------------------------------------------- /configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | 6 | worker: eval_finetune 7 | output_dir: checkpoints 8 | job_name: finetune 9 | pretrain_job_name: 10 | checkpoint: 11 | encoder_prefix: 'encoder.' 12 | pretrain_resume_epoch: latest 13 | eval: False 14 | debug: False 15 | 16 | model: 17 | image: 18 | backbone: vit_base 19 | pretrained: vit_base_mae_in1k 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | fusion: 24 | arch: factorized_mmi 25 | layers: all 26 | num_fusion_tkns: 16 27 | num_aggr_image_tkns: 8 28 | num_aggr_audio_tkns: 8 29 | mlp_ratio: 4.0 30 | attn_ratio: 0.25 31 | num_heads: 12 32 | 33 | opt: 34 | resume: True 35 | joint_loss: True 36 | batch_size: 32 37 | epochs: 100 38 | warmup_epochs: 20 39 | accum_iter: 4 40 | clip_grad: 41 | weight_decay: 0.05 42 | layer_decay: 0.75 43 | smoothing: 0.1 44 | lr: 45 | blr: 3e-4 46 | min_lr: 0. 47 | drop_path: 0.2 48 | attn_drop: 0. 49 | proj_drop: 0. 50 | use_amp: False 51 | 52 | data: 53 | dataset: vggsounds 54 | data_path: /srv/home/groups/pmorgado/datasets/vggsounds 55 | audio_rate: 16000 56 | audio_dur: 3. 57 | audio_mels: 128 58 | image_size: 224 59 | crop_min: 0.5 60 | 61 | mixup: 1. 62 | cutmix: 0. 63 | cutmix_minmax: 64 | mixup_prob: 1. 65 | mixup_switch_prob: 0.5 66 | mixup_mode: batch -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | dir: . 3 | sweep: 4 | dir: . 5 | subdir: . 6 | -------------------------------------------------------------------------------- /configs/linprobe.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra: default 3 | - env: default 4 | - log: default 5 | 6 | worker: eval_linprobe 7 | output_dir: checkpoints 8 | job_name: linprobe 9 | pretrain_job_name: 10 | checkpoint: 11 | encoder_prefix: 'encoder.' 12 | pretrain_resume_epoch: latest 13 | eval: False 14 | debug: False 15 | 16 | model: 17 | image: 18 | backbone: vit_base 19 | pretrained: vit_base_mae_in1k 20 | audio: 21 | backbone: vit_base 22 | pretrained: vit_base_audiomae_as2m 23 | fusion: 24 | arch: factorized_mmi 25 | layers: all 26 | num_fusion_tkns: 16 27 | num_aggr_image_tkns: 8 28 | num_aggr_audio_tkns: 8 29 | mlp_ratio: 4.0 30 | attn_ratio: 0.25 31 | num_heads: 12 32 | 33 | opt: 34 | resume: True 35 | use_amp: False 36 | batch_size: 64 37 | epochs: 60 38 | warmup_epochs: 10 39 | accum_iter: 4 40 | clip_grad: 41 | weight_decay: 0.0 42 | lr: 43 | blr: 0.3 44 | min_lr: 0. 45 | 46 | data: 47 | dataset: vggsounds 48 | data_path: /srv/home/groups/pmorgado/datasets/vggsounds 49 | audio_rate: 16000 50 | audio_dur: 3. 51 | audio_mels: 128 52 | image_size: 224 53 | crop_min: 0.5 54 | -------------------------------------------------------------------------------- /configs/log/default.yaml: -------------------------------------------------------------------------------- 1 | # Logging frequency 2 | print_freq: 100 3 | save_freq: 50 4 | eval_freq: 10 5 | wandb_watch_freq: 0 6 | debug: False 7 | 8 | # Logging 9 | use_wandb: True 10 | wandb_entity: pmorgado 11 | wandb_project: efav 12 | -------------------------------------------------------------------------------- /configs/nn_probe/default.yaml: -------------------------------------------------------------------------------- 1 | dataset: vggsound 2 | data_path: ${data.data_path} 3 | batch_size: ${opt.batch_size} 4 | audio_rate: ${data.audio_rate} 5 | audio_dur: ${data.audio_dur} 6 | audio_mels: ${data.audio_mels} 7 | image_size: ${data.image_size} 8 | crop_min: ${data.crop_min} -------------------------------------------------------------------------------- /eval_avsrcsep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | from typing import Iterable 4 | 5 | import torch 6 | from torch import nn 7 | import numpy as np 8 | 9 | import datasets as myDBs 10 | from torchvision import transforms as vT 11 | from util import audio_transforms as aT 12 | 13 | from models.deepavfusion import DeepAVFusion 14 | from models.avsrcsep import AVSrcSep 15 | from mir_eval.separation import bss_eval_sources 16 | 17 | from util import distributed as dist_utils 18 | from util import misc as misc_utils 19 | from util import data as data_utils 20 | from util import meters, lr_sched 21 | 22 | 23 | def main_worker(local_rank, args): 24 | # Setup environment 25 | job_dir = f"{args.output_dir}/{args.job_name}" 26 | dist_utils.init_distributed_mode(local_rank, args, log_fn=f"{job_dir}/train.log") 27 | device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') 28 | print(f'job dir: {job_dir}') 29 | misc_utils.print_args(args) 30 | 31 | # Adjust learning rate to batch size 32 | num_tasks = dist_utils.get_world_size() 33 | num_tasks_per_node = max(1, torch.cuda.device_count()) 34 | args.env.workers = args.env.workers // num_tasks_per_node 35 | eff_batch_size = args.opt.batch_size * args.opt.accum_iter * num_tasks 36 | if args.opt.lr is None: # only base_lr is specified 37 | args.opt.lr = args.opt.blr * eff_batch_size / 256 38 | print("base lr: %.2e" % args.opt.blr) 39 | print("actual lr: %.2e" % args.opt.lr) 40 | print("accumulate grad iterations: %d" % args.opt.accum_iter) 41 | print("effective batch size: %d" % eff_batch_size) 42 | 43 | # Dataloaders 44 | dataset_train = myDBs.load_dataset( 45 | args.data.dataset, 46 | args.data.data_path, 47 | dataset_type='mixed_audio', 48 | visual_transform=vT.Compose([ 49 | vT.RandomResizedCrop(args.data.image_size, scale=(args.data.crop_min, 1.)), 50 | vT.RandomHorizontalFlip(), 51 | vT.ToTensor(), 52 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 53 | ]), 54 | audio_transform=aT.Compose([ 55 | aT.Pad(rate=args.data.audio_rate, dur=args.data.audio_dur), 56 | aT.MelSpectrogram(sample_rate=args.data.audio_rate, n_fft=int(args.data.audio_rate * 0.05), hop_length=int(args.data.audio_rate / 64), n_mels=args.data.audio_mels), 57 | aT.Log() 58 | ]), 59 | train=True, 60 | num_mixtures=args.avss.num_mixtures, 61 | audio_dur=args.data.audio_dur, 62 | audio_rate=args.data.audio_rate, 63 | temporal_jitter=True 64 | ) 65 | loader_train = data_utils.get_dataloader( 66 | dataset_train, args.env.distributed, args.opt.batch_size, args.env.workers, shuffle=True, drop_last=True) 67 | print(dataset_train) 68 | 69 | dataset_val = myDBs.load_dataset( 70 | args.data.dataset, 71 | args.data.data_path, 72 | dataset_type='mixed_audio', 73 | visual_transform=vT.Compose([ 74 | vT.Resize(int(args.data.image_size/0.875)), 75 | vT.CenterCrop(args.data.image_size), 76 | vT.ToTensor(), 77 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), 78 | audio_transform=aT.Compose([ 79 | aT.Pad(rate=args.data.audio_rate, dur=args.data.audio_dur), 80 | aT.MelSpectrogram(sample_rate=args.data.audio_rate, n_fft=int(args.data.audio_rate * 0.05), hop_length=int(args.data.audio_rate / 64), n_mels=args.data.audio_mels), 81 | aT.Log()]), 82 | train=False, 83 | num_mixtures=args.avss.num_mixtures, 84 | audio_dur=args.data.audio_dur, 85 | audio_rate=args.data.audio_rate, 86 | ) 87 | loader_val = data_utils.get_dataloader( 88 | dataset_val, args.env.distributed, args.opt.batch_size, args.env.workers, shuffle=False, drop_last=False) 89 | print(dataset_val) 90 | 91 | # Create model 92 | image_size, audio_size = (args.data.image_size, args.data.image_size), (args.data.audio_mels, int(args.data.audio_dur*64)) 93 | encoder = DeepAVFusion( 94 | image_arch=args.model.image.backbone, image_pretrained=args.model.image.pretrained, image_size=image_size, 95 | audio_arch=args.model.audio.backbone, audio_pretrained=args.model.audio.pretrained, audio_size=audio_size, 96 | fusion_arch=args.model.fusion.arch, 97 | fusion_layers=args.model.fusion.layers, 98 | num_fusion_tkns=(args.model.fusion.num_fusion_tkns, 99 | args.model.fusion.num_aggr_image_tkns, 100 | args.model.fusion.num_aggr_audio_tkns), 101 | drop_path=args.opt.drop_path, 102 | attn_drop=args.opt.attn_drop, 103 | drop=args.opt.proj_drop, 104 | fusion_mlp_ratio=args.model.fusion.mlp_ratio, 105 | fusion_attn_ratio=args.model.fusion.attn_ratio, 106 | fusion_num_heads=args.model.fusion.num_heads 107 | ) 108 | model = AVSrcSep(encoder=encoder, 109 | log_freq=args.avss.log_freq, 110 | weighted_loss=args.avss.weighted_loss, 111 | binary_mask=args.avss.binary_mask) 112 | model.to(device) 113 | print("Model = %s" % str(model)) 114 | 115 | if args.checkpoint or args.pretrain_job_name: 116 | pretrain_ckpt = args.checkpoint or f"{args.output_dir}/checkpoints/checkpoint_{args.pretrain_resume_epoch}.pth" 117 | encoder.load_checkpoint(pretrain_ckpt, args.encoder_prefix) 118 | 119 | # Optimizer with layer-wise lr decay (lrd) 120 | param_groups = lr_sched.param_groups_lrd( 121 | model, args.opt.weight_decay, 122 | no_weight_decay_list=[n for n, p in model.named_parameters() if 'bias' in n or 'norm' in n], 123 | layer_decay=args.opt.layer_decay) 124 | optimizer = torch.optim.AdamW(param_groups, lr=args.opt.lr) 125 | 126 | # Trainer 127 | trainer = misc_utils.Trainer( 128 | model, 129 | optimizer=optimizer, 130 | use_amp=args.opt.use_amp, 131 | accum_iter=args.opt.accum_iter, 132 | distributed=args.env.distributed 133 | ) 134 | 135 | # Checkpointing and logging 136 | ckpt_manager = misc_utils.CheckpointManager( 137 | modules=trainer.module_dict(), 138 | ckpt_dir=f"{job_dir}/checkpoints", 139 | epochs=args.opt.epochs, 140 | save_freq=args.log.save_freq) 141 | start_epoch = ckpt_manager.resume()[0] if args.opt.resume else 0 142 | wb_logger = misc_utils.WBLogger( 143 | f"{job_dir}/wandb", args.log.wandb_entity, args.log.wandb_project+'-avsrcsep', args.job_name, 144 | model, args) 145 | 146 | if args.eval: 147 | evaluate(trainer.eval_model, loader_val, start_epoch, device, args) 148 | exit(0) 149 | 150 | # =============================================================== # 151 | # Training loop 152 | print(f"Start training for {args.opt.epochs} epochs") 153 | for epoch in range(start_epoch, args.opt.epochs): 154 | if args.env.distributed: 155 | loader_train.sampler.set_epoch(epoch) 156 | 157 | # train for one epoch 158 | train_one_epoch(trainer, loader_train, epoch, 159 | device=device, wb_logger=wb_logger, args=args) 160 | 161 | # evaluate 162 | if epoch % args.log.eval_freq == 0 or epoch == args.opt.epochs - 1 or epoch == start_epoch: 163 | global_step = (len(loader_train) // trainer.accum_iter) * (epoch + 1) 164 | test_stats = evaluate(trainer.eval_model, loader_val, epoch, device, args) 165 | wb_logger.log(test_stats, step=global_step, force=True) 166 | 167 | # save checkpoint 168 | ckpt_manager.checkpoint(epoch+1, {'epoch': epoch+1}) 169 | 170 | 171 | def train_one_epoch(trainer: misc_utils.Trainer, 172 | loader: Iterable, 173 | epoch: int = 0, 174 | wb_logger: misc_utils.WBLogger = None, 175 | device: torch.device = torch.device('cpu'), 176 | args=None): 177 | trainer.model.train(True) 178 | metric_logger = meters.MetricLogger(delimiter=" ") 179 | header = f'[Train][Ep-{epoch}/{args.opt.epochs}]' 180 | 181 | trainer.zero_grad() 182 | for step, (image, audio_mix, anno) in enumerate(metric_logger.log_every(loader, args.log.print_freq, header)): 183 | sys.stdout.flush() 184 | global_step = (len(loader) // trainer.accum_iter) * epoch + step // trainer.accum_iter 185 | if step % args.opt.accum_iter == 0: 186 | lr = lr_sched.adjust_learning_rate(trainer.optimizer, epoch + step / len(loader), args) 187 | metric_logger.update(lr=lr) 188 | 189 | # Prepare data 190 | image = image[0].to(device, non_blocking=True).float() 191 | audio_mix = audio_mix.to(device, non_blocking=True).float() 192 | audio_trg = anno['mel_specs'][:, 0].to(device, non_blocking=True).float() 193 | 194 | # Forward pass 195 | with trainer.autocast(), trainer.autosync(): 196 | loss = trainer.model(image, audio_mix, audio_trg)[0] 197 | 198 | if not math.isfinite(loss.item()): 199 | raise f"Loss is {loss.item()}, stopping training" 200 | 201 | # Backward pass and model update 202 | grad_norm, amp_scale = trainer.step(loss) 203 | 204 | # Log 205 | if trainer.accums == 0: 206 | metric_logger.update(loss=loss.item(), grad_norm=grad_norm, amp_scale=amp_scale, n=image.shape[0]) 207 | wb_logger.log(metric_logger.latest(), step=global_step) 208 | 209 | if args.debug and step == 100: 210 | break 211 | 212 | # gather the stats from all processes 213 | print("Syncing meters...") 214 | metric_logger.synchronize_between_processes() 215 | print("Averaged stats:", metric_logger) 216 | trainer.zero_grad() 217 | return metric_logger.averages() 218 | 219 | 220 | @torch.no_grad() 221 | def evaluate(model: nn.Module, 222 | loader: Iterable, 223 | epoch: int = 0, 224 | device: torch.device = torch.device('cpu'), 225 | args=None): 226 | model.train(False) 227 | metric_logger = meters.MetricLogger(delimiter=" ") 228 | header = f'[Eval][Ep-{epoch}/{args.opt.epochs}]' 229 | 230 | evaluator = AVSrcSepEvaluator() 231 | masking_separation = SpectrogramMasking(args.data.audio_rate, args.data.audio_mels) 232 | for step, (image, audio, anno, name) in enumerate(metric_logger.log_every(loader, args.log.print_freq, header)): 233 | # Prepare data 234 | audio = audio.to(device, non_blocking=True).float() 235 | audio_mix = anno['mixed_audio'].to(device, non_blocking=True).float() 236 | frames1 = anno['frames'][:, 0].to(device, non_blocking=True).float() 237 | frames2 = anno['frames'][:, 1].to(device, non_blocking=True).float() 238 | 239 | # Separate 240 | _, pred_mask1, _ = model(frames1, audio_mix, audio) 241 | _, pred_mask2, _ = model(frames2, audio_mix, audio) 242 | 243 | # Compute separation metrics 244 | mix_waveforms = anno['waveforms'].sum(1) 245 | for i in range(audio.shape[0]): 246 | waveform_gt = anno['waveforms'][i].squeeze(1) 247 | waveform_pred1 = masking_separation(mix_waveforms[i], pred_mask1[i]) 248 | waveform_pred2 = masking_separation(mix_waveforms[i], pred_mask2[i]) 249 | waveform_pred = np.stack((waveform_pred1, waveform_pred2), axis=0).squeeze(1) 250 | if torch.any(waveform_gt.pow(2).sum(-1) < 1e-5): 251 | continue 252 | if np.any((waveform_pred**2).sum(-1) < 1e-5): 253 | continue 254 | evaluator.update(waveform_gt, waveform_pred, name[i]) 255 | 256 | if args.debug and step == 8: 257 | break 258 | 259 | sdr, sir, sar = evaluator.average_sdr_sir_sar() 260 | print(f'{header} SDR={round(sdr,5)} SIR={round(sir,5)} SAR={round(sar,5)}') 261 | return {'sdr': sdr, 'sir': sir, 'sar': sar} 262 | 263 | 264 | class SpectrogramMasking: 265 | def __init__(self, audio_rate, audio_mels): 266 | n_fft = int(audio_rate * 0.05) 267 | hop_length = int(audio_rate / 64) 268 | self.spectrogram_t = aT.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None) 269 | self.inv_spectrogram_t = aT.InverseSpectrogram(n_fft=n_fft, hop_length=hop_length) 270 | self.mel_spec = aT.MelSpectrogram(sample_rate=audio_rate, n_fft=n_fft, hop_length=hop_length, n_mels=audio_mels) 271 | 272 | def __call__(self, waveform_mix, pred_mask): 273 | stft_mix = self.spectrogram_t(waveform_mix) 274 | pred_mask = torch.cat((torch.sigmoid(pred_mask).detach().cpu(), torch.zeros(*pred_mask.shape[:2], 1)), dim=2) 275 | pred_mask = torch.einsum('bmt,fm->bft', pred_mask, self.mel_spec.mel_scale.fb) 276 | waveform_pred = self.inv_spectrogram_t(pred_mask * stft_mix) 277 | return waveform_pred 278 | 279 | 280 | class AVSrcSepEvaluator(object): 281 | def __init__(self, ): 282 | super(AVSrcSepEvaluator, self).__init__() 283 | 284 | self.name_list = [] 285 | self.sdr_list = [] 286 | self.sir_list = [] 287 | self.sar_list = [] 288 | 289 | def average_sdr_sir_sar(self): 290 | sdr = np.mean(self.sdr_list) 291 | sir = np.mean(self.sir_list) 292 | sar = np.mean(self.sar_list) 293 | return sdr, sir, sar 294 | 295 | def clear(self): 296 | self.name_list = [] 297 | self.sdr_list = [] 298 | self.sir_list = [] 299 | self.sar_list = [] 300 | 301 | def update(self, waveform_gt, waveform_pred, name): 302 | if isinstance(waveform_gt, torch.Tensor): 303 | waveform_gt = waveform_gt.detach().cpu().numpy() 304 | if isinstance(waveform_pred, torch.Tensor): 305 | waveform_pred = waveform_pred.detach().cpu().numpy() 306 | 307 | sdr, sir, sar, _ = bss_eval_sources(waveform_gt, waveform_pred, False) 308 | 309 | # Save 310 | self.name_list.append(name) 311 | self.sdr_list.append(sdr) 312 | self.sir_list.append(sir) 313 | self.sar_list.append(sar) -------------------------------------------------------------------------------- /eval_linprobe.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | from typing import Iterable 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import datasets as myDBs 9 | from torchvision import transforms as vT 10 | from util import audio_transforms as aT 11 | 12 | from models.deepavfusion import DeepAVFusion 13 | from models.classifier import AVClassifier 14 | 15 | from util import distributed as dist_utils 16 | from util import misc as misc_utils 17 | from util import data as data_utils 18 | from util import meters, lars, lr_sched 19 | from timm.utils import accuracy 20 | 21 | 22 | def main_worker(local_rank, args): 23 | # Setup environment 24 | job_dir = f"{args.output_dir}/{args.job_name}" 25 | dist_utils.init_distributed_mode(local_rank, args, log_fn=f"{job_dir}/train.log") 26 | device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') 27 | print(f'job dir: {job_dir}') 28 | misc_utils.print_args(args) 29 | 30 | # Adjust learning rate to batch size 31 | num_tasks = dist_utils.get_world_size() 32 | num_tasks_per_node = max(1, torch.cuda.device_count()) 33 | args.env.workers = args.env.workers // num_tasks_per_node 34 | eff_batch_size = args.opt.batch_size * args.opt.accum_iter * num_tasks 35 | if args.opt.lr is None: # only base_lr is specified 36 | args.opt.lr = args.opt.blr * eff_batch_size / 256 37 | print("base lr: %.2e" % args.opt.blr) 38 | print("actual lr: %.2e" % args.opt.lr) 39 | print("accumulate grad iterations: %d" % args.opt.accum_iter) 40 | print("effective batch size: %d" % eff_batch_size) 41 | 42 | # Dataloaders 43 | dataset_train = myDBs.load_dataset( 44 | args.data.dataset, 45 | args.data.data_path, 46 | dataset_type='simple', 47 | visual_transform=vT.Compose([ 48 | vT.RandomResizedCrop(args.data.image_size, scale=(args.data.crop_min, 1.)), 49 | vT.RandomHorizontalFlip(), 50 | vT.ToTensor(), 51 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), 52 | audio_transform=aT.Compose([ 53 | aT.Pad(rate=args.data.audio_rate, dur=args.data.audio_dur), 54 | aT.RandomVol(), 55 | aT.MelSpectrogram(sample_rate=args.data.audio_rate, n_fft=int(args.data.audio_rate * 0.05), hop_length=int(args.data.audio_rate / 64), n_mels=args.data.audio_mels), 56 | aT.Log()]), 57 | train=True, 58 | audio_dur=args.data.audio_dur, 59 | audio_rate=args.data.audio_rate, 60 | temporal_jitter=True 61 | ) 62 | loader_train = data_utils.get_dataloader( 63 | dataset_train, args.env.distributed, args.opt.batch_size, args.env.workers, shuffle=True, drop_last=True) 64 | print(dataset_train) 65 | 66 | dataset_val = myDBs.load_dataset( 67 | args.data.dataset, 68 | args.data.data_path, 69 | dataset_type='simple', 70 | visual_transform=vT.Compose([ 71 | vT.Resize(int(args.data.image_size/0.875)), 72 | vT.CenterCrop(args.data.image_size), 73 | vT.ToTensor(), 74 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), 75 | audio_transform=aT.Compose([ 76 | aT.Pad(rate=args.data.audio_rate, dur=args.data.audio_dur), 77 | aT.MelSpectrogram(sample_rate=args.data.audio_rate, n_fft=int(args.data.audio_rate * 0.05), hop_length=int(args.data.audio_rate / 64), n_mels=args.data.audio_mels), 78 | aT.Log()]), 79 | train=False, 80 | audio_dur=args.data.audio_dur, 81 | audio_rate=args.data.audio_rate, 82 | temporal_jitter=False 83 | ) 84 | loader_val = data_utils.get_dataloader( 85 | dataset_val, args.env.distributed, args.opt.batch_size, args.env.workers, shuffle=False, drop_last=False) 86 | print(dataset_val) 87 | 88 | # Create model 89 | image_size, audio_size = (args.data.image_size, args.data.image_size), (args.data.audio_mels, int(args.data.audio_dur*64)) 90 | encoder = DeepAVFusion( 91 | image_arch=args.model.image.backbone, image_pretrained=args.model.image.pretrained, image_size=image_size, 92 | audio_arch=args.model.audio.backbone, audio_pretrained=args.model.audio.pretrained, audio_size=audio_size, 93 | fusion_arch=args.model.fusion.arch, 94 | fusion_layers=args.model.fusion.layers, 95 | num_fusion_tkns=(args.model.fusion.num_fusion_tkns, 96 | args.model.fusion.num_aggr_image_tkns, 97 | args.model.fusion.num_aggr_audio_tkns), 98 | fusion_mlp_ratio=args.model.fusion.mlp_ratio, 99 | fusion_attn_ratio=args.model.fusion.attn_ratio, 100 | fusion_num_heads=args.model.fusion.num_heads 101 | ) 102 | model = AVClassifier(encoder, myDBs.NUM_CLASSES[args.data.dataset], freeze_encoder=True, input_norm=True) 103 | model.to(device) 104 | print("Model = %s" % str(model)) 105 | 106 | if args.checkpoint or args.pretrain_job_name: 107 | pretrain_ckpt = args.checkpoint or f"{args.output_dir}/checkpoints/checkpoint_{args.pretrain_resume_epoch}.pth" 108 | encoder.load_checkpoint(pretrain_ckpt, args.encoder_prefix) 109 | 110 | # criterion 111 | criterion = nn.BCEWithLogitsLoss() if myDBs.MULTI_CLASS_DBS[args.data.dataset] else nn.CrossEntropyLoss() 112 | print("criterion = %s" % str(criterion)) 113 | 114 | # Optimizer 115 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 116 | assert len(parameters) == 6 117 | optimizer = lars.LARS(parameters, lr=args.opt.lr, weight_decay=args.opt.weight_decay) 118 | 119 | # Trainer 120 | trainer = misc_utils.Trainer( 121 | model, 122 | criterion=criterion, 123 | optimizer=optimizer, 124 | use_amp=args.opt.use_amp, 125 | accum_iter=args.opt.accum_iter, 126 | distributed=args.env.distributed 127 | ) 128 | 129 | # Checkpointing 130 | ckpt_manager = misc_utils.CheckpointManager( 131 | modules=trainer.module_dict(), 132 | ckpt_dir=f"{job_dir}/checkpoints", 133 | epochs=args.opt.epochs, 134 | save_freq=args.log.save_freq) 135 | start_epoch = ckpt_manager.resume()[0] if args.opt.resume else 0 136 | wb_logger = misc_utils.WBLogger( 137 | f"{job_dir}/wandb", args.log.wandb_entity, args.log.wandb_project+'-linprobe', args.job_name, 138 | model, args) 139 | 140 | if args.eval: 141 | evaluate(trainer.eval_model, loader_val, start_epoch, device, args) 142 | exit(0) 143 | 144 | # =============================================================== # 145 | # Training loop 146 | print(f"Start training for {args.opt.epochs} epochs") 147 | for epoch in range(start_epoch, args.opt.epochs): 148 | if args.env.distributed: 149 | loader_train.sampler.set_epoch(epoch) 150 | 151 | # train for one epoch 152 | train_one_epoch(trainer, loader_train, epoch, 153 | device=device, wb_logger=wb_logger, args=args) 154 | 155 | # evaluate 156 | if epoch % args.log.eval_freq == 0 or epoch == args.opt.epochs - 1 or epoch == start_epoch: 157 | global_step = (len(loader_train) // trainer.accum_iter) * (epoch + 1) 158 | test_stats = evaluate(trainer.eval_model, loader_val, epoch, device, args) 159 | wb_logger.log(test_stats, step=global_step, force=True) 160 | 161 | # save checkpoint 162 | ckpt_manager.checkpoint(epoch+1, {'epoch': epoch+1}) 163 | 164 | 165 | def train_one_epoch(trainer: misc_utils.Trainer, 166 | loader: Iterable, 167 | epoch: int = 0, 168 | wb_logger: misc_utils.WBLogger = None, 169 | device: torch.device = torch.device('cpu'), 170 | args=None): 171 | trainer.model.train(True) 172 | header = f'[Train][Ep-{epoch}/{args.opt.epochs}]' 173 | metric_logger = meters.MetricLogger(delimiter=" ") 174 | metric_logger.add_meter('lr', meters.SmoothedValue(window_size=1, fmt='{value:.6f}')) 175 | 176 | trainer.zero_grad() 177 | for step, (image, audio, anno) in enumerate(metric_logger.log_every(loader, args.log.print_freq, header)): 178 | sys.stdout.flush() 179 | global_step = (len(loader) // trainer.accum_iter) * epoch + step // trainer.accum_iter 180 | if step % args.opt.accum_iter == 0: 181 | lr = lr_sched.adjust_learning_rate(trainer.optimizer, epoch + step / len(loader), args) 182 | metric_logger.update(lr=lr) 183 | 184 | # Prepare data 185 | image = image.to(device, non_blocking=True).float() 186 | audio = audio.to(device, non_blocking=True).float() 187 | targets = anno['class'].to(device, non_blocking=True) 188 | targets = targets.float() if myDBs.MULTI_CLASS_DBS[args.data.dataset] else targets.long() 189 | 190 | # Forward pass 191 | with trainer.autocast(), trainer.autosync(): 192 | preds_image, preds_audio, preds_fusion = trainer.model(image, audio) 193 | preds = (preds_image + preds_audio + preds_fusion) / 3. 194 | loss = trainer.criterion(preds, targets) 195 | if not math.isfinite(loss.item()): 196 | raise f"Loss is {loss.item()}, stopping training" 197 | 198 | # Backward pass and model update 199 | grad_norm, amp_scale = trainer.step(loss) 200 | 201 | # Log 202 | if trainer.accums == 0: 203 | metric_logger.update(loss=loss.item(), grad_norm=grad_norm, amp_scale=amp_scale, n=image.shape[0]) 204 | wb_logger.log(metric_logger.latest(), step=global_step) 205 | 206 | if args.debug and step == 100: 207 | break 208 | 209 | # gather the stats from all processes 210 | print("Syncing meters...") 211 | metric_logger.synchronize_between_processes() 212 | print("Averaged stats:", metric_logger) 213 | trainer.zero_grad() 214 | return metric_logger.averages() 215 | 216 | 217 | @torch.no_grad() 218 | def evaluate(model: nn.Module, 219 | loader: Iterable, 220 | epoch: int = 0, 221 | device: torch.device = torch.device('cpu'), 222 | args=None): 223 | model.train(False) 224 | metric_logger = meters.MetricLogger(delimiter=" ") 225 | header = f'[Eval][Ep-{epoch}/{args.opt.epochs}]' 226 | 227 | preds_all, preds_image_all, preds_audio_all, preds_fusion_all, labels_all = [], [], [], [], [] 228 | for step, (image, audio, anno) in enumerate(metric_logger.log_every(loader, args.log.print_freq, header)): 229 | image = image.to(device, non_blocking=True).float() 230 | audio = audio.to(device, non_blocking=True).float() 231 | label = anno['class'].to(device, non_blocking=True).long() 232 | 233 | preds_image, preds_audio, preds_fusion = model(image, audio) 234 | preds = (preds_image + preds_audio + preds_fusion) / 3. 235 | preds_image_all.append(preds_image), preds_audio_all.append(preds_audio), preds_fusion_all.append(preds_fusion) 236 | preds_all.append(preds), labels_all.append(label) 237 | 238 | if args.debug and step == 8: 239 | break 240 | 241 | # Synchronize across gpus 242 | preds_image_all = dist_utils.concat_all_gather(torch.cat(preds_image_all)) 243 | preds_audio_all = dist_utils.concat_all_gather(torch.cat(preds_audio_all)) 244 | preds_fusion_all = dist_utils.concat_all_gather(torch.cat(preds_fusion_all)) 245 | preds_all = dist_utils.concat_all_gather(torch.cat(preds_all)) 246 | labels_all = dist_utils.concat_all_gather(torch.cat(labels_all)) 247 | 248 | # measure performance 249 | stats = dict() 250 | if myDBs.MULTI_CLASS_DBS[args.data.dataset]: 251 | labels_all = labels_all.cpu().numpy() 252 | 253 | for mod, preds in [('image', preds_image_all), ('audio', preds_audio_all), ('fusion', preds_fusion_all), ('all', preds_all)]: 254 | preds = preds.cpu().numpy() 255 | stats_mod = misc_utils.calc_multi_class_stats(labels_all, preds) 256 | stats.update({f'{k}_{mod}': v for k, v in stats_mod.items()}) 257 | else: 258 | stats.update( 259 | val_acc1_image=accuracy(preds_image_all, labels_all)[0].item(), 260 | val_acc1_audio=accuracy(preds_audio_all, labels_all)[0].item(), 261 | val_acc1_fusion=accuracy(preds_fusion_all, labels_all)[0].item(), 262 | val_acc1_all=accuracy(preds_all, labels_all)[0].item(), 263 | ) 264 | prefix = 'val_' 265 | stats = {f"{prefix}{k}": v for k, v in stats.items()} 266 | 267 | msg = ' | '.join([f'{k}={v:.2f}' for k, v in stats.items()]) 268 | print(f"{header} {msg}") 269 | return stats 270 | -------------------------------------------------------------------------------- /launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import copy 4 | import time 5 | import warnings 6 | import logging 7 | from pathlib import Path 8 | 9 | import hydra 10 | import hydra.utils as hydra_utils 11 | import submitit 12 | 13 | import torch 14 | import torch.multiprocessing as mp 15 | import importlib 16 | import numpy as np 17 | 18 | os.environ['MKL_THREADING_LAYER'] = 'GNU' 19 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | def update_pythonpath_relative_hydra(): 25 | """Update PYTHONPATH to only have absolute paths.""" 26 | # NOTE: We do not change sys.path: we want to update paths for future instantiations 27 | # of python using the current environment (namely, when submitit loads the job 28 | # pickle). 29 | try: 30 | original_cwd = Path(hydra_utils.get_original_cwd()).resolve() 31 | except (AttributeError, ValueError): 32 | # Assume hydra is not initialized, we don't need to do anything. 33 | # In hydra 0.11, this returns AttributeError; later it will return ValueError 34 | # https://github.com/facebookresearch/hydra/issues/496 35 | # I don't know how else to reliably check whether Hydra is initialized. 36 | return 37 | paths = [] 38 | for orig_path in os.environ["PYTHONPATH"].split(":"): 39 | path = Path(orig_path) 40 | if not path.is_absolute(): 41 | path = original_cwd / path 42 | paths.append(path.resolve()) 43 | os.environ["PYTHONPATH"] = ":".join([str(x) for x in paths]) 44 | log.info('PYTHONPATH: {}'.format(os.environ["PYTHONPATH"])) 45 | 46 | 47 | class Worker: 48 | def __call__(self, args): 49 | mp.set_start_method('spawn') 50 | main_function = getattr(importlib.import_module(args.worker), 'main_worker') 51 | 52 | np.set_printoptions(precision=3) 53 | socket_name = os.popen("ip r | grep default | awk '{print $5}'").read().strip('\n') 54 | print("Setting GLOO and NCCL sockets IFNAME to: {}".format(socket_name)) 55 | os.environ["GLOO_SOCKET_IFNAME"] = socket_name 56 | 57 | # Use random port to avoid collision between parallel jobs 58 | # if args.env.world_size == 1: 59 | # args.env.port = np.random.randint(40000, 50000) 60 | # 61 | # if args.env.slurm: 62 | # job_env = submitit.JobEnvironment() 63 | # args.env.rank = job_env.global_rank 64 | # args.env.dist_url = f'tcp://{job_env.hostnames[0]}:{args.env.port}' 65 | # else: 66 | # args.env.rank = 0 67 | # args.env.dist_url = f'tcp://localhost:{args.env.port}' 68 | # print('Using url {}'.format(args.env.dist_url)) 69 | 70 | if args.env.slurm: 71 | job_env = submitit.JobEnvironment() 72 | args.env.rank = job_env.global_rank 73 | 74 | if args.env.ngpu > 1: 75 | os.environ['NCCL_P2P_DISABLE'] = '1' 76 | os.environ['NCCL_P2P_LEVEL'] = 'LOC' 77 | os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET6' 78 | args.env.dist_url = f"file://{args.output_dir}/{args.job_name}/.dist" 79 | os.makedirs(os.path.dirname(args.env.dist_url[7:]), exist_ok=True) 80 | time.sleep(2) 81 | 82 | if args.env.gpu is not None: 83 | warnings.warn( 84 | 'You have chosen a specific GPU. This will completely ' 85 | 'disable data parallelism.') 86 | 87 | # Run code 88 | ngpus_per_node = torch.cuda.device_count() 89 | args.env.distributed = args.env.world_size > 1 or (args.env.distributed and ngpus_per_node > 1) 90 | if args.env.distributed: 91 | mp.spawn(main_function, nprocs=ngpus_per_node, args=(args,)) 92 | else: 93 | main_function(0, args) 94 | 95 | def checkpoint(self, *args, **kwargs) -> submitit.helpers.DelayedSubmission: 96 | return submitit.helpers.DelayedSubmission(Worker(), *args, **kwargs) # submits to requeuing 97 | 98 | 99 | def my_jobs(): 100 | return os.popen('squeue -o %j -u $USER').read().split("\n") 101 | 102 | 103 | @hydra.main(config_path='configs/', config_name='efav', version_base='1.1') 104 | def main(args): 105 | update_pythonpath_relative_hydra() 106 | args.output_dir = hydra_utils.to_absolute_path(args.output_dir) 107 | args.job_name = str(args.job_name) # Resolve job name, before config vars change 108 | if 'pretrain_job_name' in args: # If eval, fix paths 109 | args.output_dir = f"{args.output_dir}/{args.pretrain_job_name}" 110 | os.makedirs(f"{args.output_dir}/{args.job_name}", exist_ok=True) 111 | 112 | # defaults 113 | if args.env.workers is None: 114 | args.env.workers = 15 * args.env.ngpu 115 | if args.env.mem_gb is None: 116 | args.env.mem_gb = 60 * args.env.ngpu 117 | 118 | # If job is running, ignore 119 | if args.env.slurm: 120 | slurm_job_name = f"{args.job_name}-{args.pretrain_job_name}" if 'pretrain_job_name' in args and args.pretrain_job_name else args.job_name 121 | if slurm_job_name in my_jobs(): 122 | print(f'Skipping {args.job_name} because already in queue') 123 | return 124 | 125 | # Submit jobs 126 | executor = submitit.AutoExecutor( 127 | folder=f"{args.output_dir}/{args.job_name}/submitit_logs", 128 | slurm_max_num_timeout=100, 129 | cluster=None if args.env.slurm else "debug", 130 | ) 131 | additional_parameters = {} 132 | if args.env.nodelist != "": 133 | additional_parameters.update({"nodelist": args.env.nodelist}) 134 | if args.env.exclude != "": 135 | additional_parameters.update({"exclude": args.env.exclude}) 136 | executor.update_parameters( 137 | timeout_min=args.env.slurm_timeout, 138 | slurm_partition=args.env.slurm_partition, 139 | cpus_per_task=args.env.workers, 140 | gpus_per_node=args.env.ngpu, 141 | nodes=args.env.world_size, 142 | tasks_per_node=1, 143 | mem_gb=args.env.mem_gb, 144 | slurm_additional_parameters=additional_parameters, 145 | slurm_signal_delay_s=120) 146 | executor.update_parameters(name=slurm_job_name) 147 | executor.submit(Worker(), args) 148 | else: 149 | Worker()(args) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /metadata/flickr_test.csv: -------------------------------------------------------------------------------- 1 | 10000130166,202 2 | 10007936344,0 3 | 10008553263,309 4 | 10009662863,0 5 | 10013411946,92 6 | 10016382545,297 7 | 10031203703,484 8 | 10035917404,0 9 | 10045181004,60 10 | 10060697266,0 11 | 10061269855,0 12 | 10067897816,47 13 | 10069511285,701 14 | 10079676496,357 15 | 10080652613,79 16 | 10106776154,164 17 | 10109936483,406 18 | 10110769444,0 19 | 10111706074,0 20 | 10119607346,865 21 | 10119736773,315 22 | 10129791115,105 23 | 10129813456,0 24 | 10163352444,275 25 | 10165890496,61 26 | 10173808724,816 27 | 10202025415,0 28 | 10221582415,179 29 | 10234680056,0 30 | 10246013484,205 31 | 10268129763,0 32 | 10278084464,0 33 | 10283938426,371 34 | 10289643764,152 35 | 10304554474,64 36 | 10308472565,411 37 | 10369815925,0 38 | 10389014583,93 39 | 10409146004,126 40 | 10409351364,0 41 | 10413912845,0 42 | 10437048305,0 43 | 10440265465,0 44 | 10441843213,173 45 | 10446128125,0 46 | 10451898725,20 47 | 10468313544,0 48 | 10476442285,438 49 | 10476446674,394 50 | 10477544895,0 51 | 10481128583,0 52 | 10520382525,66 53 | 10545790805,285 54 | 10548273474,132 55 | 10549911663,0 56 | 10555704135,323 57 | 10557361454,419 58 | 10610974774,0 59 | 10624261424,424 60 | 10635080206,703 61 | 10654057765,0 62 | 10659354545,0 63 | 10666396623,0 64 | 10667001125,0 65 | 10701841844,0 66 | 10743985374,112 67 | 10751849775,0 68 | 10761625676,0 69 | 10771238614,185 70 | 10795023275,771 71 | 10796050285,0 72 | 10805310706,2431 73 | 10847897294,0 74 | 10859255295,0 75 | 10862172945,0 76 | 10939270325,2125 77 | 10952523225,96 78 | 10979853044,354 79 | 10980225216,0 80 | 11002352834,7 81 | 11036641226,48 82 | 11050409385,103 83 | 11071986404,112 84 | 11101358996,59 85 | 11154049584,0 86 | 11314266424,420 87 | 11650557765,0 88 | 11764496003,381 89 | 11765015695,0 90 | 11776645404,371 91 | 12015590114,69 92 | 12031178616,0 93 | 12048726756,412 94 | 12066557153,1526 95 | 12103535156,250 96 | 12158276143,293 97 | 12328837165,60 98 | 12373955905,0 99 | 12444360833,0 100 | 12512072175,0 101 | 12534178685,364 102 | 12598610153,0 103 | 12599965145,0 104 | 12729873904,0 105 | 13153991894,1 106 | 13234495505,0 107 | 13409856383,0 108 | 13447842855,0 109 | 13579448834,0 110 | 13660204414,794 111 | 13721325644,135 112 | 13861270725,0 113 | 14062380939,0 114 | 14172932360,222 115 | 14174434304,168 116 | 14203929291,83 117 | 14414061986,430 118 | 14566072101,0 119 | 14648401332,431 120 | 14652810592,94 121 | 14678892718,0 122 | 14680986811,0 123 | 14742398172,62 124 | 14778198185,58 125 | 14899015793,496 126 | 15135657108,651 127 | 15210828119,0 128 | 15679023977,444 129 | 15799251138,189 130 | 15994498891,0 131 | 16012165838,0 132 | 16101081359,501 133 | 16369370669,58 134 | 16419739420,0 135 | 16563757295,0 136 | 16659090834,758 137 | 16804898716,104 138 | 16950760048,38 139 | 17140316339,1233 140 | 17537079145,277 141 | 19306860381,0 142 | 20854461578,608 143 | 2401957951,198 144 | 2404965389,219 145 | 2432219254,0 146 | 2462141223,0 147 | 2465634368,0 148 | 2499673064,0 149 | 2509250774,207 150 | 2542659118,348 151 | 2630218696,80 152 | 2695985181,0 153 | 2698329253,0 154 | 2766580444,57 155 | 2778088729,0 156 | 2808068937,436 157 | 2811549163,155 158 | 2819123278,0 159 | 2858344348,170 160 | 2895192729,311 161 | 2897081671,1250 162 | 2897653916,596 163 | 2907554899,0 164 | 2935372113,0 165 | 2961466546,0 166 | 3012435981,17 167 | 3052033339,0 168 | 3102685146,3 169 | 3109408703,250 170 | 3202731977,0 171 | 3207674548,0 172 | 3230152596,30 173 | 3403348616,284 174 | 3413119417,1818 175 | 3416359816,0 176 | 3484198977,17 177 | 3535748358,62 178 | 3568893693,0 179 | 3668283413,0 180 | 3709678770,0 181 | 3719297546,384 182 | 3727937033,390 183 | 3749326537,747 184 | 3790233127,69 185 | 3801318146,0 186 | 3828103578,449 187 | 3896768873,0 188 | 3908875350,0 189 | 3956341886,313 190 | 4007045926,460 191 | 4041216001,50 192 | 4180455681,178 193 | 4303617148,577 194 | 4351373013,0 195 | 4407899725,249 196 | 4479260803,0 197 | 4543136011,238 198 | 4576746825,90 199 | 4646464908,0 200 | 4657718822,61 201 | 4717096777,5 202 | 4755507106,1024 203 | 4758950312,87 204 | 4767923306,438 205 | 4876943924,445 206 | 4938432980,0 207 | 4957886467,373 208 | 4965189170,450 209 | 5006362787,299 210 | 5130986168,194 211 | 5179649119,0 212 | 5194793239,0 213 | 5220919991,256 214 | 5237296528,107 215 | 5303633386,103 216 | 5344317532,5 217 | 5480070309,188 218 | 5490101320,233 219 | 5601291878,250 220 | 5607310495,63 221 | 5697766486,356 222 | 5801923729,202 223 | 5873098388,424 224 | 5906878321,1304 225 | 5991393107,163 226 | 6067281197,0 227 | 6098668510,384 228 | 6158154336,369 229 | 6185482260,1918 230 | 6289328021,105 231 | 6458319057,445 232 | 6669466181,103 233 | 6755812421,128 234 | 6897499873,186 235 | 7178293884,431 236 | 7387939334,0 237 | 7560517176,1549 238 | 7623513024,356 239 | 7733838448,64 240 | 7740990330,0 241 | 7897462346,0 242 | 8250285374,0 243 | 8294094053,531 244 | 8311669586,654 245 | 8396975855,344 246 | 9060997789,1017 247 | 9309351423,1135 248 | 9456627660,0 249 | 9636000842,0 250 | 9761803312,917 -------------------------------------------------------------------------------- /metadata/vgginstruments_test.csv: -------------------------------------------------------------------------------- 1 | sample0000 2 | sample0001 3 | sample0002 4 | sample0003 5 | sample0004 6 | sample0005 7 | sample0006 8 | sample0007 9 | sample0009 10 | sample0010 11 | sample0011 12 | sample0012 13 | sample0013 14 | sample0014 15 | sample0015 16 | sample0016 17 | sample0017 18 | sample0018 19 | sample0019 20 | sample0020 21 | sample0021 22 | sample0022 23 | sample0023 24 | sample0024 25 | sample0026 26 | sample0027 27 | sample0028 28 | sample0029 29 | sample0030 30 | sample0031 31 | sample0032 32 | sample0033 33 | sample0034 34 | sample0035 35 | sample0036 36 | sample0037 37 | sample0038 38 | sample0039 39 | sample0040 40 | sample0043 41 | sample0044 42 | sample0045 43 | sample0046 44 | sample0048 45 | sample0049 46 | sample0050 47 | sample0051 48 | sample0052 49 | sample0053 50 | sample0054 51 | sample0055 52 | sample0056 53 | sample0057 54 | sample0059 55 | sample0060 56 | sample0062 57 | sample0063 58 | sample0064 59 | sample0065 60 | sample0066 61 | sample0068 62 | sample0069 63 | sample0070 64 | sample0072 65 | sample0073 66 | sample0074 67 | sample0075 68 | sample0076 69 | sample0078 70 | sample0079 71 | sample0080 72 | sample0081 73 | sample0082 74 | sample0083 75 | sample0084 76 | sample0085 77 | sample0086 78 | sample0087 79 | sample0089 80 | sample0090 81 | sample0091 82 | sample0092 83 | sample0093 84 | sample0094 85 | sample0095 86 | sample0096 87 | sample0098 88 | sample0099 89 | sample0100 90 | sample0101 91 | sample0102 92 | sample0103 93 | sample0105 94 | sample0106 95 | sample0107 96 | sample0108 97 | sample0109 98 | sample0110 99 | sample0111 100 | sample0112 101 | sample0113 102 | sample0114 103 | sample0115 104 | sample0116 105 | sample0117 106 | sample0118 107 | sample0119 108 | sample0120 109 | sample0121 110 | sample0123 111 | sample0124 112 | sample0125 113 | sample0126 114 | sample0127 115 | sample0129 116 | sample0130 117 | sample0131 118 | sample0132 119 | sample0133 120 | sample0134 121 | sample0135 122 | sample0136 123 | sample0137 124 | sample0138 125 | sample0139 126 | sample0140 127 | sample0141 128 | sample0142 129 | sample0143 130 | sample0144 131 | sample0145 132 | sample0146 133 | sample0147 134 | sample0148 135 | sample0149 136 | sample0150 137 | sample0151 138 | sample0152 139 | sample0154 140 | sample0155 141 | sample0156 142 | sample0157 143 | sample0158 144 | sample0159 145 | sample0160 146 | sample0161 147 | sample0162 148 | sample0164 149 | sample0165 150 | sample0166 151 | sample0167 152 | sample0169 153 | sample0170 154 | sample0172 155 | sample0173 156 | sample0174 157 | sample0175 158 | sample0176 159 | sample0179 160 | sample0180 161 | sample0181 162 | sample0182 163 | sample0183 164 | sample0184 165 | sample0185 166 | sample0186 167 | sample0187 168 | sample0188 169 | sample0189 170 | sample0190 171 | sample0191 172 | sample0192 173 | sample0193 174 | sample0194 175 | sample0195 176 | sample0196 177 | sample0197 178 | sample0198 179 | sample0199 180 | sample0200 181 | sample0201 182 | sample0204 183 | sample0205 184 | sample0206 185 | sample0208 186 | sample0209 187 | sample0211 188 | sample0212 189 | sample0213 190 | sample0214 191 | sample0215 192 | sample0217 193 | sample0218 194 | sample0219 195 | sample0220 196 | sample0221 197 | sample0222 198 | sample0223 199 | sample0224 200 | sample0225 201 | sample0226 202 | sample0227 203 | sample0228 204 | sample0229 205 | sample0230 206 | sample0231 207 | sample0233 208 | sample0234 209 | sample0237 210 | sample0238 211 | sample0239 212 | sample0240 213 | sample0242 214 | sample0243 215 | sample0244 216 | sample0245 217 | sample0246 218 | sample0247 219 | sample0248 220 | sample0249 221 | sample0250 222 | sample0252 223 | sample0254 224 | sample0255 225 | sample0256 226 | sample0257 227 | sample0258 228 | sample0259 229 | sample0260 230 | sample0261 231 | sample0262 232 | sample0263 233 | sample0265 234 | sample0266 235 | sample0267 236 | sample0268 237 | sample0269 238 | sample0270 239 | sample0271 240 | sample0272 241 | sample0273 242 | sample0274 243 | sample0275 244 | sample0276 245 | sample0277 246 | sample0278 247 | sample0279 248 | sample0280 249 | sample0281 250 | sample0282 251 | sample0283 252 | sample0284 253 | sample0285 254 | sample0286 255 | sample0287 256 | sample0288 257 | sample0289 258 | sample0290 259 | sample0291 260 | sample0292 261 | sample0293 262 | sample0294 263 | sample0295 264 | sample0296 265 | sample0297 266 | sample0298 267 | sample0299 268 | sample0300 269 | sample0302 270 | sample0303 271 | sample0304 272 | sample0305 273 | sample0306 274 | sample0307 275 | sample0308 276 | sample0309 277 | sample0310 278 | sample0311 279 | sample0312 280 | sample0313 281 | sample0314 282 | sample0316 283 | sample0317 284 | sample0318 285 | sample0319 286 | sample0320 287 | sample0321 288 | sample0322 289 | sample0323 290 | sample0325 291 | sample0326 292 | sample0328 293 | sample0330 294 | sample0331 295 | sample0332 296 | sample0333 297 | sample0334 298 | sample0335 299 | sample0336 300 | sample0337 301 | sample0338 302 | sample0339 303 | sample0340 304 | sample0341 305 | sample0342 306 | sample0343 307 | sample0344 308 | sample0345 309 | sample0346 310 | sample0347 311 | sample0348 312 | sample0349 313 | sample0350 314 | sample0352 315 | sample0353 316 | sample0354 317 | sample0355 318 | sample0357 319 | sample0358 320 | sample0359 321 | sample0360 322 | sample0361 323 | sample0362 324 | sample0363 325 | sample0365 326 | sample0366 327 | sample0367 328 | sample0368 329 | sample0369 330 | sample0370 331 | sample0372 332 | sample0373 333 | sample0374 334 | sample0375 335 | sample0377 336 | sample0378 337 | sample0379 338 | sample0380 339 | sample0381 340 | sample0382 341 | sample0383 342 | sample0384 343 | sample0385 344 | sample0386 345 | sample0387 346 | sample0388 347 | sample0389 348 | sample0390 349 | sample0391 350 | sample0392 351 | sample0393 352 | sample0394 353 | sample0396 354 | sample0397 355 | sample0398 356 | sample0399 357 | sample0400 358 | sample0402 359 | sample0403 360 | sample0404 361 | sample0405 362 | sample0406 363 | sample0407 364 | sample0409 365 | sample0410 366 | sample0411 367 | sample0412 368 | sample0414 369 | sample0415 370 | sample0416 371 | sample0417 372 | sample0418 373 | sample0419 374 | sample0420 375 | sample0421 376 | sample0422 377 | sample0423 378 | sample0424 379 | sample0425 380 | sample0426 381 | sample0427 382 | sample0428 383 | sample0429 384 | sample0430 385 | sample0431 386 | sample0433 387 | sample0434 388 | sample0435 389 | sample0436 390 | sample0437 391 | sample0439 392 | sample0440 393 | sample0443 394 | sample0444 395 | sample0445 396 | sample0446 397 | sample0447 398 | sample0449 399 | sample0450 400 | sample0451 401 | sample0452 402 | sample0453 403 | sample0454 404 | sample0455 405 | sample0457 406 | sample0458 407 | sample0459 408 | sample0460 409 | sample0461 410 | sample0462 411 | sample0463 412 | sample0464 413 | sample0465 414 | sample0466 415 | sample0467 416 | sample0468 417 | sample0469 418 | sample0470 419 | sample0471 420 | sample0472 421 | sample0473 422 | sample0475 423 | sample0476 424 | sample0477 425 | sample0478 426 | sample0479 427 | sample0480 428 | sample0481 429 | sample0482 430 | sample0483 431 | sample0484 432 | sample0485 433 | sample0486 434 | sample0487 435 | sample0488 436 | sample0489 437 | sample0490 438 | sample0491 439 | sample0492 440 | sample0493 441 | sample0494 442 | sample0495 443 | sample0496 444 | sample0497 445 | sample0498 446 | sample0499 -------------------------------------------------------------------------------- /models/avmae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from util.pos_embed import get_2d_sincos_pos_embed 5 | from timm.models.vision_transformer import Block 6 | from models.swin import SwinTransformerBlock 7 | 8 | 9 | class AVMAE(nn.Module): 10 | def __init__( 11 | self, encoder, encoder_dim, 12 | image_decoder_arch='plain', image_decoder_depth=8, image_mask_ratio=0.75, image_norm_loss=False, 13 | audio_decoder_arch='plain', audio_decoder_depth=8, audio_mask_ratio=0.8, audio_norm_loss=False, 14 | decoder_dim=512, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm 15 | ): 16 | super(AVMAE, self).__init__() 17 | 18 | self.image_mask_ratio = image_mask_ratio 19 | self.image_norm_loss = image_norm_loss 20 | self.audio_mask_ratio = audio_mask_ratio 21 | self.audio_norm_loss = audio_norm_loss 22 | self.decoder_dim = decoder_dim 23 | 24 | # -------------------------------------------------------------------------- # 25 | # Audio visual encoder 26 | self.encoder = encoder 27 | self.image_gs, self.audio_gs = encoder.image.patch_embed.grid_size, encoder.audio.patch_embed.grid_size 28 | self.image_ps, self.audio_ps = encoder.image.patch_embed.patch_size, encoder.audio.patch_embed.patch_size 29 | 30 | # -------------------------------------------------------------------------- # 31 | # Audio decoder 32 | self.audio_decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True) 33 | self.audio_decoder_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) 34 | self.audio_decoder_pos_embed = nn.Parameter(torch.zeros(1, self.audio_gs[0]*self.audio_gs[1], decoder_dim)) # fixed sin-cos embedding 35 | 36 | self.audio_decoder_arch = audio_decoder_arch 37 | if self.audio_decoder_arch == 'swin': 38 | self.audio_decoder_blocks = nn.ModuleList([ 39 | SwinTransformerBlock( 40 | dim=decoder_dim, 41 | input_resolution=self.audio_gs, 42 | window_size=4, 43 | shift_size=(index % 2)*2, 44 | num_heads=num_heads, 45 | mlp_ratio=mlp_ratio, 46 | drop=0.0, 47 | attn_drop=0.0, 48 | drop_path=0.0, 49 | norm_layer=norm_layer, 50 | ) 51 | for index in range(audio_decoder_depth)]) 52 | else: 53 | self.audio_decoder_blocks = nn.ModuleList([ 54 | Block(decoder_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 55 | for _ in range(audio_decoder_depth)]) 56 | 57 | self.audio_decoder_norm = norm_layer(decoder_dim) 58 | self.audio_decoder_pred = nn.Linear(decoder_dim, self.audio_ps[0]*self.audio_ps[1], bias=True) # decoder to patch 59 | 60 | # -------------------------------------------------------------------------- # 61 | # Image decoder 62 | self.image_decoder_embed = nn.Linear(encoder_dim, decoder_dim, bias=True) 63 | self.image_decoder_mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) 64 | self.image_decoder_pos_embed = nn.Parameter(torch.zeros(1, self.image_gs[0]*self.image_gs[1], decoder_dim)) # fixed sin-cos embedding 65 | 66 | self.image_decoder_arch = image_decoder_arch 67 | if self.image_decoder_arch == 'swin': 68 | self.image_decoder_blocks = nn.ModuleList([ 69 | SwinTransformerBlock( 70 | dim=decoder_dim, 71 | input_resolution=self.image_gs, 72 | window_size=4, 73 | shift_size=(index % 2)*2, 74 | num_heads=num_heads, 75 | mlp_ratio=mlp_ratio, 76 | drop=0.0, 77 | attn_drop=0.0, 78 | drop_path=0.0, 79 | norm_layer=norm_layer, 80 | ) 81 | for index in range(image_decoder_depth)]) 82 | else: 83 | self.image_decoder_blocks = nn.ModuleList([ 84 | Block(decoder_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 85 | for _ in range(image_decoder_depth)]) 86 | 87 | self.image_decoder_norm = norm_layer(decoder_dim) 88 | self.image_decoder_pred = nn.Linear(decoder_dim, self.image_ps[0]*self.image_ps[1]*3, bias=True) # decoder to patch 89 | 90 | self.initialize_weights() 91 | 92 | def initialize_weights(self): 93 | 94 | # initialize (and freeze) pos_embed by sin-cos embedding 95 | pes = get_2d_sincos_pos_embed(self.decoder_dim, self.image_gs, cls_token=False) 96 | self.image_decoder_pos_embed.data.copy_(torch.from_numpy(pes).float().unsqueeze(0)) 97 | pes = get_2d_sincos_pos_embed(self.decoder_dim, self.audio_gs, cls_token=False) 98 | self.audio_decoder_pos_embed.data.copy_(torch.from_numpy(pes).float().unsqueeze(0)) 99 | 100 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 101 | torch.nn.init.normal_(self.image_decoder_mask_token, std=.02) 102 | torch.nn.init.normal_(self.audio_decoder_mask_token, std=.02) 103 | 104 | # Initialize decoder (avoid init encoder to prevent overriding pretrained weights) 105 | for n, m in self.named_modules(): 106 | if not n.startswith('encoder'): 107 | self._init_weights(m) 108 | 109 | def _init_weights(self, m): 110 | if isinstance(m, nn.Linear): 111 | # we use xavier_uniform following official JAX ViT: 112 | torch.nn.init.xavier_uniform_(m.weight) 113 | if isinstance(m, nn.Linear) and m.bias is not None: 114 | nn.init.constant_(m.bias, 0) 115 | 116 | elif isinstance(m, nn.LayerNorm): 117 | nn.init.constant_(m.bias, 0) 118 | nn.init.constant_(m.weight, 1.0) 119 | 120 | def random_masking(self, N, L, mask_ratio, device): 121 | """ 122 | Perform per-sample random masking by per-sample shuffling. 123 | Per-sample shuffling is done by argsort random noise. 124 | """ 125 | 126 | # sort noise for each sample 127 | noise = torch.rand(N, L, device=device) # noise in [0, 1] 128 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 129 | ids_restore = torch.argsort(ids_shuffle, dim=1) 130 | 131 | # keep first subset 132 | len_keep = int(L * (1 - mask_ratio)) 133 | ids_keep = ids_shuffle[:, :len_keep] 134 | 135 | # generate the binary mask: 0 is keep, 1 is remove 136 | mask = torch.ones([N, L], device=device) 137 | mask[:, :len_keep] = 0 138 | 139 | # unshuffle to get the binary mask 140 | mask = torch.gather(mask, dim=1, index=ids_restore) 141 | 142 | return ids_keep, mask, ids_restore 143 | 144 | def forward_encoder(self, image, audio): 145 | return self.encoder(image, audio) 146 | 147 | def forward_decoder(self, x, x_fusion, ids_restore, modality='image'): 148 | bs, nFus, nMask = x.shape[0], x_fusion.shape[1], ids_restore.shape[1]-x.shape[1] 149 | embed = self.__getattr__(f'{modality}_decoder_embed') 150 | mask_token = self.__getattr__(f'{modality}_decoder_mask_token') 151 | pes = self.__getattr__(f'{modality}_decoder_pos_embed') 152 | arch = self.__getattribute__(f'{modality}_decoder_arch') 153 | blocks = self.__getattr__(f'{modality}_decoder_blocks') 154 | norm = self.__getattr__(f'{modality}_decoder_norm') 155 | pred = self.__getattr__(f'{modality}_decoder_pred') 156 | 157 | # embed tokens 158 | x, x_fusion = embed(x), embed(x_fusion) 159 | 160 | # append mask tokens to sequence and unshuffle 161 | x = torch.cat([x, mask_token.repeat(bs, nMask, 1)], dim=1) 162 | x = x.gather(dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 163 | 164 | # add pos embed 165 | x = x + pes 166 | 167 | # apply Transformer blocks 168 | if arch == 'plain': 169 | x = torch.cat([x_fusion, x], dim=1) 170 | for blk in blocks: 171 | x = blk(x) 172 | x = x[:, nFus:, :] 173 | 174 | elif arch == 'swin': 175 | for blk in blocks: 176 | x, x_fusion = blk(x, x_fusion) 177 | 178 | # apply predictor 179 | x = pred(norm(x)) 180 | return x 181 | 182 | @staticmethod 183 | def forward_loss(target, pred, mask, norm_pix_loss=True): 184 | """ 185 | target: [N, L, p*p*3] 186 | pred: [N, L, p*p*3] 187 | mask: [N, L], 0 is keep, 1 is remove, 188 | """ 189 | if norm_pix_loss: 190 | mean = target.mean(dim=-1, keepdim=True) 191 | var = target.var(dim=-1, keepdim=True) 192 | target = (target - mean) / (var + 1.e-6)**.5 193 | 194 | loss = (pred - target) ** 2 195 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 196 | 197 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 198 | return loss 199 | 200 | @staticmethod 201 | def patchify(x, patch_size): 202 | """ 203 | x: (N, C, H, W) 204 | patch_size = (H/p)*(W/p) 205 | """ 206 | bs, c = x.shape[:2] 207 | pH, pW = patch_size 208 | gH, gW = x.shape[2] // pH, x.shape[3] // pW 209 | 210 | x = x.reshape(shape=(bs, c, gH, pH, gW, pW)) 211 | x = torch.einsum('nchpwq->nhwpqc', x) 212 | x = x.reshape(shape=(bs, gH * gW, pH*pW * c)) 213 | 214 | return x 215 | 216 | def forward(self, image, audio): 217 | B, device = image.shape[0], image.device 218 | 219 | # embed patches 220 | image_ids_keep, image_mask, image_ids_restore = self.random_masking(B, self.image_gs[0]*self.image_gs[1], self.image_mask_ratio, device=device) 221 | audio_ids_keep, audio_mask, audio_ids_restore = self.random_masking(B, self.audio_gs[0]*self.audio_gs[1], self.audio_mask_ratio, device=device) 222 | 223 | # Encoder image and audio 224 | x_image, x_audio, x_fusion = self.encoder(image, audio, image_ids_keep=image_ids_keep, audio_ids_keep=audio_ids_keep) 225 | 226 | # Decode image 227 | target_image = self.patchify(image, self.image_ps) 228 | pred_image = self.forward_decoder(x_image, x_fusion, image_ids_restore, modality='image') 229 | loss_image = self.forward_loss(target_image, pred_image, image_mask, norm_pix_loss=self.image_norm_loss) 230 | 231 | # Decode audio 232 | target_audio = self.patchify(audio, self.audio_ps) 233 | pred_audio = self.forward_decoder(x_audio, x_fusion, audio_ids_restore, modality='audio') 234 | loss_audio = self.forward_loss(target_audio, pred_audio, audio_mask, norm_pix_loss=self.audio_norm_loss) 235 | 236 | return loss_image, loss_audio, pred_image, pred_audio 237 | -------------------------------------------------------------------------------- /models/avsegm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from models.avsrcsep import Up, DoubleConv 6 | 7 | 8 | class Interpolate(nn.Module): 9 | """Interpolation module. 10 | """ 11 | 12 | def __init__(self, scale_factor, mode, align_corners=False): 13 | """Init. 14 | Args: 15 | scale_factor (float): scaling 16 | mode (str): interpolation mode 17 | """ 18 | super(Interpolate, self).__init__() 19 | 20 | self.interp = nn.functional.interpolate 21 | self.scale_factor = scale_factor 22 | self.mode = mode 23 | self.align_corners = align_corners 24 | 25 | def forward(self, x): 26 | """Forward pass. 27 | Args: 28 | x (tensor): input 29 | Returns: 30 | tensor: interpolated data 31 | """ 32 | 33 | x = self.interp( 34 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 35 | ) 36 | 37 | return x 38 | 39 | 40 | class AVSegmSimple(nn.Module): 41 | def __init__(self, encoder, num_classes=71): 42 | super(AVSegmSimple, self).__init__() 43 | self.encoder = encoder 44 | self.num_classes = num_classes 45 | 46 | scales = [1, 2, 4, 8] 47 | embed_dim = self.encoder.embed_dim 48 | layer_dims = [max(128, embed_dim // scale) for scale in scales] 49 | self.normv = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(len(scales))]) 50 | self.proja = nn.ModuleList([nn.Linear(embed_dim, layer_dims[d]) for d in range(len(scales))]) 51 | self.norma = nn.ModuleList([nn.LayerNorm(layer_dims[d]) for d in range(len(scales))]) 52 | self.top = DoubleConv(embed_dim*2, embed_dim) 53 | self.lat = nn.ModuleList([Up(embed_dim, layer_dims[d], factor=scales[d], bilinear=False) for d in range(1, len(scales))]) 54 | self.up = nn.ModuleList([Up(layer_dims[d], layer_dims[d+1], in2_channels=layer_dims[d+1]*2, bilinear=False) 55 | for d in range(len(scales)-1)]) 56 | 57 | self.predictor = nn.Sequential( 58 | nn.Conv2d(layer_dims[-1], 128, kernel_size=(3, 3), stride=(1, 1), padding=1), 59 | Interpolate(scale_factor=2, mode="bilinear"), 60 | nn.Conv2d(128, num_classes, kernel_size=(3, 3), stride=(1, 1), padding=1), 61 | ) 62 | 63 | # Initialize decoder weights 64 | for n, m in self.named_modules(): 65 | if not n.startswith('encoder'): 66 | self._init_weights(m) 67 | 68 | def _init_weights(self, m): 69 | if isinstance(m, nn.Linear): 70 | # we use xavier_uniform following official JAX ViT: 71 | torch.nn.init.xavier_uniform_(m.weight) 72 | if isinstance(m, nn.Linear) and m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 75 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 76 | if m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 79 | nn.init.constant_(m.weight, 1) 80 | nn.init.constant_(m.bias, 0) 81 | 82 | def params_layer_ids(self): 83 | params_layer_ids = [] 84 | params_layer_ids.extend(self.encoder.params_layer_ids()) 85 | params_layer_ids.extend([(p, len(self.encoder.image.blocks)+1) 86 | for n, p in self.named_parameters() if not n.startswith('encoder')]) 87 | return params_layer_ids 88 | 89 | def forward(self, image, audio, gt_segm=None): 90 | # Forward encoder 91 | _, _, _, all_embs = self.encoder(image, audio, return_embs=True) 92 | image_gs = self.encoder.image.patch_embed.grid_size 93 | xv_list = [all_embs[d][0] for d in np.linspace(0, len(all_embs)-1, len(self.norma), endpoint=True).astype(int)] 94 | xa_norm_list = [norm(proj(all_embs[-1][1])).mean(dim=1) 95 | for norm, proj in zip(self.norma, self.proja)] 96 | xv_norm_list = [norm(xv).view(image.shape[0], *image_gs, -1).permute(0, 3, 1, 2) 97 | for norm, xv in zip(self.normv, xv_list)] 98 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 99 | 100 | xa_top = xa_norm_list[0][:, :, None, None].repeat(1, 1, *image_gs) 101 | x = self.top(torch.cat((xv_norm_list[0], xa_top), dim=1)) 102 | for i, (xv, xa) in enumerate(zip(xv_norm_list[1:], xa_norm_list[1:])): 103 | xv = self.lat[i](xv) 104 | xa = xa[:, :, None, None].repeat(1, 1, *xv.shape[2:]) 105 | x = self.up[i](x, torch.cat((xv, xa), dim=1)) 106 | 107 | logits = self.predictor(x) # BF x C x 224 x 224 108 | 109 | loss = None 110 | if gt_segm is not None: 111 | if self.num_classes == 1: 112 | loss = F.binary_cross_entropy_with_logits(logits[:, 0], gt_segm) 113 | else: 114 | loss = F.cross_entropy(logits, gt_segm) 115 | loss = loss + torch.stack([p.sum()*0. for p in self.parameters()]).sum() 116 | return loss, logits -------------------------------------------------------------------------------- /models/avsrcsep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class DoubleConv(nn.Module): 8 | """(convolution => [BN] => ReLU) * 2""" 9 | 10 | def __init__(self, in_channels, out_channels, mid_channels=None): 11 | super().__init__() 12 | if not mid_channels: 13 | mid_channels = out_channels 14 | self.double_conv = nn.Sequential( 15 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(mid_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | return self.double_conv(x) 25 | 26 | 27 | class Up(nn.Module): 28 | """Upscaling then double conv""" 29 | 30 | def __init__(self, in_channels, out_channels, in2_channels=0, factor=2, bilinear=True): 31 | super().__init__() 32 | 33 | # if bilinear, use the normal convolutions to reduce the number of channels 34 | if bilinear: 35 | self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True) 36 | self.conv = DoubleConv(in_channels + in2_channels, out_channels, in_channels // factor) 37 | else: 38 | self.up = nn.ConvTranspose2d(in_channels, in_channels // factor, kernel_size=factor, stride=factor) 39 | self.conv = DoubleConv(in_channels // factor + in2_channels, out_channels) 40 | 41 | def forward(self, x1, x2=None): 42 | x1 = self.up(x1) 43 | if x2 is not None: 44 | return self.conv(torch.cat([x1, x2], dim=1)) 45 | else: 46 | return self.conv(x1) 47 | 48 | 49 | class AVSrcSepUNet(nn.Module): 50 | def __init__(self, embed_dim, bilinear=False): 51 | super().__init__() 52 | 53 | self.xv_norm = nn.LayerNorm(embed_dim) 54 | self.xa_norm = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(5)]) 55 | 56 | self.cond5 = nn.Linear(embed_dim, embed_dim) 57 | self.cond4 = nn.Linear(embed_dim, embed_dim // 2) 58 | self.cond3 = nn.Linear(embed_dim, embed_dim // 4) 59 | self.cond2 = nn.Linear(embed_dim, embed_dim // 8) 60 | self.cond1 = nn.Linear(embed_dim, embed_dim // 16) 61 | 62 | self.top = DoubleConv(embed_dim*2, embed_dim) 63 | 64 | self.lat4 = Up(embed_dim, embed_dim // 2, factor=2, bilinear=bilinear) 65 | self.lat3 = Up(embed_dim, embed_dim // 4, factor=4, bilinear=bilinear) 66 | self.lat2 = Up(embed_dim, embed_dim // 8, factor=8, bilinear=bilinear) 67 | self.lat1 = Up(embed_dim, embed_dim // 16, factor=16, bilinear=bilinear) 68 | 69 | self.up4 = Up(embed_dim // 1, embed_dim // 2, in2_channels=embed_dim // 1, bilinear=bilinear) 70 | self.up3 = Up(embed_dim // 2, embed_dim // 4, in2_channels=embed_dim // 2, bilinear=bilinear) 71 | self.up2 = Up(embed_dim // 4, embed_dim // 8, in2_channels=embed_dim // 4, bilinear=bilinear) 72 | self.up1 = Up(embed_dim // 8, embed_dim // 16, in2_channels=embed_dim // 8, bilinear=bilinear) 73 | 74 | self.pred = nn.Conv2d(embed_dim // 16, 1, kernel_size=(3, 3), padding=(1, 1)) 75 | 76 | self.apply(self._init_weights) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, nn.Linear): 80 | # we use xavier_uniform following official JAX ViT: 81 | torch.nn.init.xavier_uniform_(m.weight) 82 | if isinstance(m, nn.Linear) and m.bias is not None: 83 | nn.init.constant_(m.bias, 0) 84 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 85 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 86 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 87 | nn.init.constant_(m.weight, 1) 88 | nn.init.constant_(m.bias, 0) 89 | 90 | def forward(self, xa_embs, xv, audio_gs=(8, 12)): 91 | bs = xa_embs[0].shape[0] 92 | enc_idx = np.linspace(0, len(xa_embs)-1, 5, endpoint=True)[::-1].astype(int) # [11 8 5 2 0] 93 | xa1, xa2, xa3, xa4, xa5 = [self.xa_norm[i](xa_embs[e]).view(bs, *audio_gs, -1).permute(0, 3, 1, 2) 94 | for i, e in enumerate(enc_idx)] 95 | xv = self.xv_norm(xv).mean(1) 96 | 97 | xv5 = self.cond5(xv)[:, :, None, None].repeat(1, 1, *audio_gs) 98 | x = self.top(torch.cat((xa5, xv5), dim=1)) 99 | 100 | xv4 = self.cond4(xv)[:, :, None, None].repeat(1, 1, audio_gs[0]*2, audio_gs[1]*2) 101 | lat4 = torch.cat((self.lat4(xa4), xv4), dim=1) 102 | x = self.up4(x, lat4) 103 | 104 | xv3 = self.cond3(xv)[:, :, None, None].repeat(1, 1, audio_gs[0]*4, audio_gs[1]*4) 105 | lat3 = torch.cat((self.lat3(xa3), xv3), dim=1) 106 | x = self.up3(x, lat3) 107 | 108 | xv2 = self.cond2(xv)[:, :, None, None].repeat(1, 1, audio_gs[0]*8, audio_gs[1]*8) 109 | lat2 = torch.cat((self.lat2(xa2), xv2), dim=1) 110 | x = self.up2(x, lat2) 111 | 112 | xv1 = self.cond1(xv)[:, :, None, None].repeat(1, 1, audio_gs[0]*16, audio_gs[1]*16) 113 | lat1 = torch.cat((self.lat1(xa1), xv1), dim=1) 114 | x = self.up1(x, lat1) 115 | 116 | logits = self.pred(x) 117 | return logits 118 | 119 | 120 | class AVSrcSep(nn.Module): 121 | def __init__(self, encoder, log_freq=True, weighted_loss=True, binary_mask=True): 122 | super().__init__() 123 | self.log_freq = log_freq 124 | self.weighted_loss = weighted_loss 125 | self.binary_mask = binary_mask 126 | 127 | self.encoder = encoder 128 | self.avss_decoder = AVSrcSepUNet(embed_dim=self.encoder.embed_dim) 129 | 130 | def params_layer_ids(self): 131 | params_layer_ids = [] 132 | params_layer_ids.extend(self.encoder.params_layer_ids()) 133 | params_layer_ids.extend([(p, len(self.encoder.image.blocks)+1) for p in self.avss_decoder.parameters()]) 134 | return params_layer_ids 135 | 136 | def loss_mask_prediction(self, pred_mask, log_spec_mix, log_spec): 137 | spec = torch.pow(10., log_spec) 138 | spec_mix = torch.pow(10., log_spec_mix) 139 | 140 | # Calculate loss weighting coefficient: magnitude of input mixture 141 | if self.weighted_loss: 142 | weight = torch.log1p(spec_mix) 143 | weight = torch.clamp(weight, 1e-3, 10) 144 | else: 145 | weight = torch.ones_like(spec_mix) 146 | 147 | # Compute ground truth masks 148 | if self.binary_mask: 149 | gt_masks = (spec > spec_mix).float() 150 | else: 151 | gt_masks = spec / (spec + spec_mix + 1e-5) 152 | gt_masks.clamp_(0., 1.) 153 | 154 | loss_avss = F.binary_cross_entropy_with_logits(pred_mask, gt_masks, weight) 155 | return loss_avss, gt_masks 156 | 157 | def forward(self, image, audio_mix, audio_gt=None): 158 | # Encode audio and visuals 159 | _, _, _, all_embs = self.encoder(image, audio_mix, return_embs=True) 160 | xv = all_embs[-1][0] 161 | xa_embs = [x[1] for x in all_embs] 162 | 163 | # Prediction head 164 | audio_gs = self.encoder.audio.patch_embed.grid_size 165 | logits_mask = self.avss_decoder(xa_embs, xv, audio_gs) 166 | 167 | # source separation loss 168 | loss = gt_masks = None 169 | if audio_gt is not None: 170 | loss, gt_masks = self.loss_mask_prediction(logits_mask, audio_mix, audio_gt) 171 | loss = loss + torch.stack([p.sum()*0. for p in self.parameters()]).sum() 172 | 173 | return loss, logits_mask, gt_masks 174 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class AVClassifier(nn.Module): 5 | def __init__(self, encoder, num_classes, freeze_encoder=False, input_norm=False): 6 | super(AVClassifier, self).__init__() 7 | self.encoder = encoder 8 | 9 | self.freeze_encoder = freeze_encoder 10 | if self.freeze_encoder: 11 | for p in self.encoder.parameters(): 12 | p.requires_grad = False 13 | 14 | self.input_norm = input_norm 15 | if self.input_norm: 16 | self.image_norm = nn.BatchNorm1d(self.encoder.embed_dim, affine=False, eps=1e-6) 17 | self.audio_norm = nn.BatchNorm1d(self.encoder.embed_dim, affine=False, eps=1e-6) 18 | self.fusion_norm = nn.BatchNorm1d(self.encoder.embed_dim, affine=False, eps=1e-6) 19 | 20 | self.image_head = nn.Linear(self.encoder.embed_dim, num_classes) 21 | self.audio_head = nn.Linear(self.encoder.embed_dim, num_classes) 22 | self.fusion_head = nn.Linear(self.encoder.embed_dim, num_classes) 23 | 24 | self.initialize_weights() 25 | 26 | def initialize_weights(self): 27 | nn.init.xavier_uniform_(self.image_head.weight) 28 | nn.init.zeros_(self.image_head.bias) 29 | nn.init.xavier_uniform_(self.audio_head.weight) 30 | nn.init.zeros_(self.audio_head.bias) 31 | nn.init.xavier_uniform_(self.fusion_head.weight) 32 | nn.init.zeros_(self.fusion_head.bias) 33 | 34 | def params_layer_ids(self): 35 | params_layer_ids = [] 36 | params_layer_ids.extend(self.encoder.params_layer_ids()) 37 | params_layer_ids.extend([(p, len(self.encoder.audio.blocks)+1) for p in self.image_head.parameters()]) 38 | params_layer_ids.extend([(p, len(self.encoder.audio.blocks)+1) for p in self.audio_head.parameters()]) 39 | params_layer_ids.extend([(p, len(self.encoder.audio.blocks)+1) for p in self.fusion_head.parameters()]) 40 | return params_layer_ids 41 | 42 | def forward(self, image, audio): 43 | if self.freeze_encoder: 44 | with torch.no_grad(): 45 | x_image, x_audio, x_fusion = self.encoder(image, audio) 46 | else: 47 | x_image, x_audio, x_fusion = self.encoder(image, audio) 48 | 49 | x_image, x_audio, x_fusion = x_image.mean(dim=1), x_audio.mean(dim=1), x_fusion.mean(dim=1) 50 | if self.input_norm: 51 | x_image = self.image_norm(x_image) 52 | x_audio = self.audio_norm(x_audio) 53 | x_fusion = self.fusion_norm(x_fusion) 54 | 55 | pred_image = self.image_head(x_image) 56 | pred_audio = self.audio_head(x_audio) 57 | pred_fusion = self.fusion_head(x_fusion) 58 | 59 | return pred_image, pred_audio, pred_fusion 60 | 61 | def train(self, mode: bool = True): 62 | super().train(mode) 63 | if self.freeze_encoder: 64 | self.encoder.train(False) 65 | 66 | -------------------------------------------------------------------------------- /models/deepavfusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models import fusion_blocks, vits 4 | 5 | 6 | class DeepAVFusion(nn.Module): 7 | def __init__( 8 | self, 9 | image_arch='vit_base', image_pretrained=True, image_size=(224, 224), 10 | audio_arch='vit_base', audio_pretrained=True, audio_size=(128, 192), 11 | fusion_arch='factorized_mmi', 12 | fusion_layers='all', 13 | num_fusion_tkns=(4, 8, 4), 14 | fusion_mlp_ratio=1.0, fusion_attn_ratio=0.25, fusion_num_heads=12, 15 | drop_path=0., attn_drop=0., drop=0., 16 | ): 17 | super(DeepAVFusion, self).__init__() 18 | 19 | # Audio and visual encoders 20 | self.image = vits.__dict__[image_arch](pretrained=image_pretrained, input_size=image_size, in_chans=3, use_cls_token=False, drop_path=drop_path, attn_drop=attn_drop, drop=drop) 21 | self.audio = vits.__dict__[audio_arch](pretrained=audio_pretrained, input_size=audio_size, in_chans=1, use_cls_token=False, drop_path=drop_path, attn_drop=attn_drop, drop=drop) 22 | self.embed_dim = self.image.embed_dim 23 | self.fusion_arch = fusion_arch 24 | 25 | # NOTE: multi-modal fusion blocks and tokens 26 | self.num_fusion = num_fusion_tkns 27 | self.fusion_tokens = nn.Parameter(torch.zeros(1, sum(num_fusion_tkns), self.embed_dim)) 28 | 29 | FusionBlock = None 30 | if fusion_arch == 'token': 31 | FusionBlock = fusion_blocks.FusionBlock_LocalAVTokens 32 | elif fusion_arch == 'dense_mmi': 33 | FusionBlock = fusion_blocks.FusionBlock_DenseAVInteractions 34 | elif fusion_arch == 'factorized_mmi': 35 | from functools import partial 36 | FusionBlock = partial(fusion_blocks.FusionBlock_FactorizedAVInteractions, fusion_tkns=num_fusion_tkns) 37 | 38 | max_depth = max(len(self.image.blocks), len(self.audio.blocks)) 39 | if fusion_layers == 'all': 40 | fusion_layers = set(range(max_depth)) 41 | elif fusion_layers == 'none': 42 | fusion_layers = set([]) 43 | elif isinstance(fusion_layers, int): 44 | fusion_layers = {fusion_layers} 45 | else: 46 | fusion_layers = set([int(l) for l in fusion_layers.split('-')]) 47 | self.fusion_blocks = nn.ModuleList([ 48 | None if i not in fusion_layers or FusionBlock is None else FusionBlock( 49 | dim=self.embed_dim, num_heads=fusion_num_heads, attn_ratio=fusion_attn_ratio, mlp_ratio=fusion_mlp_ratio, qkv_bias=True, 50 | drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=nn.LayerNorm) 51 | for i in range(max_depth)]) 52 | self.fusion_norm = nn.LayerNorm(self.embed_dim) 53 | 54 | self.initialize_weights() 55 | 56 | def initialize_weights(self): 57 | torch.nn.init.normal_(self.fusion_tokens, std=.02) 58 | self.fusion_blocks.apply(self._init_weights) 59 | 60 | def _init_weights(self, m): 61 | if isinstance(m, nn.Linear): 62 | # we use xavier_uniform following official JAX ViT: 63 | torch.nn.init.xavier_uniform_(m.weight) 64 | if isinstance(m, nn.Linear) and m.bias is not None: 65 | nn.init.constant_(m.bias, 0) 66 | elif isinstance(m, nn.LayerNorm): 67 | nn.init.constant_(m.bias, 0) 68 | nn.init.constant_(m.weight, 1.0) 69 | 70 | def params_layer_ids(self): 71 | params_layer_ids = [] 72 | params_layer_ids.extend(self.image.params_layer_ids()) 73 | params_layer_ids.extend(self.audio.params_layer_ids()) 74 | params_layer_ids.extend([(self.fusion_tokens, 0)]) 75 | for i, blk in enumerate(self.fusion_blocks): 76 | if blk is not None: 77 | params_layer_ids.extend([(p, i+1) for p in blk.parameters()]) 78 | params_layer_ids.extend([(p, len(self.fusion_blocks)+1) for p in self.fusion_norm.parameters()]) 79 | return params_layer_ids 80 | 81 | def load_checkpoint(self, ckpt_fn, prefix): 82 | ckpt = torch.load(ckpt_fn, map_location='cpu') 83 | ckpt = ckpt['state_dict'] 84 | ckpt = {k[len(prefix):]: ckpt[k] for k in ckpt if k.startswith(prefix)} 85 | self.load_state_dict(ckpt, strict=True) 86 | print(f"Loaded pre-trained checkpoint: {ckpt_fn}") 87 | 88 | def forward(self, image, audio, image_ids_keep=None, audio_ids_keep=None, return_embs=False): 89 | B = image.shape[0] 90 | 91 | # embed patches 92 | x_image = self.image.prepare_patch_tokens(image, image_ids_keep) 93 | x_audio = self.audio.prepare_patch_tokens(audio, audio_ids_keep) 94 | 95 | # apply blocks 96 | embs = [] 97 | x_fusion = self.fusion_tokens.expand(B, -1, -1) 98 | nI, nA, nF = x_image.shape[1], x_audio.shape[1], self.fusion_tokens.shape[1] 99 | for blk_image, blk_audio, blk_fusion in zip(self.image.blocks, self.audio.blocks, self.fusion_blocks): 100 | if blk_fusion is None: 101 | x_image = blk_image(x_image) 102 | x_audio = blk_audio(x_audio) 103 | else: 104 | _, _x_image = blk_image(torch.cat((x_fusion, x_image), dim=1)).split((nF, nI), dim=1) 105 | _, _x_audio = blk_audio(torch.cat((x_fusion, x_audio), dim=1)).split((nF, nA), dim=1) 106 | x_fusion = blk_fusion(x_fusion, x_image, x_audio) 107 | x_image, x_audio = _x_image, _x_audio 108 | if return_embs: 109 | embs.append((x_image, x_audio, x_fusion)) 110 | 111 | x_image = self.image.norm(x_image) 112 | x_audio = self.audio.norm(x_audio) 113 | x_fusion = self.fusion_norm(x_fusion) 114 | 115 | if not return_embs: 116 | return x_image, x_audio, x_fusion 117 | else: 118 | return x_image, x_audio, x_fusion, embs 119 | -------------------------------------------------------------------------------- /models/fusion_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.vision_transformer import DropPath, Mlp 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 8 | super().__init__() 9 | self.num_heads = num_heads 10 | head_dim = dim // num_heads 11 | self.scale = head_dim ** -0.5 12 | 13 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 14 | self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj = nn.Linear(dim, dim) 16 | self.proj_drop = nn.Dropout(proj_drop) 17 | 18 | def forward(self, x): 19 | B, N, C = x.shape 20 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 21 | q, k, v = qkv[0], qkv[1], qkv[2] 22 | 23 | attn = (q @ k.transpose(-2, -1)) * self.scale 24 | attn = attn.softmax(dim=-1) 25 | attn = self.attn_drop(attn) 26 | 27 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 28 | x = self.proj(x) 29 | x = self.proj_drop(x) 30 | return x, attn 31 | 32 | 33 | class CrossAttention(nn.Module): 34 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 35 | super().__init__() 36 | self.num_heads = num_heads 37 | head_dim = dim // num_heads 38 | self.scale = head_dim ** -0.5 39 | 40 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 41 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | 46 | def forward(self, x1, x2): 47 | (B, N1, C), N2 = x1.shape, x2.shape[1] 48 | q = self.q(x1).reshape(B, N1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 49 | kv = self.kv(x2).reshape(B, N2, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | k, v = kv[0], kv[1] 51 | 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | 56 | x1 = (attn @ v).transpose(1, 2).reshape(B, N1, C) 57 | x1 = self.proj(x1) 58 | x1 = self.proj_drop(x1) 59 | return x1, attn 60 | 61 | 62 | class Block(nn.Module): 63 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 64 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 65 | super().__init__() 66 | self.norm1 = norm_layer(dim) 67 | self.attn = Attention( 68 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | self.norm2 = norm_layer(dim) 71 | mlp_hidden_dim = int(dim * mlp_ratio) 72 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 73 | 74 | def forward(self, x, return_attention=False): 75 | y, attn = self.attn(self.norm1(x)) 76 | if return_attention: 77 | return attn 78 | x = x + self.drop_path(y) 79 | x = x + self.drop_path(self.mlp(self.norm2(x))) 80 | return x 81 | 82 | ######################################################################################### 83 | # 84 | # m1 m2 ... mF 85 | # |/ \|/ \| 86 | # m1 m2 ... mF --- vi \forall i -- aj \forall j --- 87 | # 88 | ######################################################################################### 89 | class CrossAttention_LocalAVTokens(nn.Module): 90 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., dim_ratio=1.): 91 | super().__init__() 92 | self.num_heads = num_heads 93 | self.dim = int(dim * dim_ratio) 94 | head_dim = self.dim // num_heads 95 | self.scale = head_dim ** -0.5 96 | 97 | self.q = nn.Linear(dim, self.dim, bias=qkv_bias) 98 | self.kv = nn.Linear(dim, self.dim * 2, bias=qkv_bias) 99 | self.attn_drop = nn.Dropout(attn_drop) 100 | self.proj = nn.Linear(self.dim, dim) 101 | self.proj_drop = nn.Dropout(proj_drop) 102 | 103 | def forward(self, xmm, xv, xa): 104 | (bs, nmm, _), nv, na = xmm.shape, xv.shape[1], xa.shape[1] 105 | 106 | x_src = torch.cat((xv, xa), dim=1) 107 | q = self.q(xmm).reshape(bs, nmm, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) 108 | k, v = self.kv(x_src).reshape(bs, nv+na, 2, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4) 109 | 110 | attn = (q @ k.transpose(-2, -1)) * self.scale 111 | attn = attn.softmax(dim=-1) 112 | attn = self.attn_drop(attn) 113 | 114 | xmm = (attn @ v).transpose(1, 2).reshape(bs, nmm, self.dim) 115 | xmm = self.proj(xmm) 116 | xmm = self.proj_drop(xmm) 117 | return xmm, attn 118 | 119 | 120 | class FusionBlock_LocalAVTokens(nn.Module): 121 | def __init__(self, dim, num_heads, attn_ratio=0.25, mlp_ratio=4., qkv_bias=False, 122 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 123 | super().__init__() 124 | self.norm1_mm = norm_layer(dim) 125 | self.norm1_aud = norm_layer(dim) 126 | self.norm1_img = norm_layer(dim) 127 | self.attn = CrossAttention_LocalAVTokens( 128 | dim, num_heads=num_heads, qkv_bias=qkv_bias, 129 | attn_drop=attn_drop, proj_drop=drop, dim_ratio=attn_ratio) 130 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 131 | self.norm2 = norm_layer(dim) 132 | mlp_hidden_dim = int(dim * mlp_ratio) 133 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 134 | 135 | def forward(self, xmm, xa, xv, return_attention=False): 136 | xmm, xv, xa = self.norm1_mm(xmm), self.norm1_img(xv), self.norm1_aud(xa) 137 | res_fusion, attn = self.attn(xmm, xv, xa) 138 | xmm = xmm + self.drop_path(res_fusion) 139 | if return_attention: 140 | return attn 141 | 142 | res_fusion = self.mlp(self.norm2(xmm)) 143 | xmm = xmm + self.drop_path(res_fusion) 144 | return xmm 145 | 146 | 147 | ######################################################################################### 148 | # 149 | # m1 m2 ... mF 150 | # |/ \|/ \| 151 | # m1 m2 ... mF --- cat(vi,aj) \forall i,j --- 152 | # 153 | ######################################################################################### 154 | class CrossAttention_DenseAVInteractions(nn.Module): 155 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., dim_ratio=1.): 156 | super().__init__() 157 | self.num_heads = num_heads 158 | head_dim = dim // num_heads 159 | self.scale = head_dim ** -0.5 160 | self.dim = int(dim * dim_ratio) 161 | 162 | self.q = nn.Linear(dim, self.dim, bias=qkv_bias) 163 | self.kv = nn.Linear(dim * 2, self.dim * 2, bias=qkv_bias) 164 | self.attn_drop = nn.Dropout(attn_drop) 165 | self.proj = nn.Linear(self.dim, dim) 166 | self.proj_drop = nn.Dropout(proj_drop) 167 | 168 | def forward(self, xmm, xa, xv): 169 | (bs, nmm, _), nv, na = xmm.shape, xv.shape[1], xa.shape[1] 170 | 171 | xva = torch.cat(( 172 | xv.unsqueeze(2).repeat(1, 1, na, 1), 173 | xa.unsqueeze(1).repeat(1, nv, 1, 1), 174 | ), dim=3).flatten(1, 2) 175 | 176 | q = self.q(xmm).reshape(bs, nmm, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) 177 | k, v = self.kv(xva).reshape(bs, nv*na, 2, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4) 178 | 179 | attn = (q @ k.transpose(-2, -1)) * self.scale 180 | attn = attn.softmax(dim=-1) 181 | attn = self.attn_drop(attn) 182 | 183 | xmm = (attn @ v).transpose(1, 2).reshape(bs, nmm, self.dim) 184 | xmm = self.proj(xmm) 185 | xmm = self.proj_drop(xmm) 186 | return xmm, attn 187 | 188 | 189 | class FusionBlock_DenseAVInteractions(nn.Module): 190 | def __init__(self, dim, num_heads, attn_ratio=0.25, mlp_ratio=4., qkv_bias=False, 191 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 192 | super().__init__() 193 | self.norm1_mm = norm_layer(dim) 194 | self.norm1_aud = norm_layer(dim) 195 | self.norm1_img = norm_layer(dim) 196 | self.attn = CrossAttention_DenseAVInteractions( 197 | dim, num_heads=num_heads, qkv_bias=qkv_bias, 198 | attn_drop=attn_drop, proj_drop=drop, dim_ratio=attn_ratio) 199 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 200 | self.norm2 = norm_layer(dim) 201 | mlp_hidden_dim = int(dim * mlp_ratio) 202 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 203 | 204 | def forward(self, xmm, xv, xa, return_attention=False): 205 | xmm, xv, xa = self.norm1_mm(xmm), self.norm1_img(xv), self.norm1_aud(xa) 206 | res_fusion, attn = self.attn(xmm, xv, xa) 207 | xmm = xmm + self.drop_path(res_fusion) 208 | if return_attention: 209 | return attn 210 | 211 | res_fusion = self.mlp(self.norm2(xmm)) 212 | xmm = xmm + self.drop_path(res_fusion) 213 | return xmm 214 | 215 | 216 | class CrossAttention_FactorizedAVInteractions(nn.Module): 217 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., dim_ratio=1., fusion_tkns=(8, 4, 4)): 218 | super().__init__() 219 | self.num_heads = num_heads 220 | head_dim = dim // num_heads 221 | self.scale = head_dim ** -0.5 222 | self.dim = int(dim * dim_ratio) 223 | self.fusion_tkns = fusion_tkns 224 | 225 | self.attn_v = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) 226 | self.attn_a = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) 227 | 228 | self.q = nn.Linear(dim, self.dim, bias=qkv_bias) 229 | self.k = nn.Linear(dim * 2, self.dim, bias=qkv_bias) 230 | self.v = nn.Linear(dim * 2, dim, bias=qkv_bias) 231 | self.attn_drop = nn.Dropout(attn_drop) 232 | self.proj = nn.Linear(dim, dim) 233 | self.proj_drop = nn.Dropout(proj_drop) 234 | 235 | def forward(self, xmm, xv, xa): 236 | bs = xmm.shape[0] 237 | nmm, nv, na = self.fusion_tkns 238 | 239 | # 16 8 8 240 | xmm2, xmm_v, xmm_a = xmm.split((nmm, nv, na), dim=1) 241 | xmm_v, _ = self.attn_v(xmm_v, xv) # Linearly with #V 242 | xmm_a, _ = self.attn_a(xmm_a, xa) # Linearly with #A 243 | 244 | # All VA pairs 245 | xva = torch.cat(( 246 | xmm_v.unsqueeze(2).repeat(1, 1, na, 1), 247 | xmm_a.unsqueeze(1).repeat(1, nv, 1, 1), 248 | ), dim=3).flatten(1, 2) 249 | 250 | q = self.q(xmm2).reshape(bs, nmm, self.num_heads, -1).permute(0, 2, 1, 3) 251 | k = self.k(xva).reshape(bs, nv*na, self.num_heads, -1).permute(0, 2, 1, 3) 252 | v = self.v(xva).reshape(bs, nv*na, self.num_heads, -1).permute(0, 2, 1, 3) 253 | 254 | attn = (q @ k.transpose(-2, -1)) * self.scale 255 | attn = attn.softmax(dim=-1) 256 | attn = self.attn_drop(attn) 257 | 258 | xmm2 = (attn @ v).transpose(1, 2).flatten(2) 259 | xmm2 = self.proj(xmm2) 260 | xmm2 = self.proj_drop(xmm2) 261 | 262 | xmm = torch.cat((xmm2, xmm_v, xmm_a), dim=1) 263 | return xmm, attn 264 | 265 | 266 | class FusionBlock_FactorizedAVInteractions(nn.Module): 267 | def __init__(self, dim, num_heads, attn_ratio=0.25, mlp_ratio=4., qkv_bias=False, fusion_tkns=(8, 4, 4), 268 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 269 | super().__init__() 270 | self.norm1_mm = norm_layer(dim) 271 | self.norm1_aud = norm_layer(dim) 272 | self.norm1_img = norm_layer(dim) 273 | self.attn = CrossAttention_FactorizedAVInteractions( 274 | dim, num_heads=num_heads, qkv_bias=qkv_bias, 275 | attn_drop=attn_drop, proj_drop=drop, dim_ratio=attn_ratio, fusion_tkns=fusion_tkns) 276 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 277 | self.norm2 = norm_layer(dim) 278 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 279 | 280 | def forward(self, xmm, xv, xa, return_attention=False): 281 | xmm, xv, xa = self.norm1_mm(xmm), self.norm1_img(xv), self.norm1_aud(xa) 282 | res_fusion, attn = self.attn(xmm, xv, xa) 283 | xmm = xmm + self.drop_path(res_fusion) 284 | if return_attention: 285 | return attn 286 | 287 | res_fusion = self.mlp(self.norm2(xmm)) 288 | xmm = xmm + self.drop_path(res_fusion) 289 | return xmm 290 | -------------------------------------------------------------------------------- /models/swin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.swin_transformer import get_relative_position_index, window_partition, window_reverse 6 | from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert 7 | 8 | 9 | class WindowAttention(nn.Module): 10 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 11 | It supports both of shifted and non-shifted window. 12 | 13 | Args: 14 | dim (int): Number of input channels. 15 | num_heads (int): Number of attention heads. 16 | head_dim (int): Number of channels per head (dim // num_heads if not set) 17 | window_size (tuple[int]): The height and width of the window. 18 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 19 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 20 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 21 | """ 22 | 23 | def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): 24 | 25 | super().__init__() 26 | self.dim = dim 27 | self.window_size = to_2tuple(window_size) # Wh, Ww 28 | win_h, win_w = self.window_size 29 | self.window_area = win_h * win_w 30 | self.num_heads = num_heads 31 | head_dim = head_dim or dim // num_heads 32 | attn_dim = head_dim * num_heads 33 | self.scale = head_dim ** -0.5 34 | 35 | # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH 36 | self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) 37 | 38 | # get pair-wise relative position index for each token inside the window 39 | self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) 40 | 41 | self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(attn_dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | 46 | trunc_normal_(self.relative_position_bias_table, std=.02) 47 | self.softmax = nn.Softmax(dim=-1) 48 | 49 | def _get_rel_pos_bias(self) -> torch.Tensor: 50 | relative_position_bias = self.relative_position_bias_table[ 51 | self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH 52 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 53 | return relative_position_bias.unsqueeze(0) 54 | 55 | def forward(self, x, mask: Optional[torch.Tensor] = None): 56 | """ 57 | Args: 58 | x: input features with shape of (num_windows*B, N, C) 59 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 60 | """ 61 | B_, N, C = x.shape 62 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 63 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 64 | 65 | q = q * self.scale 66 | attn = (q @ k.transpose(-2, -1)) 67 | 68 | # Add relative position bias 69 | bias = self._get_rel_pos_bias() 70 | if bias.shape[-1] != attn.shape[-1] or bias.shape[-2] != attn.shape[-2]: 71 | bias = torch.nn.functional.pad(bias, (0, attn.shape[-1] - bias.shape[-1], 0, attn.shape[-2] - bias.shape[-2])) 72 | attn = attn + bias 73 | 74 | if mask is not None: 75 | num_win = mask.shape[0] 76 | if mask.shape[-1] != attn.shape[-1] or mask.shape[-2] != attn.shape[-2]: 77 | mask = torch.nn.functional.pad(mask, (0, attn.shape[-1] - mask.shape[-1], 0, attn.shape[-2] - mask.shape[-2])) 78 | attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 79 | attn = attn.view(-1, self.num_heads, N, N) 80 | attn = self.softmax(attn) 81 | else: 82 | attn = self.softmax(attn) 83 | 84 | attn = self.attn_drop(attn) 85 | 86 | x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | 91 | 92 | class SwinTransformerBlock(nn.Module): 93 | r""" Swin Transformer Block. 94 | 95 | Args: 96 | dim (int): Number of input channels. 97 | input_resolution (tuple[int]): Input resulotion. 98 | window_size (int): Window size. 99 | num_heads (int): Number of attention heads. 100 | head_dim (int): Enforce the number of channels per head 101 | shift_size (int): Shift size for SW-MSA. 102 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 103 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 104 | drop (float, optional): Dropout rate. Default: 0.0 105 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 106 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 107 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 108 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 109 | """ 110 | 111 | def __init__( 112 | self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, 113 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 114 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 115 | super().__init__() 116 | self.dim = dim 117 | self.input_resolution = input_resolution 118 | self.window_size = window_size 119 | self.shift_size = shift_size 120 | self.mlp_ratio = mlp_ratio 121 | if min(self.input_resolution) <= self.window_size: 122 | # if window size is larger than input resolution, we don't partition windows 123 | self.shift_size = 0 124 | self.window_size = min(self.input_resolution) 125 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 126 | 127 | self.norm1 = norm_layer(dim) 128 | self.attn = WindowAttention( 129 | dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), 130 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 131 | 132 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 133 | self.norm2 = norm_layer(dim) 134 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 135 | 136 | if self.shift_size > 0: 137 | # calculate attention mask for SW-MSA 138 | H, W = self.input_resolution 139 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 140 | cnt = 0 141 | for h in ( 142 | slice(0, -self.window_size), 143 | slice(-self.window_size, -self.shift_size), 144 | slice(-self.shift_size, None)): 145 | for w in ( 146 | slice(0, -self.window_size), 147 | slice(-self.window_size, -self.shift_size), 148 | slice(-self.shift_size, None)): 149 | img_mask[:, h, w, :] = cnt 150 | cnt += 1 151 | mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 152 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 153 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 154 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 155 | else: 156 | attn_mask = None 157 | 158 | self.register_buffer("attn_mask", attn_mask) 159 | 160 | def forward(self, x, x_fusion=None): 161 | H, W = self.input_resolution 162 | B, L, C = x.shape 163 | assert(L == H * W, "input feature has wrong size") 164 | 165 | shortcut = x 166 | x = self.norm1(x) 167 | x = x.view(B, H, W, C) 168 | if x_fusion is not None: 169 | shortcut_fusion = x_fusion 170 | x_fusion = self.norm1(x_fusion) 171 | 172 | # cyclic shift 173 | shifted_x = x 174 | if self.shift_size > 0: 175 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 176 | 177 | # partition windows 178 | x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C 179 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C 180 | 181 | # W-MSA/SW-MSA 182 | if x_fusion is not None: 183 | Lf, Lw, num_win = x_fusion.shape[1], x_windows.shape[1], x_windows.shape[0]//x.shape[0] 184 | x_win_fus = torch.cat((x_windows, x_fusion[:, None].repeat(1, num_win, 1, 1).flatten(0, 1)), dim=1) 185 | attn_win_fus = self.attn(x_win_fus, mask=self.attn_mask) # num_win*B, window_size*window_size, C 186 | attn_windows, attn_fusion = attn_win_fus.split((Lw, Lf), dim=1) 187 | else: 188 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C 189 | 190 | # merge windows 191 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 192 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 193 | 194 | # reverse cyclic shift 195 | x = shifted_x 196 | if self.shift_size > 0: 197 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 198 | x = x.view(B, H * W, C) 199 | 200 | # FFN 201 | if x_fusion is not None: 202 | attn_fusion = attn_fusion.view(B, num_win, Lf, C).mean(1) 203 | x = torch.cat((shortcut, shortcut_fusion), dim=1) + self.drop_path(torch.cat((x, attn_fusion), dim=1)) 204 | x = x + self.drop_path(self.mlp(self.norm2(x))) 205 | return x.split((L, Lf), dim=1) 206 | else: 207 | x = shortcut + self.drop_path(x) 208 | x = x + self.drop_path(self.mlp(self.norm2(x))) 209 | return x -------------------------------------------------------------------------------- /models/video_earlyfusion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Union 2 | 3 | import torch 4 | from torch import nn 5 | from models import video_vits, vits 6 | from models.fusion_blocks import FusionBlock_FactorizedAVInteractions 7 | 8 | 9 | class VideoEarlyFusion(nn.Module): 10 | def __init__( 11 | self, 12 | video_arch: str = 'video_vit_base', 13 | video_pretrained: str = '', 14 | video_size: Tuple[int] = (24, 224, 224), 15 | audio_arch: str = 'audio_vit_base', 16 | audio_pretrained: str = '', 17 | audio_size: Tuple[int] = (128, 298), 18 | # fusion setting 19 | fusion_layers: str = 'all', 20 | num_fusion_tkns: Tuple[int] = (8, 16, 16), 21 | fusion_mlp_ratio: float = 1., 22 | fusion_attn_ratio: float = .25, 23 | fusion_num_heads: int = 12, 24 | drop_path: float = 0., 25 | attn_drop: float = 0., 26 | drop: float = 0., 27 | ): 28 | super().__init__() 29 | 30 | # Audio and visual encoders 31 | self.video = video_vits.__dict__[video_arch](pretrained=video_pretrained, input_size=video_size, in_chans=3, use_cls_token=False, drop_path=drop_path, attn_drop=attn_drop, drop=drop) 32 | self.audio = vits.__dict__[audio_arch](pretrained=audio_pretrained, input_size=audio_size, in_chans=1, use_cls_token=False, drop_path=drop_path, attn_drop=attn_drop, drop=drop) 33 | self.embed_dim = self.video.embed_dim 34 | 35 | # NOTE: multi-modal fusion blocks and tokens 36 | self.num_fusion = num_fusion_tkns 37 | self.fusion_tokens = nn.Parameter(torch.zeros(1, sum(num_fusion_tkns), self.embed_dim)) 38 | 39 | max_depth = max(len(self.video.blocks), len(self.audio.blocks)) 40 | if fusion_layers == 'all': 41 | fusion_layers = set(range(max_depth)) 42 | elif fusion_layers == 'none': 43 | fusion_layers = set([]) 44 | elif isinstance(fusion_layers, int): 45 | fusion_layers = {fusion_layers} 46 | else: 47 | fusion_layers = set([int(l) for l in fusion_layers.split('-')]) 48 | self.fusion_blocks = nn.ModuleList([ 49 | None if i not in fusion_layers else FusionBlock_FactorizedAVInteractions( 50 | dim=self.embed_dim, fusion_tkns=num_fusion_tkns, num_heads=fusion_num_heads, 51 | attn_ratio=fusion_attn_ratio, mlp_ratio=fusion_mlp_ratio, qkv_bias=True, 52 | drop=drop, attn_drop=attn_drop, drop_path=drop_path) 53 | for i in range(max_depth)]) 54 | self.fusion_norm = nn.LayerNorm(self.embed_dim) 55 | 56 | self.initialize_weights() 57 | 58 | def initialize_weights(self): 59 | torch.nn.init.normal_(self.fusion_tokens, std=.02) 60 | self.fusion_blocks.apply(self._init_weights) 61 | 62 | def _init_weights(self, m): 63 | if isinstance(m, nn.Linear): 64 | # we use xavier_uniform following official JAX ViT: 65 | torch.nn.init.xavier_uniform_(m.weight) 66 | if isinstance(m, nn.Linear) and m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.LayerNorm): 69 | nn.init.constant_(m.bias, 0) 70 | nn.init.constant_(m.weight, 1.0) 71 | 72 | def params_layer_ids(self): 73 | params_layer_ids = [] 74 | params_layer_ids.extend(self.video.params_layer_ids()) 75 | params_layer_ids.extend(self.audio.params_layer_ids()) 76 | params_layer_ids.extend([(self.fusion_tokens, 0)]) 77 | for i, blk in enumerate(self.fusion_blocks): 78 | if blk is not None: 79 | params_layer_ids.extend([(p, i+1) for p in blk.parameters()]) 80 | params_layer_ids.extend([(p, len(self.fusion_blocks)+1) for p in self.fusion_norm.parameters()]) 81 | return params_layer_ids 82 | 83 | def load_checkpoint(self, ckpt_fn, prefix): 84 | ckpt = torch.load(ckpt_fn, map_location='cpu') 85 | ckpt = ckpt['state_dict'] 86 | ckpt = {k[len(prefix):]: ckpt[k] for k in ckpt if k.startswith(prefix)} 87 | # Adapt image model checkpoint for video 88 | ckpt = {k.replace('image.', 'video.'): ckpt[k] for k in ckpt} 89 | ckpt['video.pos_embed'] = self.video.state_dict()['pos_embed'] 90 | if self.video.patch_embed.proj.weight.ndim > ckpt['video.patch_embed.proj.weight'].ndim: 91 | ckpt['video.patch_embed.proj.weight'] = ckpt['video.patch_embed.proj.weight'].unsqueeze(2).repeat(1, 1, self.video.patch_size[0], 1, 1) 92 | self.load_state_dict(ckpt, strict=True) 93 | print(f"Loaded pre-trained checkpoint: {ckpt_fn}") 94 | 95 | def forward(self, video, audio, video_ids_keep=None, audio_ids_keep=None, return_embs=False): 96 | ''' 97 | @video: (b c t h w) 98 | @audio: (b c n t) 99 | ''' 100 | B = video.shape[0] 101 | 102 | # embed patches 103 | x_video = self.video.prepare_patch_tokens(video, video_ids_keep) # (b, t*h*w, c) 104 | x_audio = self.audio.prepare_patch_tokens(audio, audio_ids_keep) # (b, n*t, c) 105 | 106 | # apply blocks 107 | embs = [] 108 | x_fusion = self.fusion_tokens.expand(B, -1, -1) 109 | nV, nA, nF = x_video.shape[1], x_audio.shape[1], self.fusion_tokens.shape[1] 110 | for blk_video, blk_audio, blk_fusion in zip(self.video.blocks, self.audio.blocks, self.fusion_blocks): 111 | if blk_fusion is None: 112 | x_video = blk_video(x_video) 113 | x_audio = blk_audio(x_audio) 114 | else: 115 | _, _x_video = blk_video(torch.cat((x_fusion, x_video), dim=1)).split((nF, nV), dim=1) 116 | _, _x_audio = blk_audio(torch.cat((x_fusion, x_audio), dim=1)).split((nF, nA), dim=1) 117 | 118 | x_fusion = blk_fusion(x_fusion, x_video, x_audio) 119 | x_video, x_audio = _x_video, _x_audio 120 | 121 | if return_embs: 122 | embs.append((x_video, x_audio, x_fusion)) 123 | 124 | x_video = self.video.norm(x_video) 125 | x_audio = self.audio.norm(x_audio) 126 | x_fusion = self.fusion_norm(x_fusion) 127 | 128 | if not return_embs: 129 | return x_video, x_audio, x_fusion 130 | else: 131 | return x_video, x_audio, x_fusion, embs 132 | 133 | 134 | # set recommended archs 135 | def video_efav_small(video_pretrained='', audio_pretrained='', **kwargs): 136 | assert video_pretrained == '' 137 | assert audio_pretrained == '' 138 | model = VideoEarlyFusion( 139 | video_arch='video_vit_small', video_pretrained=video_pretrained, 140 | audio_arch='vit_small', audio_pretrained=audio_pretrained, 141 | fusion_layers='all', num_fusion_tkns=(8, 4, 4), fusion_num_heads=6, **kwargs) 142 | return model 143 | 144 | 145 | def video_efav_base(video_pretrained='', audio_pretrained='', **kwargs): 146 | assert video_pretrained == '' 147 | assert audio_pretrained == '' 148 | model = VideoEarlyFusion( 149 | video_arch='video_vit_base', video_pretrained=video_pretrained, 150 | audio_arch='vit_base', audio_pretrained=audio_pretrained, 151 | fusion_layers='all', num_fusion_tkns=(16, 8, 8), fusion_num_heads=12, **kwargs) 152 | return model 153 | 154 | 155 | def video_efav_large(video_pretrained='', audio_pretrained='', **kwargs): 156 | assert video_pretrained == '' 157 | assert audio_pretrained == '' 158 | model = VideoEarlyFusion( 159 | video_arch='video_vit_large', video_pretrained=video_pretrained, 160 | audio_arch='vit_large', audio_pretrained=audio_pretrained, 161 | fusion_layers='all', num_fusion_tkns=(32, 12, 12), fusion_num_heads=16, **kwargs) 162 | return model 163 | 164 | 165 | def video_efav_huge(video_pretrained='', audio_pretrained='', **kwargs): 166 | assert video_pretrained == '' 167 | assert audio_pretrained == '' 168 | model = VideoEarlyFusion( 169 | video_arch='video_vit_huge', video_pretrained=video_pretrained, 170 | audio_arch='vit_huge', audio_pretrained=audio_pretrained, 171 | fusion_layers='all', num_fusion_tkns=(64, 16, 16), fusion_num_heads=16, **kwargs) 172 | return model 173 | 174 | 175 | if __name__ == '__main__': 176 | import time 177 | import torch.amp 178 | model = video_efav_base(video_pretrained='', audio_pretrained='', video_size=(8, 224, 224), audio_size=(128, 192)).cuda() 179 | for b in range(10): 180 | bs = 2**b 181 | xv = torch.randn(bs, 3, 8, 224, 224).cuda() 182 | xa = torch.randn(bs, 1, 128, 192).cuda() 183 | ts = time.time() 184 | with torch.amp.autocast(device_type='cuda'): 185 | vv, va, vf = model(xv, xa) 186 | (vv.sum()+va.sum()+vf.sum()).backward() 187 | gpu_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 188 | print(f"BS={bs}\t GPU mem: {gpu_mem:.1f} Gb\t Fwd Bwd: {time.time()-ts:.2f} sec") -------------------------------------------------------------------------------- /models/vits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from timm.models.vision_transformer import PatchEmbed, Block 6 | from util.pos_embed import get_2d_sincos_pos_embed 7 | 8 | 9 | PRETRAINED_WEIGHTS = { 10 | 'vit_base_audiomae_as2m': ('assets/models/vitbase_audiomae_as2m.pth', ''), 11 | 'vit_base_mae_in1k': ('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', ''), 12 | 'vit_large_mae_in1k': ('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth', ''), 13 | 'vit_huge_mae_in1k': ('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth', ''), 14 | } 15 | 16 | class ViT(nn.Module): 17 | """ VisionTransformer backbone 18 | """ 19 | def __init__(self, 20 | input_size=224, patch_size=16, in_chans=3, 21 | embed_dim=1024, depth=24, num_heads=16, 22 | mlp_ratio=4., norm_layer=nn.LayerNorm, use_cls_token=False, 23 | drop_path=0., attn_drop=0., drop=0.): 24 | super().__init__() 25 | 26 | self.embed_dim = embed_dim 27 | self.patch_embed = PatchEmbed(input_size, patch_size, in_chans, embed_dim) 28 | num_patches = self.patch_embed.num_patches 29 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) # fixed sin-cos embedding 30 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_cls_token else None 31 | 32 | self.blocks = nn.ModuleList([ 33 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, attn_drop=attn_drop, proj_drop=drop) 34 | for _ in range(depth)]) 35 | self.norm = norm_layer(embed_dim) 36 | self.initialize_weights() 37 | 38 | def initialize_weights(self): 39 | # initialize (and freeze) pos_embed by sin-cos embedding 40 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=False) 41 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 42 | 43 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 44 | w = self.patch_embed.proj.weight.data 45 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 46 | 47 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 48 | if self.cls_token is not None: 49 | torch.nn.init.normal_(self.cls_token, std=.02) 50 | 51 | # initialize nn.Linear and nn.LayerNorm 52 | self.apply(self._init_weights) 53 | 54 | def _init_weights(self, m): 55 | if isinstance(m, nn.Linear): 56 | # we use xavier_uniform following official JAX ViT: 57 | torch.nn.init.xavier_uniform_(m.weight) 58 | if isinstance(m, nn.Linear) and m.bias is not None: 59 | nn.init.constant_(m.bias, 0) 60 | elif isinstance(m, nn.LayerNorm): 61 | nn.init.constant_(m.bias, 0) 62 | nn.init.constant_(m.weight, 1.0) 63 | 64 | def load_checkpoint(self, ckpt_fn, prefix='', skip_keys_prefix=('decoder', 'mask_token')): 65 | try: 66 | ckpt = torch.load(ckpt_fn, map_location="cpu") 67 | except Exception: 68 | ckpt = torch.hub.load_state_dict_from_url(url=ckpt_fn, map_location="cpu") 69 | 70 | if 'state_dict' in ckpt: 71 | ckpt = ckpt['state_dict'] 72 | elif 'model' in ckpt: 73 | ckpt = ckpt['model'] 74 | ckpt = {k[len(prefix):]: v for k, v in ckpt.items() if k.startswith(prefix)} 75 | ckpt = {k: v for k, v in ckpt.items() if not k.startswith(skip_keys_prefix)} 76 | 77 | if self.cls_token is None and 'cls_token' in ckpt: 78 | del ckpt['cls_token'] 79 | ckpt['pos_embed'] = self.state_dict()['pos_embed'] 80 | self.load_state_dict(ckpt, strict=True) 81 | 82 | def params_layer_ids(self): 83 | params_layer_ids = [] 84 | params_layer_ids.extend([(p, 0) for p in self.patch_embed.parameters()]) 85 | params_layer_ids.extend([(self.cls_token, 0)]) 86 | for i, blk in enumerate(self.blocks): 87 | params_layer_ids.extend([(p, i+1) for p in blk.parameters()]) 88 | params_layer_ids.extend([(p, len(self.blocks)+1) for p in self.norm.parameters()]) 89 | return params_layer_ids 90 | 91 | def prepare_patch_tokens(self, x, ids_keep=None): 92 | # embed patches 93 | x = self.patch_embed(x) 94 | 95 | # add pos embed w/o cls token 96 | x = x + self.pos_embed 97 | 98 | # masking 99 | if ids_keep is not None: 100 | x = x.gather(dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1])) 101 | 102 | # append cls token 103 | if self.cls_token is not None: 104 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 105 | x = torch.cat((cls_tokens, x), dim=1) 106 | 107 | return x 108 | 109 | def forward(self, x, ids_keep=None): 110 | # prepare patches 111 | x = self.prepare_patch_tokens(x, ids_keep=ids_keep) 112 | 113 | # apply Transformer blocks 114 | for blk in self.blocks: 115 | x = blk(x) 116 | x = self.norm(x) 117 | 118 | return x 119 | 120 | 121 | def vit_small_patch16(pretrained=False, **kwargs): 122 | assert pretrained == False 123 | model = ViT( 124 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 125 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 126 | ) 127 | return model 128 | 129 | 130 | def vit_base_patch16(pretrained=False, **kwargs): 131 | model = ViT( 132 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 133 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 134 | ) 135 | 136 | if pretrained is not None and pretrained != '': 137 | assert pretrained in {'vit_base_mae_in1k', 'vit_base_audiomae_as2m'} 138 | url, prefix = PRETRAINED_WEIGHTS[pretrained] 139 | model.load_checkpoint(url, prefix=prefix) 140 | 141 | return model 142 | 143 | 144 | def vit_large_patch16(pretrained=False, **kwargs): 145 | model = ViT( 146 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 147 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 148 | ) 149 | 150 | if pretrained is not None: 151 | assert pretrained in {'vit_large_mae_in1k'} 152 | url, prefix = PRETRAINED_WEIGHTS[pretrained] 153 | model.load_checkpoint(url, prefix=prefix) 154 | 155 | return model 156 | 157 | 158 | def vit_huge_patch14(pretrained=None, **kwargs): 159 | model = ViT( 160 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 161 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 162 | ) 163 | 164 | if pretrained is not None: 165 | assert pretrained in {'vit_huge_mae_in1k'} 166 | url, prefix = PRETRAINED_WEIGHTS[pretrained] 167 | model.load_checkpoint(url, prefix=prefix) 168 | 169 | return model 170 | 171 | 172 | # set recommended archs 173 | vit_small = vit_small_patch16 174 | vit_base = vit_base_patch16 175 | vit_large = vit_large_patch16 176 | vit_huge = vit_huge_patch14 -------------------------------------------------------------------------------- /requirements.yml: -------------------------------------------------------------------------------- 1 | name: efav 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.10.11=h7a1cb2a_2 8 | - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0 9 | - pytorch-cuda=11.8=h7e8668a_5 10 | - torchaudio=2.0.2=py310_cu118 11 | - torchvision=0.15.2=py310_cu118 12 | - pip=23.0.1=py310h06a4308_0 13 | - pip: 14 | - av==10.0.0 15 | - hydra-core==1.3.2 16 | - jupyter==1.0.0 17 | - matplotlib==3.7.1 18 | - mir-eval==0.7 19 | - scikit-image==0.21.0 20 | - scikit-learn==1.2.2 21 | - scipy==1.10.1 22 | - submitit==1.4.5 23 | - timm==0.9.2 24 | - tqdm==4.65.0 25 | - wandb==0.15.3 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | from typing import Iterable 4 | import torch 5 | 6 | import datasets as myDBs 7 | from torchvision import transforms as vT 8 | from util import audio_transforms as aT 9 | 10 | from models.deepavfusion import DeepAVFusion 11 | from models.avmae import AVMAE 12 | 13 | from util import distributed as dist_utils 14 | from util import misc as misc_utils 15 | from util import data as data_utils 16 | from util import meters, lr_sched 17 | from util.knn_probe import EvalAVNNProbe 18 | 19 | 20 | def main_worker(local_rank, args): 21 | # Setup environment 22 | job_dir = f"{args.output_dir}/{args.job_name}" 23 | dist_utils.init_distributed_mode(local_rank, args, log_fn=f"{job_dir}/train.log") 24 | device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') 25 | print(f'job dir: {job_dir}') 26 | misc_utils.print_args(args) 27 | 28 | # Adjust learning rate to batch size 29 | num_tasks = dist_utils.get_world_size() 30 | num_tasks_per_node = max(1, torch.cuda.device_count()) 31 | args.env.workers = args.env.workers // num_tasks_per_node 32 | eff_batch_size = args.opt.batch_size * args.opt.accum_iter * num_tasks 33 | if args.opt.lr is None: # only base_lr is specified 34 | args.opt.lr = args.opt.blr * eff_batch_size / 256 35 | print("base lr: %.2e" % args.opt.blr) 36 | print("actual lr: %.2e" % args.opt.lr) 37 | print("accumulate grad iterations: %d" % args.opt.accum_iter) 38 | print("effective batch size: %d" % eff_batch_size) 39 | 40 | # Dataloaders 41 | dataset = myDBs.load_dataset( 42 | args.data.dataset, 43 | args.data.data_path, 44 | dataset_type='simple', 45 | visual_transform=vT.Compose([ 46 | vT.RandomResizedCrop(args.data.image_size, scale=(args.data.crop_min, 1.)), 47 | vT.RandomHorizontalFlip(), 48 | vT.ToTensor(), 49 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]), 50 | audio_transform=aT.Compose([ 51 | aT.Pad(rate=args.data.audio_rate, dur=args.data.audio_dur), 52 | aT.RandomVol(), 53 | aT.MelSpectrogram(sample_rate=args.data.audio_rate, n_fft=int(args.data.audio_rate * 0.05), hop_length=int(args.data.audio_rate / 64), n_mels=args.data.audio_mels), 54 | aT.Log()]), 55 | train=True, 56 | audio_dur=args.data.audio_dur, 57 | audio_rate=args.data.audio_rate, 58 | temporal_jitter=True, 59 | ) 60 | loader = data_utils.get_dataloader( 61 | dataset, args.env.distributed, args.opt.batch_size, args.env.workers, shuffle=True, drop_last=True) 62 | print(dataset) 63 | 64 | # Create model 65 | image_size, audio_size = (args.data.image_size, args.data.image_size), (args.data.audio_mels, int(args.data.audio_dur * 64)) 66 | encoder = DeepAVFusion( 67 | image_arch=args.model.image.backbone, image_pretrained=args.model.image.pretrained, image_size=image_size, 68 | audio_arch=args.model.audio.backbone, audio_pretrained=args.model.audio.pretrained, audio_size=audio_size, 69 | fusion_arch=args.model.fusion.arch, 70 | fusion_layers=args.model.fusion.layers, 71 | num_fusion_tkns=(args.model.fusion.num_fusion_tkns, 72 | args.model.fusion.num_aggr_image_tkns, 73 | args.model.fusion.num_aggr_audio_tkns), 74 | fusion_mlp_ratio=args.model.fusion.mlp_ratio, 75 | fusion_attn_ratio=args.model.fusion.attn_ratio, 76 | fusion_num_heads=args.model.fusion.num_heads 77 | ) 78 | model = AVMAE( 79 | encoder, encoder.embed_dim, 80 | image_decoder_arch=args.model.image.decoder_arch, image_decoder_depth=args.model.image.decoder_depth, 81 | image_mask_ratio=args.model.image.mask_ratio, image_norm_loss=args.model.image.norm_loss, 82 | audio_decoder_arch=args.model.audio.decoder_arch, audio_decoder_depth=args.model.audio.decoder_depth, 83 | audio_mask_ratio=args.model.audio.mask_ratio, audio_norm_loss=args.model.audio.norm_loss 84 | ) 85 | model.to(device) 86 | print("Model = %s" % str(model)) 87 | 88 | # Optimizer 89 | no_weight_decay_list = [n for n, p in model.named_parameters() if 'bias' in n or 'norm' in n] 90 | param_groups = lr_sched.param_groups_pretrained( 91 | model, args.opt.weight_decay, no_weight_decay_list=no_weight_decay_list, 92 | image_pt=args.model.image.pretrained, audio_pt=args.model.audio.pretrained) 93 | optimizer = torch.optim.AdamW(param_groups, lr=args.opt.lr, betas=(0.9, 0.95)) 94 | print(optimizer) 95 | 96 | # Trainer 97 | trainer = misc_utils.Trainer( 98 | model, 99 | optimizer=optimizer, 100 | use_amp=args.opt.use_amp, 101 | accum_iter=args.opt.accum_iter, 102 | distributed=args.env.distributed 103 | ) 104 | 105 | # Checkpointing and logging 106 | ckpt_manager = misc_utils.CheckpointManager( 107 | modules=trainer.module_dict(), 108 | ckpt_dir=f"{job_dir}/checkpoints", 109 | epochs=args.opt.epochs, 110 | save_freq=args.log.save_freq) 111 | start_epoch = ckpt_manager.resume()[0] if args.opt.resume else 0 112 | wb_logger = misc_utils.WBLogger( 113 | f"{job_dir}/wandb", args.log.wandb_entity, args.log.wandb_project, args.job_name, 114 | model, args) 115 | 116 | # Set up probes 117 | knn_probe = EvalAVNNProbe(args.nn_probe, args.log, args.env) 118 | 119 | # =============================================================== # 120 | # Training loop 121 | print(f"Start training for {args.opt.epochs} epochs") 122 | for epoch in range(start_epoch, args.opt.epochs): 123 | if args.env.distributed: 124 | loader.sampler.set_epoch(epoch) 125 | 126 | # train for one epoch 127 | train_one_epoch(loader, trainer, epoch, 128 | device=device, wb_logger=wb_logger, args=args) 129 | 130 | # evaluate 131 | if epoch % args.log.eval_freq == 0 or epoch == args.opt.epochs - 1 or epoch == start_epoch: 132 | global_step = (len(loader) // trainer.accum_iter) * (epoch + 1) 133 | knn_stats = knn_probe.evaluate(trainer.eval_model, epoch=epoch) 134 | wb_logger.log(knn_stats, step=global_step, force=True) 135 | 136 | # save checkpoint 137 | ckpt_manager.checkpoint(epoch+1, {'epoch': epoch+1}) 138 | 139 | 140 | def train_one_epoch(loader: Iterable, 141 | trainer: misc_utils.Trainer, 142 | epoch: int = 0, 143 | wb_logger: misc_utils.WBLogger = None, 144 | device: torch.device = torch.device('cpu'), 145 | args=None): 146 | trainer.model.train(True) 147 | metric_logger = meters.MetricLogger(delimiter=" ") 148 | header = f'[Train][Ep-{epoch}/{args.opt.epochs}]' 149 | 150 | trainer.zero_grad() 151 | for step, (image, audio, _) in enumerate(metric_logger.log_every(loader, args.log.print_freq, header)): 152 | sys.stdout.flush() 153 | global_step = (len(loader) // trainer.accum_iter) * epoch + step // trainer.accum_iter 154 | if step % args.opt.accum_iter == 0: 155 | lr = lr_sched.adjust_learning_rate(trainer.optimizer, epoch + step / len(loader), args) 156 | metric_logger.update(lr=lr) 157 | 158 | # Prepare data 159 | image = image.to(device, non_blocking=True).float() 160 | audio = audio.to(device, non_blocking=True).float() 161 | 162 | # Forward pass 163 | with trainer.autocast(), trainer.autosync(): 164 | loss_image, loss_audio = trainer.model(image, audio)[:2] 165 | loss = loss_image + loss_audio 166 | if not math.isfinite(loss.item()): 167 | raise f"Loss is {loss.item()}, stopping training" 168 | 169 | # Backward pass and model update 170 | grad_norm, amp_scale = trainer.step(loss) 171 | 172 | # Log 173 | if trainer.accums == 0: 174 | metric_logger.update( 175 | loss=loss.item(), loss_image=loss_image.item(), loss_audio=loss_audio.item(), 176 | grad_norm=grad_norm, amp_scale=amp_scale, n=image.shape[0]) 177 | wb_logger.log(metric_logger.latest(), step=global_step) 178 | 179 | if args.debug and step == 100: 180 | break 181 | 182 | # gather the stats from all processes 183 | print("Syncing meters...") 184 | metric_logger.synchronize_between_processes() 185 | print("Averaged stats:", metric_logger) 186 | trainer.zero_grad() 187 | return metric_logger.averages() 188 | -------------------------------------------------------------------------------- /util/audio_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torchaudio.transforms import * 4 | import torchaudio.functional as F 5 | from torchvision.transforms import Compose 6 | 7 | 8 | class RandomVol(torch.nn.Module): 9 | def __init__(self, gain=(-6, 6)): 10 | super(RandomVol, self).__init__() 11 | self.gain = gain 12 | 13 | def forward(self, waveform): 14 | gain = random.uniform(self.gain[0], self.gain[1]) 15 | waveform = F.gain(waveform, gain) 16 | waveform = torch.clamp(waveform, -1, 1) 17 | return waveform 18 | 19 | class Pad(torch.nn.Module): 20 | def __init__(self, dur, rate): 21 | super(Pad, self).__init__() 22 | self.samples = int(dur * rate) 23 | 24 | def forward(self, waveform): 25 | while waveform.shape[-1] < self.samples: 26 | waveform = torch.cat((waveform, torch.flip(waveform, dims=(1,))), dim=1) 27 | return waveform[:, :self.samples] 28 | 29 | class Log(torch.nn.Module): 30 | def __init__(self, eps=1e-7): 31 | super(Log, self).__init__() 32 | self.eps = eps 33 | 34 | def forward(self, spec): 35 | return torch.log10(spec + self.eps) -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from torch.utils.data._utils.collate import default_collate 3 | from util import distributed as dist_utils 4 | 5 | 6 | def get_dataloader(db, distributed, batch_size, workers, collate_fn=default_collate, shuffle=True, drop_last=True): 7 | if distributed: 8 | num_tasks = dist_utils.get_world_size() 9 | global_rank = dist_utils.get_rank() 10 | sampler = data.DistributedSampler(db, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 11 | else: 12 | sampler = data.RandomSampler(db, replacement=True) 13 | return data.DataLoader( 14 | db, 15 | sampler=sampler, 16 | batch_size=batch_size, 17 | num_workers=workers, 18 | pin_memory=True, 19 | drop_last=drop_last, 20 | collate_fn=collate_fn, 21 | persistent_workers=workers>0, 22 | ) -------------------------------------------------------------------------------- /util/distributed.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import sys 4 | import random 5 | import numpy as np 6 | import warnings 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from torch.backends import cudnn 11 | 12 | 13 | def setup_for_distributed(is_master, log_fn=None): 14 | """ 15 | This function disables printing when not in master process 16 | """ 17 | builtin_print = builtins.print 18 | def print(*args, **kwargs): 19 | force = kwargs.pop('force', False) 20 | force = force or (get_rank() % 8 == 0) 21 | if is_master or force: 22 | # print with time stamp 23 | now = datetime.datetime.now().time() 24 | msg = f'[{now}] ' + ' '.join([str(ct) for ct in args]) 25 | 26 | # Print to terminal 27 | builtin_print(msg, **kwargs) 28 | sys.stdout.flush() 29 | 30 | # Log to file 31 | if log_fn is not None: 32 | open(log_fn, 'a').write(msg + '\n') 33 | 34 | builtins.print = print 35 | 36 | 37 | def is_dist_avail_and_initialized(): 38 | if not dist.is_available(): 39 | return False 40 | if not dist.is_initialized(): 41 | return False 42 | return True 43 | 44 | 45 | def get_world_size(): 46 | if not is_dist_avail_and_initialized(): 47 | return 1 48 | return dist.get_world_size() 49 | 50 | 51 | def get_rank(): 52 | if not is_dist_avail_and_initialized(): 53 | return 0 54 | return dist.get_rank() 55 | 56 | 57 | def is_main_process(): 58 | return get_rank() == 0 59 | 60 | 61 | def save_on_master(*args, **kwargs): 62 | if is_main_process(): 63 | torch.save(*args, **kwargs) 64 | 65 | 66 | def init_distributed_mode(local_rank, args, log_fn): 67 | ngpus_per_node = torch.cuda.device_count() 68 | args.env.distributed = ngpus_per_node > 0 69 | if args.env.distributed: 70 | args.env.world_size = ngpus_per_node * args.env.world_size 71 | args.env.rank = args.env.rank * ngpus_per_node + local_rank 72 | else: 73 | print('Not using distributed mode') 74 | setup_for_distributed(is_master=True, log_fn=log_fn) # hack 75 | args.env.world_size = 1 76 | args.env.rank = 0 77 | return 78 | 79 | dist.init_process_group(backend='nccl', init_method=args.env.dist_url, 80 | world_size=args.env.world_size, rank=args.env.rank) 81 | 82 | torch.cuda.set_device(local_rank) 83 | print('Distributed init (rank {}): {}, gpu {}'.format( 84 | args.env.rank, args.env.dist_url, local_rank), flush=True) 85 | print('before barrier') 86 | torch.distributed.barrier() 87 | print('after barrier') 88 | setup_for_distributed(args.env.rank == 0, log_fn=log_fn) 89 | 90 | if args.env.seed is not None: 91 | seed = args.env.seed + get_rank() 92 | random.seed(seed) 93 | torch.manual_seed(seed) 94 | np.random.seed(seed) 95 | cudnn.deterministic = True 96 | warnings.warn('You have chosen to seed training. ' 97 | 'This will turn on the CUDNN deterministic setting, ' 98 | 'which can slow down your training considerably! ' 99 | 'You may see unexpected behavior when restarting ' 100 | 'from checkpoints.') 101 | 102 | 103 | def all_reduce_mean(x): 104 | world_size = get_world_size() 105 | if world_size > 1: 106 | x_reduce = torch.tensor(x).cuda() 107 | dist.all_reduce(x_reduce) 108 | x_reduce /= world_size 109 | return x_reduce.item() 110 | else: 111 | return x 112 | 113 | 114 | @torch.no_grad() 115 | def concat_all_gather(tensor): 116 | """ 117 | Performs all_gather operation on the provided tensors. 118 | *** Warning ***: torch.distributed.all_gather has no gradient. 119 | """ 120 | if not torch.distributed.is_initialized(): 121 | return tensor 122 | tensors_gather = [torch.ones_like(tensor) 123 | for _ in range(torch.distributed.get_world_size())] 124 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 125 | 126 | output = torch.cat(tensors_gather, dim=0) 127 | return output 128 | 129 | 130 | def all_gather(obj): 131 | gather_list = [None] * get_world_size() 132 | dist.all_gather_object(gather_list, obj) 133 | return gather_list -------------------------------------------------------------------------------- /util/image_labels_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | import torchvision.transforms.functional as F 4 | from torchvision.transforms.functional import InterpolationMode 5 | import numpy as np 6 | 7 | class Compose(T.Compose): 8 | def __call__(self, img, lbls=()): 9 | for t in self.transforms: 10 | img, lbls = t(img, lbls) 11 | return img, lbls 12 | 13 | 14 | 15 | class RandomResizedCrop(T.RandomResizedCrop): 16 | def forward(self, img, lbls=()): 17 | if not isinstance(lbls, (list, tuple)): 18 | lbls = [lbls] 19 | 20 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 21 | img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) 22 | lbls = [F.resized_crop(lbl, i, j, h, w, self.size, InterpolationMode.NEAREST, antialias=self.antialias) for lbl in lbls] 23 | return img, lbls 24 | 25 | 26 | class Resize(T.Resize): 27 | def forward(self, img, lbls=()): 28 | if not isinstance(lbls, (list, tuple)): 29 | lbls = [lbls] 30 | img = F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) 31 | lbls = [F.resize(lbl, self.size, InterpolationMode.NEAREST, self.max_size, self.antialias) for lbl in lbls] 32 | return img, lbls 33 | 34 | 35 | class CenterCrop(T.CenterCrop): 36 | def forward(self, img, lbls=()): 37 | if not isinstance(lbls, (list, tuple)): 38 | lbls = [lbls] 39 | img = F.center_crop(img, self.size) 40 | lbls = [F.center_crop(lbl, self.size) for lbl in lbls] 41 | return img, lbls 42 | 43 | 44 | class RandomHorizontalFlip(T.RandomHorizontalFlip): 45 | def forward(self, img, lbls=()): 46 | if not isinstance(lbls, (list, tuple)): 47 | lbls = [lbls] 48 | if torch.rand(1) < self.p: 49 | return F.hflip(img), [F.hflip(lbl) for lbl in lbls] 50 | return img, lbls 51 | 52 | 53 | class ToTensor(T.ToTensor): 54 | def __call__(self, img, lbls=()): 55 | return F.to_tensor(img), [torch.tensor(np.array(lbl)).long() for lbl in lbls] 56 | 57 | 58 | class Normalize(T.Normalize): 59 | def forward(self, img, lbls=()): 60 | return F.normalize(img, self.mean, self.std, self.inplace), lbls 61 | 62 | -------------------------------------------------------------------------------- /util/knn_probe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | 5 | import numpy as np 6 | from collections import defaultdict 7 | 8 | parent = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 9 | sys.path.append(parent) 10 | 11 | from datasets import get_vggsound, get_audioset 12 | import torch.utils.data 13 | from sklearn import metrics 14 | 15 | 16 | from util import distributed as dist_utils 17 | from util import meters 18 | from torch.nn import functional as F 19 | from torchvision import transforms as vT 20 | from util import audio_transforms as aT 21 | 22 | 23 | class EvalAVNNProbe: 24 | def __init__(self, probe_args, log_args, env_args): 25 | self.device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') 26 | self.distributed = env_args.distributed 27 | self.eval_freq = log_args.eval_freq 28 | self.print_freq = log_args.print_freq 29 | self.dataset = probe_args.dataset 30 | 31 | image_transform = vT.Compose([ 32 | vT.Resize(int(probe_args.image_size/0.875)), 33 | vT.CenterCrop(probe_args.image_size), 34 | vT.ToTensor(), 35 | vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 36 | ]) 37 | audio_transform = vT.Compose([ 38 | aT.Pad(rate=probe_args.audio_rate, dur=probe_args.audio_dur), 39 | aT.MelSpectrogram(sample_rate=probe_args.audio_rate, n_fft=int(probe_args.audio_rate * 0.05), hop_length=int(probe_args.audio_rate / 64), n_mels=probe_args.audio_mels), 40 | aT.Log(), 41 | ]) 42 | 43 | if self.dataset == 'vggsound': 44 | self.db = get_vggsound( 45 | probe_args.data_path, 46 | partition='test', 47 | audio_dur=probe_args.audio_dur, 48 | audio_rate=probe_args.audio_rate, 49 | visual_transform=image_transform, 50 | audio_transform=audio_transform 51 | ) 52 | self.multi_label = False 53 | elif self.dataset == 'audioset': 54 | self.db = get_audioset( 55 | probe_args.data_path, 56 | partition='eval', 57 | audio_dur=probe_args.audio_dur, 58 | audio_rate=probe_args.audio_rate, 59 | visual_transform=image_transform, 60 | audio_transform=audio_transform 61 | ) 62 | self.multi_label = True 63 | else: 64 | raise NotImplementedError 65 | 66 | if self.distributed: 67 | num_tasks = dist_utils.get_world_size() 68 | global_rank = dist_utils.get_rank() 69 | self.sampler = torch.utils.data.DistributedSampler( 70 | self.db, num_replicas=num_tasks, rank=global_rank, shuffle=True) 71 | else: 72 | self.sampler = torch.utils.data.RandomSampler(self.db) 73 | self.loader = torch.utils.data.DataLoader( 74 | self.db, 75 | sampler=self.sampler, 76 | batch_size=max(probe_args.batch_size//4, 1), 77 | num_workers=max(env_args.workers, 1), 78 | pin_memory=False, 79 | drop_last=True, 80 | ) 81 | 82 | @torch.no_grad() 83 | def evaluate(self, model, epoch=0): 84 | # model = copy.deepcopy(model) 85 | model.train(False) 86 | if self.distributed: 87 | self.sampler.set_epoch(0) 88 | 89 | # Extract features 90 | a_feats, v_feats, mm_feats, labels = [], [], [], [] 91 | metric_logger = meters.MetricLogger(delimiter=" ") 92 | for image, spec, anno in metric_logger.log_every(self.loader, self.print_freq, 'Extract features'): 93 | # Prepare data 94 | spec = spec.to(self.device, non_blocking=True).float() 95 | image = image.to(self.device, non_blocking=True).float() 96 | lbl = anno['class'].to(self.device, non_blocking=True).long() 97 | 98 | # Extract features 99 | x_v, x_a, x_mm = model.forward_encoder(image, spec)[:3] 100 | 101 | # Collect features and labels 102 | v_feats.append(x_v.mean(dim=1)) 103 | a_feats.append(x_a.mean(dim=1)) 104 | mm_feats.append(x_mm.mean(dim=1)) 105 | labels.append(lbl) 106 | 107 | # Synchronize across gpus 108 | a_feats = dist_utils.concat_all_gather(F.normalize(torch.cat(a_feats), p=2, dim=1)) 109 | v_feats = dist_utils.concat_all_gather(F.normalize(torch.cat(v_feats), p=2, dim=1)) 110 | mm_feats = dist_utils.concat_all_gather(F.normalize(torch.cat(mm_feats), p=2, dim=1)) 111 | labels = dist_utils.concat_all_gather(torch.cat(labels)) 112 | n_data = labels.shape[0] 113 | 114 | # kNN Evaluation 115 | metric_logger = meters.MetricLogger(delimiter=" ") 116 | preds = defaultdict(list) 117 | for i in metric_logger.log_every(range(0, n_data, 128), 250): 118 | qa_feats = a_feats[i:i+128] 119 | qv_feats = v_feats[i:i+128] 120 | qmm_feats = mm_feats[i:i+128] 121 | 122 | # Find nearest neighbor and aggregate predictions 123 | scores_a = torch.einsum('qd,nd->qn', qa_feats, a_feats) 124 | scores_v = torch.einsum('qd,nd->qn', qv_feats, v_feats) 125 | scores_mm = torch.einsum('qd,nd->qn', qmm_feats, mm_feats) 126 | 127 | for mod, scores in [('audio', scores_a), ('image', scores_v), ('fusion', scores_mm), ('all', scores_v+scores_a+scores_mm)]: 128 | scores, nn_idx = torch.topk(scores, k=2, dim=1, sorted=True) 129 | preds[mod].append(( 130 | labels[nn_idx[:, 1]], 131 | scores[:, 1] 132 | )) 133 | 134 | # Compute accuracies/auc 135 | metrics_dict = {} 136 | labels = labels.cpu().numpy() 137 | if self.multi_label: 138 | seen_classes = labels.sum(0) > 0 139 | for mod in preds: 140 | scores = torch.cat([(ypred * yscore[:, None]) for ypred, yscore in preds[mod]]).cpu().numpy() 141 | ap = metrics.average_precision_score(labels[:, seen_classes], scores[:, seen_classes], average=None) 142 | auc = metrics.roc_auc_score(labels[:, seen_classes], scores[:, seen_classes], average=None) 143 | metrics_dict.update({f'{mod}_nn_ap': ap.mean(), f'{mod}_nn_auc': auc.mean()}) 144 | 145 | else: 146 | for mod in preds: 147 | ypred = torch.cat([ypred for ypred, _ in preds[mod]]).cpu().numpy() 148 | acc = np.mean(ypred == labels)*100 149 | metrics_dict.update({f'{mod}_nn_acc': acc}) 150 | 151 | print(metrics_dict) 152 | return metrics_dict 153 | 154 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LARS(torch.optim.Optimizer): 5 | """ 6 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 7 | """ 8 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 9 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 10 | super().__init__(params, defaults) 11 | 12 | @torch.no_grad() 13 | def step(self): 14 | for g in self.param_groups: 15 | for p in g['params']: 16 | dp = p.grad 17 | 18 | if dp is None: 19 | continue 20 | 21 | if p.ndim > 1: # if not normalization gamma/beta or bias 22 | dp = dp.add(p, alpha=g['weight_decay']) 23 | param_norm = torch.norm(p) 24 | update_norm = torch.norm(dp) 25 | one = torch.ones_like(param_norm) 26 | q = torch.where(param_norm > 0., 27 | torch.where(update_norm > 0, 28 | (g['trust_coefficient'] * param_norm / update_norm), one), 29 | one) 30 | dp = dp.mul(q) 31 | 32 | param_state = self.state[p] 33 | if 'mu' not in param_state: 34 | param_state['mu'] = torch.zeros_like(p) 35 | mu = param_state['mu'] 36 | mu.mul_(g['momentum']).add_(dp) 37 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, epoch, args): 5 | wu = args.opt.get('warmup_epochs', 0) 6 | if epoch < wu: 7 | lr = args.opt.lr * epoch / wu 8 | else: 9 | lr = args.opt.lr * 0.5 * (1. + math.cos(math.pi * (epoch - wu) / (args.opt.epochs - wu))) 10 | 11 | # Learning rate adjustment for pretrained components 12 | pt_warmup_epochs = eval(str(args.opt.get('pt_warmup_epochs', -1))) 13 | if epoch < pt_warmup_epochs: 14 | lr_pt_scale = (0.5 - 0.5 * math.cos(math.pi * epoch / pt_warmup_epochs)) * (args.opt.pt_lr_mult_end - args.opt.pt_lr_mult_start) + args.opt.pt_lr_mult_start 15 | else: 16 | lr_pt_scale = args.opt.get('pt_lr_mult_end', 1.) 17 | 18 | for param_group in optimizer.param_groups: 19 | lr_layer_scale = param_group.get('lr_scale', 1.) 20 | if param_group.get('pretrained', False): 21 | param_group["lr"] = lr * lr_layer_scale * lr_pt_scale 22 | else: 23 | param_group["lr"] = lr * lr_layer_scale 24 | return lr 25 | 26 | 27 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 28 | """ 29 | Parameter groups for layer-wise lr decay 30 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 31 | """ 32 | param_groups = {} 33 | 34 | weights_layer_id = {k: v for k, v in model.params_layer_ids()} 35 | num_layers = max(list(weights_layer_id.values())) 36 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 37 | 38 | for n, p in model.named_parameters(): 39 | if not p.requires_grad: 40 | continue 41 | 42 | # no decay: all 1D parameters and model specific ones 43 | if p.ndim == 1 or n in no_weight_decay_list: 44 | g_decay = "no_decay" 45 | this_decay = 0. 46 | else: 47 | g_decay = "decay" 48 | this_decay = weight_decay 49 | 50 | group_name = "layer_%d_%s" % (weights_layer_id[p], g_decay) 51 | if group_name not in param_groups: 52 | param_groups[group_name] = { 53 | "lr_scale": layer_scales[weights_layer_id[p]], 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | param_groups[group_name]["params"].append(p) 58 | 59 | return list(param_groups.values()) 60 | 61 | 62 | def get_layer_id_for_vit(name, num_layers): 63 | """ 64 | Assign a parameter with its layer id 65 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 66 | """ 67 | if name in ['cls_token', 'pos_embed']: 68 | return 0 69 | elif name.startswith('patch_embed'): 70 | return 0 71 | elif name.startswith('blocks'): 72 | return int(name.split('.')[1]) + 1 73 | else: 74 | return num_layers 75 | 76 | 77 | def param_groups_pretrained(model, weight_decay=0.05, no_weight_decay_list=[], image_pt=None, audio_pt=None): 78 | from timm.optim import optim_factory 79 | param_groups = optim_factory.param_groups_weight_decay( 80 | model, weight_decay, 81 | no_weight_decay_list=no_weight_decay_list) 82 | param_groups_pt = [] 83 | if image_pt is not None: 84 | param_groups_pt += optim_factory.param_groups_weight_decay(model.encoder.image, weight_decay, no_weight_decay_list=no_weight_decay_list) 85 | if audio_pt is not None: 86 | param_groups_pt += optim_factory.param_groups_weight_decay(model.encoder.audio, weight_decay, no_weight_decay_list=no_weight_decay_list) 87 | for group in param_groups_pt: 88 | group['pretrained'] = True 89 | params_pt = set([p for group in param_groups_pt for p in group['params']]) 90 | for group in param_groups: 91 | group['params'] = [p for p in group['params'] if p not in params_pt] 92 | param_groups += param_groups_pt 93 | return param_groups -------------------------------------------------------------------------------- /util/meters.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from collections import defaultdict, deque 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from util import distributed as dist_utils 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | 15 | def __init__(self, window_size=20, fmt=None): 16 | if fmt is None: 17 | fmt = "{median:.4f} ({global_avg:.4f})" 18 | self.deque = deque(maxlen=window_size) 19 | self.total = 0.0 20 | self.count = 0 21 | self.fmt = fmt 22 | 23 | def update(self, value, n=1): 24 | self.deque.append(value) 25 | self.count += n 26 | self.total += value * n 27 | 28 | def synchronize_between_processes(self): 29 | """ 30 | Warning: does not synchronize the deque! 31 | """ 32 | if not dist_utils.is_dist_avail_and_initialized(): 33 | return 34 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 35 | dist.barrier() 36 | dist.all_reduce(t) 37 | t = t.tolist() 38 | self.count = int(t[0]) 39 | self.total = t[1] 40 | 41 | @property 42 | def median(self): 43 | d = torch.tensor(list(self.deque)) 44 | return d.median().item() 45 | 46 | @property 47 | def avg(self): 48 | d = torch.tensor(list(self.deque), dtype=torch.float32) 49 | return d.mean().item() 50 | 51 | @property 52 | def global_avg(self): 53 | return self.total / self.count 54 | 55 | @property 56 | def max(self): 57 | return max(self.deque) 58 | 59 | @property 60 | def value(self): 61 | return self.deque[-1] 62 | 63 | def __str__(self): 64 | return self.fmt.format( 65 | median=self.median, 66 | avg=self.avg, 67 | global_avg=self.global_avg, 68 | max=self.max, 69 | value=self.value) 70 | 71 | 72 | class MetricLogger(object): 73 | def __init__(self, delimiter="\t", print_freq=20, num_iters=0, header=''): 74 | self.delimiter = delimiter 75 | self.print_freq = print_freq 76 | self.num_iters = num_iters 77 | self.header = header 78 | self.meters = defaultdict(SmoothedValue) 79 | 80 | self._it = 0 81 | self._t = time.time() 82 | self._iter_time = SmoothedValue(fmt='{avg:.4f}') 83 | 84 | space_fmt = ':' + str(len(str(self.num_iters))) + 'd' 85 | log_msg = [ 86 | header, 87 | '[{0' + space_fmt + '}/{1}]', 88 | 'time: {time}', 89 | 'eta: {eta}', 90 | '{meters}', 91 | ] 92 | if torch.cuda.is_available(): 93 | log_msg.append('max mem: {memory:.0f}') 94 | self.log_msg = delimiter.join(log_msg) 95 | 96 | def update(self, n=1, **kwargs): 97 | for k, v in kwargs.items(): 98 | if v is None: 99 | continue 100 | if isinstance(v, torch.Tensor): 101 | v = v.item() 102 | assert isinstance(v, (float, int)) 103 | self.meters[k].update(v, n=n) 104 | 105 | def __getattr__(self, attr): 106 | if attr in self.meters: 107 | return self.meters[attr] 108 | if attr in self.__dict__: 109 | return self.__dict__[attr] 110 | raise AttributeError("'{}' object has no attribute '{}'".format( 111 | type(self).__name__, attr)) 112 | 113 | def __str__(self): 114 | msg_list = [] 115 | for name, meter in self.meters.items(): 116 | msg_list.append("{}: {}".format(name, str(meter))) 117 | return self.delimiter.join(msg_list) 118 | 119 | def synchronize_between_processes(self): 120 | for meter in self.meters.values(): 121 | meter.synchronize_between_processes() 122 | 123 | def add_meter(self, name, meter): 124 | self.meters[name] = meter 125 | 126 | def log_iter(self): 127 | self._iter_time.update(time.time() - self._t) 128 | self._t = time.time() 129 | self._it = self._it + 1 130 | 131 | MB = 1024.0 * 1024.0 132 | if self._it % self.print_freq == 0 or self._it == self.num_iters - 1: 133 | eta_seconds = self._iter_time.median * (self.num_iters - self._it) 134 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 135 | 136 | if torch.cuda.is_available(): 137 | print(self.log_msg.format( 138 | self._it, self.num_iters, eta=eta_string, 139 | meters=str(self), time=str(self._iter_time), 140 | memory=torch.cuda.max_memory_allocated() / MB)) 141 | else: 142 | print(self.log_msg.format( 143 | self._it, self.num_iters, eta=eta_string, 144 | meters=str(self), time=str(self._iter_time))) 145 | sys.stdout.flush() 146 | 147 | def log_every(self, iterable, print_freq, header=None): 148 | i = 0 149 | if not header: 150 | header = '' 151 | start_time = time.time() 152 | end = time.time() 153 | iter_time = SmoothedValue(fmt='{avg:.4f}') 154 | data_time = SmoothedValue(fmt='{avg:.4f}') 155 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 156 | log_msg = [ 157 | header, 158 | '[{0' + space_fmt + '}/{1}]', 159 | 'eta: {eta}', 160 | '{meters}', 161 | 'time: {time}', 162 | 'data: {data}' 163 | ] 164 | if torch.cuda.is_available(): 165 | log_msg.append('max mem: {memory:.0f}') 166 | log_msg = self.delimiter.join(log_msg) 167 | MB = 1024.0 * 1024.0 168 | for obj in iterable: 169 | data_time.update(time.time() - end) 170 | yield obj 171 | iter_time.update(time.time() - end) 172 | if i % print_freq == 0 or i == len(iterable) - 1: 173 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 174 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 175 | if torch.cuda.is_available(): 176 | print(log_msg.format( 177 | i, len(iterable), eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), data=str(data_time), 180 | memory=torch.cuda.max_memory_allocated() / MB)) 181 | else: 182 | print(log_msg.format( 183 | i, len(iterable), eta=eta_string, 184 | meters=str(self), 185 | time=str(iter_time), data=str(data_time))) 186 | i += 1 187 | end = time.time() 188 | total_time = time.time() - start_time 189 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 190 | print('{} Total time: {} ({:.4f} s / it)'.format( 191 | header, total_time_str, total_time / len(iterable))) 192 | 193 | def latest(self): 194 | return {k: meter.value for k, meter in self.meters.items()} 195 | 196 | def averages(self): 197 | return {k: meter.global_avg for k, meter in self.meters.items()} 198 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data 7 | 8 | from util import distributed as dist_utils 9 | import numpy as np 10 | from sklearn import metrics 11 | from collections import deque 12 | import copy 13 | 14 | import wandb 15 | import contextlib 16 | 17 | 18 | def print_args(args, prefix=''): 19 | from omegaconf.dictconfig import DictConfig 20 | for k in args: 21 | if isinstance(args[k], DictConfig): 22 | print_args(args[k], prefix=prefix + f'{k}.') 23 | else: 24 | print(f"{prefix}{k}: {str(args[k])}") 25 | 26 | 27 | class Trainer: 28 | def __init__(self, model, criterion=None, optimizer=None, accum_iter=1, use_amp=True, distributed=False): 29 | self.distributed = distributed 30 | self.model_without_ddp = model 31 | self.n_steps = torch.tensor([0]) 32 | if self.distributed: 33 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 34 | model = torch.nn.parallel.DistributedDataParallel(model) 35 | self.model = model 36 | self.criterion = criterion 37 | self.optimizer = optimizer 38 | self.scaler = torch.cuda.amp.GradScaler() if use_amp else None 39 | 40 | self.accum_iter = accum_iter 41 | self.accums = 0 42 | 43 | self.eval_model = self.model_without_ddp 44 | self.zero_grad() 45 | 46 | def module_dict(self): 47 | d = {'state_dict': self.model_without_ddp, 'n_steps': self.n_steps} 48 | if self.criterion is not None: 49 | d['criterion'] = self.criterion 50 | if self.optimizer is not None: 51 | d['optimizer'] = self.optimizer 52 | if self.scaler is not None: 53 | d['scaler'] = self.scaler 54 | return d 55 | 56 | def zero_grad(self): 57 | self.optimizer.zero_grad() 58 | for param_group in self.optimizer.param_groups: 59 | for param in param_group['params']: 60 | param.grad_bkp = None 61 | self.accums = 0 62 | 63 | def get_scale(self): 64 | if self.scaler is not None: 65 | return self.scaler.get_scale() 66 | else: 67 | return 1. 68 | 69 | def backward(self, loss, create_graph=False): 70 | # Backward pass 71 | if self.scaler is not None: 72 | loss = self.scaler.scale(loss) 73 | 74 | loss.backward(create_graph=create_graph) 75 | self.accums += 1 # Track number of gradients accumulated 76 | 77 | # Track gradient norm (adjust by loss scale and # accumulated grads) 78 | norm = get_grad_norm_(self.model.parameters()) / (self.accums * self.get_scale()) 79 | return norm.item(), self.get_scale() 80 | 81 | def backup_grads(self): 82 | for param_group in self.optimizer.param_groups: 83 | for param in param_group['params']: 84 | param.grad_bkp = param.grad 85 | param.grad = None 86 | 87 | def restore_grads(self): 88 | for param_group in self.optimizer.param_groups: 89 | for param in param_group['params']: 90 | if param.grad is None: 91 | param.grad = param.grad_bkp 92 | else: 93 | param.grad += param.grad_bkp 94 | param.grad_bkp = None 95 | 96 | def step(self, loss, create_graph=False, clip_grad=None, skip_grad=None): 97 | # Backup grads in case norm too large 98 | if skip_grad is not None: 99 | self.backup_grads() 100 | norm, scale = self.backward(loss, create_graph=create_graph) 101 | if norm > skip_grad: 102 | self.optimizer.zero_grad() 103 | self.accums -= 1 104 | self.restore_grads() 105 | else: 106 | norm, scale = self.backward(loss, create_graph=create_graph) 107 | 108 | # Time to update? 109 | if self.accums == self.accum_iter: 110 | # Unscale gradients in-place 111 | if self.scaler is not None: 112 | self.scaler.unscale_(self.optimizer) 113 | 114 | # Adjust for grad accumulation 115 | if self.accum_iter > 1: 116 | for group in self.optimizer.param_groups: 117 | for param in group["params"]: 118 | if param.grad is not None: 119 | param.grad /= self.accum_iter 120 | 121 | # Clip gradients in-place 122 | if clip_grad is not None: 123 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad) 124 | 125 | # Update parameters 126 | if self.scaler is not None: 127 | self.scaler.step(self.optimizer) 128 | self.scaler.update() 129 | else: 130 | self.optimizer.step() 131 | 132 | # Reset grads 133 | self.zero_grad() 134 | self.n_steps += 1 135 | 136 | return norm, scale 137 | 138 | def autocast(self): 139 | if self.scaler is not None: 140 | return torch.cuda.amp.autocast() 141 | else: 142 | return contextlib.nullcontext() 143 | 144 | def autosync(self): 145 | if self.distributed and self.accums < self.accum_iter - 1: 146 | return self.model.no_sync() 147 | else: 148 | return contextlib.nullcontext() 149 | 150 | 151 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 152 | if isinstance(parameters, torch.Tensor): 153 | parameters = [parameters] 154 | parameters = [p for p in parameters if p.grad is not None] 155 | norm_type = float(norm_type) 156 | if len(parameters) == 0: 157 | return torch.tensor(0.) 158 | device = parameters[0].grad.device 159 | if norm_type == math.inf: 160 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 161 | else: 162 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 163 | return total_norm 164 | 165 | 166 | class WBLogger: 167 | def __init__(self, wandb_dir, entity, project, job_name, model, args): 168 | self.mute = not args.log.use_wandb or not dist_utils.is_main_process() 169 | if self.mute: 170 | return 171 | self.wandb_dir = wandb_dir 172 | self.entity = entity 173 | self.project = project 174 | self.job_name = job_name 175 | 176 | self.log_counter = 0 # Counter for spreading out logs 177 | self.log_freq = args.log.print_freq 178 | self.watch_freq = args.log.wandb_watch_freq 179 | 180 | self.init_wandb(args, model) 181 | 182 | def init_wandb(self, args, model): 183 | if self.mute: 184 | return 185 | 186 | os.makedirs(self.wandb_dir, exist_ok=True) 187 | runid = None 188 | if os.path.exists(f"{self.wandb_dir}/runid.txt"): 189 | runid = open(f"{self.wandb_dir}/runid.txt").read() 190 | wandb.init(entity=self.entity, project=self.project, name=self.job_name, 191 | dir=self.wandb_dir, resume="allow", id=runid) 192 | open(f"{self.wandb_dir}/runid.txt", 'w').write(wandb.run.id) 193 | 194 | # Push config 195 | def flatten_args(args): 196 | from omegaconf import DictConfig 197 | args_flat = {} 198 | for k in dict(args).keys(): 199 | if isinstance(args[k], DictConfig): 200 | args_flat.update({f"{k}.{k2}": v for k2, v in flatten_args(args[k]).items()}) 201 | else: 202 | args_flat[k] = args[k] 203 | return args_flat 204 | wandb.config.update({k: v for k, v in flatten_args(args).items() if k not in wandb.config}) 205 | 206 | # Log model stats 207 | 208 | if self.watch_freq > 0: 209 | wandb.watch(model, log="all", log_freq=self.watch_freq) 210 | 211 | def log(self, metrics, step=None, force=False): 212 | if self.mute: 213 | return 214 | if force or step is None: 215 | wandb.log(metrics, step=step) 216 | else: 217 | self.log_counter += 1 218 | if self.log_counter % self.log_freq == 0: 219 | wandb.log(metrics, step=step) 220 | 221 | 222 | class CheckpointManager: 223 | def __init__(self, 224 | modules, 225 | ckpt_dir, 226 | epochs, 227 | save_freq=None): 228 | self.modules = modules 229 | self.ckpt_dir = ckpt_dir 230 | self.epochs = epochs 231 | self.save_freq = save_freq 232 | 233 | self.world_size = dist_utils.get_world_size() 234 | self.rank = dist_utils.get_rank() 235 | 236 | if self.rank == 0: 237 | os.makedirs(os.path.join(self.ckpt_dir), exist_ok=True) 238 | 239 | def map_location(self, state, device): 240 | if isinstance(state, dict): 241 | return {k: self.map_location(state[k], device) for k in state} 242 | elif isinstance(state, list): 243 | return [self.map_location(st, device) for st in state] 244 | elif isinstance(state, tuple): 245 | return tuple(self.map_location(st, device) for st in state) 246 | elif isinstance(state, torch.Tensor): 247 | return state.to(device) 248 | else: 249 | return state 250 | 251 | def create_state_dict(self, save_dict): 252 | state = {} 253 | for k in self.modules: 254 | if self.modules[k] is None: 255 | state[k] = None 256 | elif isinstance(self.modules[k], torch.Tensor): 257 | state[k] = copy.deepcopy(self.modules[k]).cpu() 258 | else: 259 | module_state = copy.deepcopy(self.modules[k].state_dict()) 260 | state[k] = self.map_location(module_state, 'cpu') 261 | 262 | if save_dict is not None: 263 | state.update(save_dict) 264 | return state 265 | 266 | def load_state_dict(self, checkpoint): 267 | for k in self.modules: 268 | self.modules[k].load_state_dict(checkpoint[k]) 269 | metrics = {k: checkpoint[k] for k in checkpoint if k not in self.modules} 270 | return metrics 271 | 272 | def resume(self): 273 | ckpt_fname = os.path.join(self.ckpt_dir, f'checkpoint_latest.pth') 274 | print(f"Loading {ckpt_fname}") 275 | start_epoch, metrics = 0, {} 276 | if os.path.isfile(ckpt_fname): 277 | checkpoint = torch.load(ckpt_fname, map_location='cpu') 278 | 279 | # Load state dict 280 | for k in self.modules: 281 | if self.modules[k] is None: 282 | continue 283 | elif isinstance(self.modules[k], torch.Tensor): 284 | self.modules[k].data[:] = checkpoint[k].data 285 | else: 286 | self.modules[k].load_state_dict(checkpoint[k]) 287 | start_epoch = checkpoint['epoch'] 288 | metrics = {k: checkpoint[k] for k in checkpoint if k not in set(self.modules.keys()) and k != 'epoch'} 289 | print(f"=> loaded checkpoint '{ckpt_fname}' (epoch {checkpoint['epoch']})") 290 | 291 | return start_epoch, metrics 292 | 293 | def checkpoint(self, epoch, save_dict=None, is_best=False): 294 | if self.rank != 0: 295 | return 296 | state = self.create_state_dict(save_dict) 297 | ckpt_fname = os.path.join(self.ckpt_dir, f'checkpoint_latest.pth') 298 | torch.save(state, ckpt_fname) 299 | print(f"=> saved checkpoint '{ckpt_fname}' (epoch {epoch})") 300 | 301 | if is_best: 302 | best_fname = os.path.join(self.ckpt_dir, f'checkpoint_best.pth') 303 | torch.save(state, best_fname) 304 | print(f"=> saved best checkpoint '{best_fname}' (epoch {epoch})") 305 | 306 | if self.save_freq is not None and ((epoch % self.save_freq == 0) or epoch == self.epochs): 307 | ckpt_fname = os.path.join(self.ckpt_dir, f'checkpoint_{epoch:04d}.pth') 308 | torch.save(state, ckpt_fname) 309 | print(f"=> saved checkpoint '{ckpt_fname}' (epoch {epoch})") 310 | 311 | 312 | def calc_multi_class_stats(labels, preds): 313 | assert labels.shape[0] == preds.shape[0] 314 | 315 | seen_classes = labels.sum(0) > 0 316 | labels, preds = labels[:, seen_classes], preds[:, seen_classes] 317 | num_classes = seen_classes.sum() 318 | 319 | ap = np.array([metrics.average_precision_score(labels[:, cls], preds[:, cls], average=None) 320 | for cls in range(num_classes)]) 321 | auc = np.array([metrics.roc_auc_score(labels[:, cls], preds[:, cls], average=None) 322 | for cls in range(num_classes)]) 323 | 324 | return dict( 325 | ap=ap.mean()*100., 326 | auc=auc.mean()*100., 327 | ) -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.layers import to_2tuple 6 | 7 | 8 | # -------------------------------------------------------- 9 | # 2D sine-cosine position embedding 10 | # References: 11 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 12 | # MoCo v3: https://github.com/facebookresearch/moco-v3 13 | # -------------------------------------------------------- 14 | 15 | 16 | def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=0, thw_props=(2, 1, 1)): 17 | """ 18 | grid_size: int of the grid height and width 19 | return: 20 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 21 | """ 22 | h_dim = int(embed_dim * (thw_props[1] / float(sum(thw_props)))) 23 | w_dim = int(embed_dim * (thw_props[2] / float(sum(thw_props)))) 24 | t_dim = embed_dim - h_dim - w_dim 25 | 26 | grid_t = np.arange(grid_size[0], dtype=np.float32) 27 | grid_h = np.arange(grid_size[1], dtype=np.float32) 28 | grid_w = np.arange(grid_size[2], dtype=np.float32) 29 | grid = np.meshgrid(grid_t, grid_w, grid_h, indexing="ij") 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([3, 1, grid_size[0], grid_size[1], grid_size[2]]) 33 | # use half of dimensions to encode grid_h and grid_w 34 | emb_t = get_1d_sincos_pos_embed_from_grid(t_dim, grid[0]) # (T*H*W, x) 35 | emb_h = get_1d_sincos_pos_embed_from_grid(h_dim, grid[1]) # (T*H*W, y) 36 | emb_w = get_1d_sincos_pos_embed_from_grid(w_dim, grid[2]) # (T*H*W, z) 37 | pos_embed = np.concatenate([emb_t, emb_h, emb_w], axis=1) # (T*H*W, D) 38 | if cls_token: 39 | pos_embed = np.concatenate([np.zeros([int(cls_token), embed_dim]), pos_embed], axis=0) 40 | return pos_embed 41 | 42 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 43 | """ 44 | grid_size: int of the grid height and width 45 | return: 46 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 47 | """ 48 | grid_size = to_2tuple(grid_size) 49 | grid_h = np.arange(grid_size[0], dtype=np.float32) 50 | grid_w = np.arange(grid_size[1], dtype=np.float32) 51 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 52 | grid = np.stack(grid, axis=0) 53 | 54 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 55 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 56 | if cls_token: 57 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 58 | return pos_embed 59 | 60 | 61 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 62 | assert embed_dim % 2 == 0 63 | 64 | # use half of dimensions to encode grid_h 65 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 66 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 67 | 68 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 69 | return emb 70 | 71 | 72 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 73 | """ 74 | embed_dim: output dimension for each position 75 | pos: a list of positions to be encoded: size (M,) 76 | out: (M, D) 77 | """ 78 | assert embed_dim % 2 == 0 79 | omega = np.arange(embed_dim // 2, dtype=np.float32) 80 | omega /= embed_dim / 2. 81 | omega = 1. / 10000**omega # (D/2,) 82 | 83 | pos = pos.reshape(-1) # (M,) 84 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 85 | 86 | emb_sin = np.sin(out) # (M, D/2) 87 | emb_cos = np.cos(out) # (M, D/2) 88 | 89 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 90 | return emb 91 | 92 | 93 | # -------------------------------------------------------- 94 | # Interpolate position embeddings for high-resolution 95 | # References: 96 | # DeiT: https://github.com/facebookresearch/deit 97 | # -------------------------------------------------------- 98 | def interpolate_pos_embed(model, checkpoint_model): 99 | if 'pos_embed' in checkpoint_model: 100 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 101 | embedding_size = pos_embed_checkpoint.shape[-1] 102 | num_patches = model.patch_embed.num_patches 103 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 104 | # height (== width) for the checkpoint position embedding 105 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 106 | # height (== width) for the new position embedding 107 | new_size = int(num_patches ** 0.5) 108 | # class_token and dist_token are kept unchanged 109 | if orig_size != new_size: 110 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 111 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 112 | # only the position tokens are interpolated 113 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 114 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 115 | pos_tokens = torch.nn.functional.interpolate( 116 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 117 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 118 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 119 | checkpoint_model['pos_embed'] = new_pos_embed 120 | 121 | 122 | 123 | class PatchEmbed3D(nn.Module): 124 | """ 125 | Flexible Image to Patch Embedding 126 | """ 127 | def __init__(self, input_size=(16, 224, 224), patch_size=(2, 16, 16), in_chans=3, embed_dim=768, stride=(2, 16, 16)): 128 | super().__init__() 129 | self.input_size = input_size 130 | self.patch_size = patch_size 131 | self.in_chans = in_chans 132 | 133 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) 134 | _, _, t, h, w = self.get_output_shape(input_size) # n, emb_dim, h, w 135 | self.patch_thw = (t, h, w) 136 | self.num_patches = t * h * w 137 | 138 | def get_output_shape(self, input_size): 139 | # todo: don't be lazy.. 140 | return self.proj(torch.randn(1, self.in_chans, input_size[0], input_size[1], input_size[2])).shape 141 | 142 | def forward(self, x): 143 | x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14 144 | x = x.flatten(2) # 32, 768, 1568 145 | x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768 146 | return x --------------------------------------------------------------------------------