├── 2.5D_visual_sound.png ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── __init__.py ├── audioVisual_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py └── data_loader.py ├── demo.py ├── evaluate.py ├── models ├── __init__.py ├── audioVisual_model.py ├── criterion.py ├── models.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── reEncodeAudio.py ├── train.py └── util ├── __init__.py └── util.py /2.5D_visual_sound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/2.5D-Visual-Sound/73c73755bd8faa4478c1ab09441bc1e77fe7f1e7/2.5D_visual_sound.png -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to 2.5D-Visual-Sound 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 80 character line length 31 | 32 | ## License 33 | By contributing to 2.5D-Visual-Sound, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | Section 1 -- Definitions. 69 | 70 | a. Adapted Material means material subject to Copyright and Similar 71 | Rights that is derived from or based upon the Licensed Material 72 | and in which the Licensed Material is translated, altered, 73 | arranged, transformed, or otherwise modified in a manner requiring 74 | permission under the Copyright and Similar Rights held by the 75 | Licensor. For purposes of this Public License, where the Licensed 76 | Material is a musical work, performance, or sound recording, 77 | Adapted Material is always produced where the Licensed Material is 78 | synched in timed relation with a moving image. 79 | 80 | b. Adapter's License means the license You apply to Your Copyright 81 | and Similar Rights in Your contributions to Adapted Material in 82 | accordance with the terms and conditions of this Public License. 83 | 84 | c. Copyright and Similar Rights means copyright and/or similar rights 85 | closely related to copyright including, without limitation, 86 | performance, broadcast, sound recording, and Sui Generis Database 87 | Rights, without regard to how the rights are labeled or 88 | categorized. For purposes of this Public License, the rights 89 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 90 | Rights. 91 | 92 | d. Effective Technological Measures means those measures that, in the 93 | absence of proper authority, may not be circumvented under laws 94 | fulfilling obligations under Article 11 of the WIPO Copyright 95 | Treaty adopted on December 20, 1996, and/or similar international 96 | agreements. 97 | 98 | e. Exceptions and Limitations means fair use, fair dealing, and/or 99 | any other exception or limitation to Copyright and Similar Rights 100 | that applies to Your use of the Licensed Material. 101 | 102 | f. Licensed Material means the artistic or literary work, database, 103 | or other material to which the Licensor applied this Public 104 | License. 105 | 106 | g. Licensed Rights means the rights granted to You subject to the 107 | terms and conditions of this Public License, which are limited to 108 | all Copyright and Similar Rights that apply to Your use of the 109 | Licensed Material and that the Licensor has authority to license. 110 | 111 | h. Licensor means the individual(s) or entity(ies) granting rights 112 | under this Public License. 113 | 114 | i. Share means to provide material to the public by any means or 115 | process that requires permission under the Licensed Rights, such 116 | as reproduction, public display, public performance, distribution, 117 | dissemination, communication, or importation, and to make material 118 | available to the public including in ways that members of the 119 | public may access the material from a place and at a time 120 | individually chosen by them. 121 | 122 | j. Sui Generis Database Rights means rights other than copyright 123 | resulting from Directive 96/9/EC of the European Parliament and of 124 | the Council of 11 March 1996 on the legal protection of databases, 125 | as amended and/or succeeded, as well as other essentially 126 | equivalent rights anywhere in the world. 127 | 128 | k. You means the individual or entity exercising the Licensed Rights 129 | under this Public License. Your has a corresponding meaning. 130 | 131 | Section 2 -- Scope. 132 | 133 | a. License grant. 134 | 135 | 1. Subject to the terms and conditions of this Public License, 136 | the Licensor hereby grants You a worldwide, royalty-free, 137 | non-sublicensable, non-exclusive, irrevocable license to 138 | exercise the Licensed Rights in the Licensed Material to: 139 | 140 | a. reproduce and Share the Licensed Material, in whole or 141 | in part; and 142 | 143 | b. produce, reproduce, and Share Adapted Material. 144 | 145 | 2. Exceptions and Limitations. For the avoidance of doubt, where 146 | Exceptions and Limitations apply to Your use, this Public 147 | License does not apply, and You do not need to comply with 148 | its terms and conditions. 149 | 150 | 3. Term. The term of this Public License is specified in Section 151 | 6(a). 152 | 153 | 4. Media and formats; technical modifications allowed. The 154 | Licensor authorizes You to exercise the Licensed Rights in 155 | all media and formats whether now known or hereafter created, 156 | and to make technical modifications necessary to do so. The 157 | Licensor waives and/or agrees not to assert any right or 158 | authority to forbid You from making technical modifications 159 | necessary to exercise the Licensed Rights, including 160 | technical modifications necessary to circumvent Effective 161 | Technological Measures. For purposes of this Public License, 162 | simply making modifications authorized by this Section 2(a) 163 | (4) never produces Adapted Material. 164 | 165 | 5. Downstream recipients. 166 | 167 | a. Offer from the Licensor -- Licensed Material. Every 168 | recipient of the Licensed Material automatically 169 | receives an offer from the Licensor to exercise the 170 | Licensed Rights under the terms and conditions of this 171 | Public License. 172 | 173 | b. No downstream restrictions. You may not offer or impose 174 | any additional or different terms or conditions on, or 175 | apply any Effective Technological Measures to, the 176 | Licensed Material if doing so restricts exercise of the 177 | Licensed Rights by any recipient of the Licensed 178 | Material. 179 | 180 | 6. No endorsement. Nothing in this Public License constitutes or 181 | may be construed as permission to assert or imply that You 182 | are, or that Your use of the Licensed Material is, connected 183 | with, or sponsored, endorsed, or granted official status by, 184 | the Licensor or others designated to receive attribution as 185 | provided in Section 3(a)(1)(A)(i). 186 | 187 | b. Other rights. 188 | 189 | 1. Moral rights, such as the right of integrity, are not 190 | licensed under this Public License, nor are publicity, 191 | privacy, and/or other similar personality rights; however, to 192 | the extent possible, the Licensor waives and/or agrees not to 193 | assert any such rights held by the Licensor to the limited 194 | extent necessary to allow You to exercise the Licensed 195 | Rights, but not otherwise. 196 | 197 | 2. Patent and trademark rights are not licensed under this 198 | Public License. 199 | 200 | 3. To the extent possible, the Licensor waives any right to 201 | collect royalties from You for the exercise of the Licensed 202 | Rights, whether directly or through a collecting society 203 | under any voluntary or waivable statutory or compulsory 204 | licensing scheme. In all other cases the Licensor expressly 205 | reserves any right to collect such royalties. 206 | 207 | Section 3 -- License Conditions. 208 | 209 | Your exercise of the Licensed Rights is expressly made subject to the 210 | following conditions. 211 | 212 | a. Attribution. 213 | 214 | 1. If You Share the Licensed Material (including in modified 215 | form), You must: 216 | 217 | a. retain the following if it is supplied by the Licensor 218 | with the Licensed Material: 219 | 220 | i. identification of the creator(s) of the Licensed 221 | Material and any others designated to receive 222 | attribution, in any reasonable manner requested by 223 | the Licensor (including by pseudonym if 224 | designated); 225 | 226 | ii. a copyright notice; 227 | 228 | iii. a notice that refers to this Public License; 229 | 230 | iv. a notice that refers to the disclaimer of 231 | warranties; 232 | 233 | v. a URI or hyperlink to the Licensed Material to the 234 | extent reasonably practicable; 235 | 236 | b. indicate if You modified the Licensed Material and 237 | retain an indication of any previous modifications; and 238 | 239 | c. indicate the Licensed Material is licensed under this 240 | Public License, and include the text of, or the URI or 241 | hyperlink to, this Public License. 242 | 243 | 2. You may satisfy the conditions in Section 3(a)(1) in any 244 | reasonable manner based on the medium, means, and context in 245 | which You Share the Licensed Material. For example, it may be 246 | reasonable to satisfy the conditions by providing a URI or 247 | hyperlink to a resource that includes the required 248 | information. 249 | 250 | 3. If requested by the Licensor, You must remove any of the 251 | information required by Section 3(a)(1)(A) to the extent 252 | reasonably practicable. 253 | 254 | 4. If You Share Adapted Material You produce, the Adapter's 255 | License You apply must not prevent recipients of the Adapted 256 | Material from complying with this Public License. 257 | 258 | Section 4 -- Sui Generis Database Rights. 259 | 260 | Where the Licensed Rights include Sui Generis Database Rights that 261 | apply to Your use of the Licensed Material: 262 | 263 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 264 | to extract, reuse, reproduce, and Share all or a substantial 265 | portion of the contents of the database; 266 | 267 | b. if You include all or a substantial portion of the database 268 | contents in a database in which You have Sui Generis Database 269 | Rights, then the database in which You have Sui Generis Database 270 | Rights (but not its individual contents) is Adapted Material; and 271 | 272 | c. You must comply with the conditions in Section 3(a) if You Share 273 | all or a substantial portion of the contents of the database. 274 | 275 | For the avoidance of doubt, this Section 4 supplements and does not 276 | replace Your obligations under this Public License where the Licensed 277 | Rights include other Copyright and Similar Rights. 278 | 279 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 280 | 281 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 282 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 283 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 284 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 285 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 286 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 287 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 288 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 289 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 290 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 291 | 292 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 293 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 294 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 295 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 296 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 297 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 298 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 299 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 300 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 301 | 302 | c. The disclaimer of warranties and limitation of liability provided 303 | above shall be interpreted in a manner that, to the extent 304 | possible, most closely approximates an absolute disclaimer and 305 | waiver of all liability. 306 | 307 | Section 6 -- Term and Termination. 308 | 309 | a. This Public License applies for the term of the Copyright and 310 | Similar Rights licensed here. However, if You fail to comply with 311 | this Public License, then Your rights under this Public License 312 | terminate automatically. 313 | 314 | b. Where Your right to use the Licensed Material has terminated under 315 | Section 6(a), it reinstates: 316 | 317 | 1. automatically as of the date the violation is cured, provided 318 | it is cured within 30 days of Your discovery of the 319 | violation; or 320 | 321 | 2. upon express reinstatement by the Licensor. 322 | 323 | For the avoidance of doubt, this Section 6(b) does not affect any 324 | right the Licensor may have to seek remedies for Your violations 325 | of this Public License. 326 | 327 | c. For the avoidance of doubt, the Licensor may also offer the 328 | Licensed Material under separate terms or conditions or stop 329 | distributing the Licensed Material at any time; however, doing so 330 | will not terminate this Public License. 331 | 332 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 333 | License. 334 | 335 | Section 7 -- Other Terms and Conditions. 336 | 337 | a. The Licensor shall not be bound by any additional or different 338 | terms or conditions communicated by You unless expressly agreed. 339 | 340 | b. Any arrangements, understandings, or agreements regarding the 341 | Licensed Material not stated herein are separate from and 342 | independent of the terms and conditions of this Public License. 343 | 344 | Section 8 -- Interpretation. 345 | 346 | a. For the avoidance of doubt, this Public License does not, and 347 | shall not be interpreted to, reduce, limit, restrict, or impose 348 | conditions on any use of the Licensed Material that could lawfully 349 | be made without permission under this Public License. 350 | 351 | b. To the extent possible, if any provision of this Public License is 352 | deemed unenforceable, it shall be automatically reformed to the 353 | minimum extent necessary to make it enforceable. If the provision 354 | cannot be reformed, it shall be severed from this Public License 355 | without affecting the enforceability of the remaining terms and 356 | conditions. 357 | 358 | c. No term or condition of this Public License will be waived and no 359 | failure to comply consented to unless expressly agreed to by the 360 | Licensor. 361 | 362 | d. Nothing in this Public License constitutes or may be interpreted 363 | as a limitation upon, or waiver of, any privileges and immunities 364 | that apply to the Licensor or You, including from the legal 365 | processes of any jurisdiction or authority. 366 | 367 | ======================================================================= 368 | 369 | Creative Commons is not a party to its public licenses. 370 | Notwithstanding, Creative Commons may elect to apply one of its public 371 | licenses to material it publishes and in those instances will be 372 | considered the "Licensor." Except for the limited purpose of indicating 373 | that material is shared under a Creative Commons public license or as 374 | otherwise permitted by the Creative Commons policies published at 375 | creativecommons.org/policies, Creative Commons does not authorize the 376 | use of the trademark "Creative Commons" or any other trademark or logo 377 | of Creative Commons without its prior written consent including, 378 | without limitation, in connection with any unauthorized modifications 379 | to any of its public licenses or any other arrangements, 380 | understandings, or agreements concerning use of licensed material. For 381 | the avoidance of doubt, this paragraph does not form part of the public 382 | licenses. 383 | 384 | Creative Commons may be contacted at creativecommons.org. 385 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 2.5D Visual Sound 2 | [[Project Page]](http://vision.cs.utexas.edu/projects/2.5D_visual_sound/) [[arXiv]](https://arxiv.org/abs/1812.04204) [[Video]](https://www.youtube.com/watch?v=Wrx3pv_ixdI) [[Dataset]](https://github.com/facebookresearch/FAIR-Play)
3 | 4 | 5 | 6 |
7 | 8 | [2.5D Visual Sound](https://arxiv.org/abs/1812.04204) 9 | [Ruohan Gao](https://www.cs.utexas.edu/~rhgao/)1 and [Kristen Grauman](http://www.cs.utexas.edu/~grauman/)2
10 | 1UT Austin, 2Facebook AI Research 11 | In Conference on Computer Vision and Pattern Recognition (**CVPR**), 2019 12 | 13 |
14 | 15 | If you find our code or project useful in your research, please cite: 16 | 17 | @inproceedings{gao2019visualsound, 18 | title={2.5D Visual Sound}, 19 | author={Gao, Ruohan and Grauman, Kristen}, 20 | booktitle={CVPR}, 21 | year={2019} 22 | } 23 | 24 | ### FAIR-Play Dataset 25 | The [FAIR-Play](https://github.com/facebookresearch/FAIR-Play) repository contains the dataset we collected and used in our paper. It contains 1,871 video clips and their corresponding binaural audio clips recorded in a music room. The code provided can be used to train mon2binaural models on this dataset. 26 | 27 | ### Training and Testing 28 | (The code has beed tested under the following system environment: Ubuntu 16.04.6 LTS, CUDA 9.0, Python 2.7.15, PyTorch 1.0.0) 29 | 1. Download the FAIR-Play dataset and prepare the hdf5 splits files accordingly by adding the correct root prefix. 30 | 31 | 2. [OPTIONAL] Preprocess the audio files using reEncodeAudio.py to accelerate the training process. 32 | 33 | 3. Use the following command to train the mono2binaural model: 34 | ``` 35 | python train.py --hdf5FolderPath /YOUR_CODE_PATH/2.5d_visual_sound/hdf5/ --name mono2binaural --model audioVisual --checkpoints_dir /YOUR_CHECKPOINT_PATH/ --save_epoch_freq 50 --display_freq 10 --save_latest_freq 100 --batchSize 256 --learning_rate_decrease_itr 10 --niter 1000 --lr_visual 0.0001 --lr_audio 0.001 --nThreads 32 --gpu_ids 0,1,2,3,4,5,6,7 --validation_on --validation_freq 100 --validation_batches 50 --tensorboard True |& tee -a mono2binaural.log 36 | ``` 37 | 38 | 4. Use the following command to test your trained mono2binaural model: 39 | ``` 40 | python demo.py --input_audio_path /BINAURAL_AUDIO_PATH --video_frame_path /VIDEO_FRAME_PATH --weights_visual /VISUAL_MODEL_PATH --weights_audio /AUDIO_MODEL_PATH --output_dir_root /YOUT_OUTPUT_DIR/ --input_audio_length 10 --hop_size 0.05 41 | ``` 42 | The model trained on split1 of FAIR-Play is shared at: https://drive.google.com/drive/folders/1fq-SK4OBFVegLM0PfVcLerR9exSwcDWl?usp=drive_link 43 | 44 | 5. Use the following command for evaluation: 45 | ``` 46 | python evaluate.py --results_root /YOUR_RESULTS --normalization True 47 | ``` 48 | 49 | ### Acknowlegements 50 | Portions of the code are adapted from the CycleGAN implementation (https://github.com/junyanz/CycleGAN) and the Sound-of-Pixels implementation (https://github.com/hangzhaomit/Sound-of-Pixels). Please also refer to the original License of these projects. 51 | 52 | 53 | ### Licence 54 | The code for 2.5D Visual Sound is CC BY 4.0 licensed, as found in the LICENSE file. 55 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/2.5D-Visual-Sound/73c73755bd8faa4478c1ab09441bc1e77fe7f1e7/data/__init__.py -------------------------------------------------------------------------------- /data/audioVisual_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import os.path 11 | import time 12 | import librosa 13 | import h5py 14 | import random 15 | import math 16 | import numpy as np 17 | import glob 18 | import torch 19 | from PIL import Image, ImageEnhance 20 | import torchvision.transforms as transforms 21 | from data.base_dataset import BaseDataset 22 | 23 | def normalize(samples, desired_rms = 0.1, eps = 1e-4): 24 | rms = np.maximum(eps, np.sqrt(np.mean(samples**2))) 25 | samples = samples * (desired_rms / rms) 26 | return samples 27 | 28 | def generate_spectrogram(audio): 29 | spectro = librosa.core.stft(audio, n_fft=512, hop_length=160, win_length=400, center=True) 30 | real = np.expand_dims(np.real(spectro), axis=0) 31 | imag = np.expand_dims(np.imag(spectro), axis=0) 32 | spectro_two_channel = np.concatenate((real, imag), axis=0) 33 | return spectro_two_channel 34 | 35 | def process_image(image, augment): 36 | image = image.resize((480,240)) 37 | w,h = image.size 38 | w_offset = w - 448 39 | h_offset = h - 224 40 | left = random.randrange(0, w_offset + 1) 41 | upper = random.randrange(0, h_offset + 1) 42 | image = image.crop((left, upper, left+448, upper+224)) 43 | 44 | if augment: 45 | enhancer = ImageEnhance.Brightness(image) 46 | image = enhancer.enhance(random.random()*0.6 + 0.7) 47 | enhancer = ImageEnhance.Color(image) 48 | image = enhancer.enhance(random.random()*0.6 + 0.7) 49 | return image 50 | 51 | class AudioVisualDataset(BaseDataset): 52 | def initialize(self, opt): 53 | self.opt = opt 54 | self.audios = [] 55 | 56 | #load hdf5 file here 57 | h5f_path = os.path.join(opt.hdf5FolderPath, opt.mode+".h5") 58 | h5f = h5py.File(h5f_path, 'r') 59 | self.audios = h5f['audio'][:] 60 | 61 | normalize = transforms.Normalize( 62 | mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225] 64 | ) 65 | vision_transform_list = [transforms.ToTensor(), normalize] 66 | self.vision_transform = transforms.Compose(vision_transform_list) 67 | 68 | def __getitem__(self, index): 69 | #load audio 70 | audio, audio_rate = librosa.load(self.audios[index], sr=self.opt.audio_sampling_rate, mono=False) 71 | 72 | #randomly get a start time for the audio segment from the 10s clip 73 | audio_start_time = random.uniform(0, 9.9 - self.opt.audio_length) 74 | audio_end_time = audio_start_time + self.opt.audio_length 75 | audio_start = int(audio_start_time * self.opt.audio_sampling_rate) 76 | audio_end = audio_start + int(self.opt.audio_length * self.opt.audio_sampling_rate) 77 | audio = audio[:, audio_start:audio_end] 78 | audio = normalize(audio) 79 | audio_channel1 = audio[0,:] 80 | audio_channel2 = audio[1,:] 81 | 82 | #get the frame dir path based on audio path 83 | path_parts = self.audios[index].strip().split('/') 84 | path_parts[-1] = path_parts[-1][:-4] + '.mp4' 85 | path_parts[-2] = 'frames' 86 | frame_path = '/'.join(path_parts) 87 | 88 | # get the closest frame to the audio segment 89 | #frame_index = int(round((audio_start_time + audio_end_time) / 2.0 + 0.5)) #1 frame extracted per second 90 | frame_index = int(round(((audio_start_time + audio_end_time) / 2.0 + 0.05) * 10)) #10 frames extracted per second 91 | frame = process_image(Image.open(os.path.join(frame_path, str(frame_index).zfill(6) + '.png')).convert('RGB'), self.opt.enable_data_augmentation) 92 | frame = self.vision_transform(frame) 93 | 94 | #passing the spectrogram of the difference 95 | audio_diff_spec = torch.FloatTensor(generate_spectrogram(audio_channel1 - audio_channel2)) 96 | audio_mix_spec = torch.FloatTensor(generate_spectrogram(audio_channel1 + audio_channel2)) 97 | 98 | return {'frame': frame, 'audio_diff_spec':audio_diff_spec, 'audio_mix_spec':audio_mix_spec} 99 | 100 | def __len__(self): 101 | return len(self.audios) 102 | 103 | def name(self): 104 | return 'AudioVisualDataset' 105 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | class BaseDataLoader(): 10 | def __init__(self): 11 | pass 12 | 13 | def initialize(self, opt): 14 | self.opt = opt 15 | pass 16 | 17 | def load_data(): 18 | return None 19 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch.utils.data as data 10 | from PIL import Image 11 | import torchvision.transforms as transforms 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | def name(self): 18 | return 'BaseDataset' 19 | 20 | def initialize(self, opt): 21 | pass 22 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch.utils.data 10 | from data.base_data_loader import BaseDataLoader 11 | 12 | def CreateDataset(opt): 13 | dataset = None 14 | if opt.model == 'audioVisual': 15 | from data.audioVisual_dataset import AudioVisualDataset 16 | dataset = AudioVisualDataset() 17 | else: 18 | raise ValueError("Dataset [%s] not recognized." % opt.model) 19 | 20 | print("dataset [%s] was created" % (dataset.name())) 21 | dataset.initialize(opt) 22 | return dataset 23 | 24 | class CustomDatasetDataLoader(BaseDataLoader): 25 | def name(self): 26 | return 'CustomDatasetDataLoader' 27 | 28 | def initialize(self, opt): 29 | BaseDataLoader.initialize(self, opt) 30 | self.dataset = CreateDataset(opt) 31 | self.dataloader = torch.utils.data.DataLoader( 32 | self.dataset, 33 | batch_size=opt.batchSize, 34 | shuffle=True, 35 | num_workers=int(opt.nThreads)) 36 | 37 | def load_data(self): 38 | return self 39 | 40 | def __len__(self): 41 | return len(self.dataset) 42 | 43 | def __iter__(self): 44 | for i, data in enumerate(self.dataloader): 45 | yield data 46 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | def CreateDataLoader(opt): 10 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 11 | data_loader = CustomDatasetDataLoader() 12 | print(data_loader.name()) 13 | data_loader.initialize(opt) 14 | return data_loader 15 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | import argparse 11 | import librosa 12 | import numpy as np 13 | from PIL import Image 14 | import subprocess 15 | from options.test_options import TestOptions 16 | import torchvision.transforms as transforms 17 | import torch 18 | from models.models import ModelBuilder 19 | from models.audioVisual_model import AudioVisualModel 20 | from data.audioVisual_dataset import generate_spectrogram 21 | 22 | def audio_normalize(samples, desired_rms = 0.1, eps = 1e-4): 23 | rms = np.maximum(eps, np.sqrt(np.mean(samples**2))) 24 | samples = samples * (desired_rms / rms) 25 | return rms / desired_rms, samples 26 | 27 | def main(): 28 | #load test arguments 29 | opt = TestOptions().parse() 30 | opt.device = torch.device("cuda") 31 | 32 | # network builders 33 | builder = ModelBuilder() 34 | net_visual = builder.build_visual(weights=opt.weights_visual) 35 | net_audio = builder.build_audio( 36 | ngf=opt.unet_ngf, 37 | input_nc=opt.unet_input_nc, 38 | output_nc=opt.unet_output_nc, 39 | weights=opt.weights_audio) 40 | nets = (net_visual, net_audio) 41 | 42 | # construct our audio-visual model 43 | model = AudioVisualModel(nets, opt) 44 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 45 | model.to(opt.device) 46 | model.eval() 47 | 48 | #load the audio to perform separation 49 | audio, audio_rate = librosa.load(opt.input_audio_path, sr=opt.audio_sampling_rate, mono=False) 50 | audio_channel1 = audio[0,:] 51 | audio_channel2 = audio[1,:] 52 | 53 | #define the transformation to perform on visual frames 54 | vision_transform_list = [transforms.Resize((224,448)), transforms.ToTensor()] 55 | vision_transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 56 | vision_transform = transforms.Compose(vision_transform_list) 57 | 58 | #perform spatialization over the whole audio using a sliding window approach 59 | overlap_count = np.zeros((audio.shape)) #count the number of times a data point is calculated 60 | binaural_audio = np.zeros((audio.shape)) 61 | 62 | #perform spatialization over the whole spectrogram in a siliding-window fashion 63 | sliding_window_start = 0 64 | data = {} 65 | samples_per_window = int(opt.audio_length * opt.audio_sampling_rate) 66 | while sliding_window_start + samples_per_window < audio.shape[-1]: 67 | sliding_window_end = sliding_window_start + samples_per_window 68 | normalizer, audio_segment = audio_normalize(audio[:,sliding_window_start:sliding_window_end]) 69 | audio_segment_channel1 = audio_segment[0,:] 70 | audio_segment_channel2 = audio_segment[1,:] 71 | audio_segment_mix = audio_segment_channel1 + audio_segment_channel2 72 | 73 | data['audio_diff_spec'] = torch.FloatTensor(generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze(0) #unsqueeze to add a batch dimension 74 | data['audio_mix_spec'] = torch.FloatTensor(generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze(0) #unsqueeze to add a batch dimension 75 | #get the frame index for current window 76 | frame_index = int(round((((sliding_window_start + samples_per_window / 2.0) / audio.shape[-1]) * opt.input_audio_length + 0.05) * 10 )) 77 | image = Image.open(os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.png')).convert('RGB') 78 | #image = image.transpose(Image.FLIP_LEFT_RIGHT) 79 | frame = vision_transform(image).unsqueeze(0) #unsqueeze to add a batch dimension 80 | data['frame'] = frame 81 | 82 | output = model.forward(data) 83 | predicted_spectrogram = output['binaural_spectrogram'][0,:,:,:].data[:].cpu().numpy() 84 | 85 | #ISTFT to convert back to audio 86 | reconstructed_stft_diff = predicted_spectrogram[0,:,:] + (1j * predicted_spectrogram[1,:,:]) 87 | reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) 88 | reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 89 | reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 90 | reconstructed_binaural = np.concatenate((np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer 91 | 92 | binaural_audio[:,sliding_window_start:sliding_window_end] = binaural_audio[:,sliding_window_start:sliding_window_end] + reconstructed_binaural 93 | overlap_count[:,sliding_window_start:sliding_window_end] = overlap_count[:,sliding_window_start:sliding_window_end] + 1 94 | sliding_window_start = sliding_window_start + int(opt.hop_size * opt.audio_sampling_rate) 95 | 96 | #deal with the last segment 97 | normalizer, audio_segment = audio_normalize(audio[:,-samples_per_window:]) 98 | audio_segment_channel1 = audio_segment[0,:] 99 | audio_segment_channel2 = audio_segment[1,:] 100 | data['audio_diff_spec'] = torch.FloatTensor(generate_spectrogram(audio_segment_channel1 - audio_segment_channel2)).unsqueeze(0) #unsqueeze to add a batch dimension 101 | data['audio_mix_spec'] = torch.FloatTensor(generate_spectrogram(audio_segment_channel1 + audio_segment_channel2)).unsqueeze(0) #unsqueeze to add a batch dimension 102 | #get the frame index for last window 103 | frame_index = int(round(((opt.input_audio_length - opt.audio_length / 2.0) + 0.05) * 10)) 104 | image = Image.open(os.path.join(opt.video_frame_path, str(frame_index).zfill(6) + '.png')).convert('RGB') 105 | #image = image.transpose(Image.FLIP_LEFT_RIGHT) 106 | frame = vision_transform(image).unsqueeze(0) #unsqueeze to add a batch dimension 107 | data['frame'] = frame 108 | output = model.forward(data) 109 | predicted_spectrogram = output['binaural_spectrogram'][0,:,:,:].data[:].cpu().numpy() 110 | #ISTFT to convert back to audio 111 | reconstructed_stft_diff = predicted_spectrogram[0,:,:] + (1j * predicted_spectrogram[1,:,:]) 112 | reconstructed_signal_diff = librosa.istft(reconstructed_stft_diff, hop_length=160, win_length=400, center=True, length=samples_per_window) 113 | reconstructed_signal_left = (audio_segment_mix + reconstructed_signal_diff) / 2 114 | reconstructed_signal_right = (audio_segment_mix - reconstructed_signal_diff) / 2 115 | reconstructed_binaural = np.concatenate((np.expand_dims(reconstructed_signal_left, axis=0), np.expand_dims(reconstructed_signal_right, axis=0)), axis=0) * normalizer 116 | 117 | #add the spatialized audio to reconstructed_binaural 118 | binaural_audio[:,-samples_per_window:] = binaural_audio[:,-samples_per_window:] + reconstructed_binaural 119 | overlap_count[:,-samples_per_window:] = overlap_count[:,-samples_per_window:] + 1 120 | 121 | #divide aggregated predicted audio by their corresponding counts 122 | predicted_binaural_audio = np.divide(binaural_audio, overlap_count) 123 | 124 | #check output directory 125 | if not os.path.isdir(opt.output_dir_root): 126 | os.mkdir(opt.output_dir_root) 127 | 128 | mixed_mono = (audio_channel1 + audio_channel2) / 2 129 | librosa.output.write_wav(os.path.join(opt.output_dir_root, 'predicted_binaural.wav'), predicted_binaural_audio, opt.audio_sampling_rate) 130 | librosa.output.write_wav(os.path.join(opt.output_dir_root, 'mixed_mono.wav'), mixed_mono, opt.audio_sampling_rate) 131 | librosa.output.write_wav(os.path.join(opt.output_dir_root, 'input_binaural.wav'), audio, opt.audio_sampling_rate) 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | import os 11 | import librosa 12 | import argparse 13 | import numpy as np 14 | from numpy import linalg as LA 15 | from scipy.signal import hilbert 16 | from data.audioVisual_dataset import generate_spectrogram 17 | import statistics as stat 18 | 19 | def normalize(samples): 20 | return samples / np.maximum(1e-20, np.max(np.abs(samples))) 21 | 22 | def STFT_L2_distance(predicted_binaural, gt_binaural): 23 | #channel1 24 | predicted_spect_channel1 = librosa.core.stft(np.asfortranarray(predicted_binaural[0,:]), n_fft=512, hop_length=160, win_length=400, center=True) 25 | gt_spect_channel1 = librosa.core.stft(np.asfortranarray(gt_binaural[0,:]), n_fft=512, hop_length=160, win_length=400, center=True) 26 | real = np.expand_dims(np.real(predicted_spect_channel1), axis=0) 27 | imag = np.expand_dims(np.imag(predicted_spect_channel1), axis=0) 28 | predicted_realimag_channel1 = np.concatenate((real, imag), axis=0) 29 | real = np.expand_dims(np.real(gt_spect_channel1), axis=0) 30 | imag = np.expand_dims(np.imag(gt_spect_channel1), axis=0) 31 | gt_realimag_channel1 = np.concatenate((real, imag), axis=0) 32 | channel1_distance = np.mean(np.power((predicted_realimag_channel1 - gt_realimag_channel1), 2)) 33 | 34 | #channel2 35 | predicted_spect_channel2 = librosa.core.stft(np.asfortranarray(predicted_binaural[1,:]), n_fft=512, hop_length=160, win_length=400, center=True) 36 | gt_spect_channel2 = librosa.core.stft(np.asfortranarray(gt_binaural[1,:]), n_fft=512, hop_length=160, win_length=400, center=True) 37 | real = np.expand_dims(np.real(predicted_spect_channel2), axis=0) 38 | imag = np.expand_dims(np.imag(predicted_spect_channel2), axis=0) 39 | predicted_realimag_channel2 = np.concatenate((real, imag), axis=0) 40 | real = np.expand_dims(np.real(gt_spect_channel2), axis=0) 41 | imag = np.expand_dims(np.imag(gt_spect_channel2), axis=0) 42 | gt_realimag_channel2 = np.concatenate((real, imag), axis=0) 43 | channel2_distance = np.mean(np.power((predicted_realimag_channel2 - gt_realimag_channel2), 2)) 44 | 45 | #sum the distance between two channels 46 | stft_l2_distance = channel1_distance + channel2_distance 47 | return float(stft_l2_distance) 48 | 49 | def Envelope_distance(predicted_binaural, gt_binaural): 50 | #channel1 51 | pred_env_channel1 = np.abs(hilbert(predicted_binaural[0,:])) 52 | gt_env_channel1 = np.abs(hilbert(gt_binaural[0,:])) 53 | channel1_distance = np.sqrt(np.mean((gt_env_channel1 - pred_env_channel1)**2)) 54 | 55 | #channel2 56 | pred_env_channel2 = np.abs(hilbert(predicted_binaural[1,:])) 57 | gt_env_channel2 = np.abs(hilbert(gt_binaural[1,:])) 58 | channel2_distance = np.sqrt(np.mean((gt_env_channel2 - pred_env_channel2)**2)) 59 | 60 | #sum the distance between two channels 61 | envelope_distance = channel1_distance + channel2_distance 62 | return float(envelope_distance) 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--results_root', type=str, required=True) 68 | parser.add_argument('--audio_sampling_rate', default=16000, type=int, help='audio sampling rate') 69 | parser.add_argument('--real_mono', default=False, type=bool, help='whether the input predicted binaural audio is mono audio') 70 | parser.add_argument('--normalization', default=False, type=bool) 71 | args = parser.parse_args() 72 | stft_distance_list = [] 73 | envelope_distance_list = [] 74 | 75 | audioNames = os.listdir(args.results_root) 76 | index = 1 77 | for audio_name in audioNames: 78 | if index % 10 == 0: 79 | print "Evaluating testing example " + str(index) + " :", audio_name 80 | #check whether input binaural is mono, replicate to two channels if it's mono 81 | if args.real_mono: 82 | mono_sound, audio_rate = librosa.load(os.path.join(args.results_root, audio_name, 'mixed_mono.wav'), sr=args.audio_sampling_rate) 83 | predicted_binaural = np.repeat(np.expand_dims(mono_sound, 0), 2, axis=0) 84 | if args.normalization: 85 | predicted_binaural = normalize(predicted_binaural) 86 | else: 87 | predicted_binaural, audio_rate = librosa.load(os.path.join(args.results_root, audio_name, 'predicted_binaural.wav'), sr=args.audio_sampling_rate, mono=False) 88 | if args.normalization: 89 | predicted_binaural = normalize(predicted_binaural) 90 | gt_binaural, audio_rate = librosa.load(os.path.join(args.results_root, audio_name, 'input_binaural.wav'), sr=args.audio_sampling_rate, mono=False) 91 | if args.normalization: 92 | gt_binaural = normalize(gt_binaural) 93 | 94 | #get results for this audio 95 | stft_distance_list.append(STFT_L2_distance(predicted_binaural, gt_binaural)) 96 | envelope_distance_list.append(Envelope_distance(predicted_binaural, gt_binaural)) 97 | index = index + 1 98 | 99 | #print the results 100 | print "STFT L2 Distance: ", stat.mean(stft_distance_list), stat.stdev(stft_distance_list), stat.stdev(stft_distance_list) / np.sqrt(len(stft_distance_list)) 101 | print "Average Envelope Distance: ", stat.mean(envelope_distance_list), stat.stdev(envelope_distance_list), stat.stdev(envelope_distance_list) / np.sqrt(len(envelope_distance_list)) 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/2.5D-Visual-Sound/73c73755bd8faa4478c1ab09441bc1e77fe7f1e7/models/__init__.py -------------------------------------------------------------------------------- /models/audioVisual_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | from torch import optim 13 | import torch.nn.functional as F 14 | from . import networks,criterion 15 | from torch.autograd import Variable 16 | 17 | class AudioVisualModel(torch.nn.Module): 18 | def name(self): 19 | return 'AudioVisualModel' 20 | 21 | def __init__(self, nets, opt): 22 | super(AudioVisualModel, self).__init__() 23 | self.opt = opt 24 | #initialize model 25 | self.net_visual, self.net_audio = nets 26 | 27 | def forward(self, input, volatile=False): 28 | visual_input = input['frame'] 29 | audio_diff = input['audio_diff_spec'] 30 | audio_mix = input['audio_mix_spec'] 31 | audio_gt = Variable(audio_diff[:,:,:-1,:], requires_grad=False) 32 | 33 | input_spectrogram = Variable(audio_mix, requires_grad=False, volatile=volatile) 34 | visual_feature = self.net_visual(Variable(visual_input, requires_grad=False, volatile=volatile)) 35 | mask_prediction = self.net_audio(input_spectrogram, visual_feature) 36 | 37 | #complex masking to obtain the predicted spectrogram 38 | spectrogram_diff_real = input_spectrogram[:,0,:-1,:] * mask_prediction[:,0,:,:] - input_spectrogram[:,1,:-1,:] * mask_prediction[:,1,:,:] 39 | spectrogram_diff_img = input_spectrogram[:,0,:-1,:] * mask_prediction[:,1,:,:] + input_spectrogram[:,1,:-1,:] * mask_prediction[:,0,:,:] 40 | binaural_spectrogram = torch.cat((spectrogram_diff_real.unsqueeze(1), spectrogram_diff_img.unsqueeze(1)), 1) 41 | 42 | output = {'mask_prediction': mask_prediction, 'binaural_spectrogram': binaural_spectrogram, 'audio_gt': audio_gt} 43 | return output 44 | -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BaseLoss(nn.Module): 15 | def __init__(self): 16 | super(BaseLoss, self).__init__() 17 | 18 | def forward(self, preds, targets, weight=None): 19 | if isinstance(preds, list): 20 | N = len(preds) 21 | if weight is None: 22 | weight = preds[0].new_ones(1) 23 | 24 | errs = [self._forward(preds[n], targets[n], weight[n]) 25 | for n in range(N)] 26 | err = torch.mean(torch.stack(errs)) 27 | 28 | elif isinstance(preds, torch.Tensor): 29 | if weight is None: 30 | weight = preds.new_ones(1) 31 | err = self._forward(preds, targets, weight) 32 | 33 | return err 34 | 35 | 36 | class L1Loss(BaseLoss): 37 | def __init__(self): 38 | super(L1Loss, self).__init__() 39 | 40 | def _forward(self, pred, target, weight): 41 | return torch.mean(weight * torch.abs(pred - target)) 42 | 43 | 44 | class L2Loss(BaseLoss): 45 | def __init__(self): 46 | super(L2Loss, self).__init__() 47 | 48 | def _forward(self, pred, target, weight): 49 | return torch.mean(weight * torch.pow(pred - target, 2)) 50 | 51 | class MSELoss(BaseLoss): 52 | def __init__(self): 53 | super(MSELoss, self).__init__() 54 | 55 | def _forward(self, pred, target): 56 | return F.mse_loss(pred, target) 57 | 58 | class BCELoss(BaseLoss): 59 | def __init__(self): 60 | super(BCELoss, self).__init__() 61 | 62 | def _forward(self, pred, target, weight): 63 | return F.binary_cross_entropy(pred, target, weight=weight) 64 | 65 | class BCEWithLogitsLoss(BaseLoss): 66 | def __init__(self): 67 | super(BCEWithLogitsLoss, self).__init__() 68 | 69 | def _forward(self, pred, target, weight): 70 | return F.binary_cross_entropy_with_logits(pred, target, weight=weight) 71 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | import torchvision 11 | from .networks import VisualNet, AudioNet, weights_init 12 | 13 | class ModelBuilder(): 14 | # builder for visual stream 15 | def build_visual(self, weights=''): 16 | pretrained = True 17 | original_resnet = torchvision.models.resnet18(pretrained) 18 | net = VisualNet(original_resnet) 19 | 20 | if len(weights) > 0: 21 | print('Loading weights for visual stream') 22 | net.load_state_dict(torch.load(weights)) 23 | return net 24 | 25 | #builder for audio stream 26 | def build_audio(self, ngf=64, input_nc=2, output_nc=2, weights=''): 27 | #AudioNet: 5 layer UNet 28 | net = AudioNet(ngf, input_nc, output_nc) 29 | 30 | net.apply(weights_init) 31 | if len(weights) > 0: 32 | print('Loading weights for audio stream') 33 | net.load_state_dict(torch.load(weights)) 34 | return net 35 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import functools 13 | 14 | def unet_conv(input_nc, output_nc, norm_layer=nn.BatchNorm2d): 15 | downconv = nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1) 16 | downrelu = nn.LeakyReLU(0.2, True) 17 | downnorm = norm_layer(output_nc) 18 | return nn.Sequential(*[downconv, downnorm, downrelu]) 19 | 20 | def unet_upconv(input_nc, output_nc, outermost=False, norm_layer=nn.BatchNorm2d): 21 | upconv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1) 22 | uprelu = nn.ReLU(True) 23 | upnorm = norm_layer(output_nc) 24 | if not outermost: 25 | return nn.Sequential(*[upconv, upnorm, uprelu]) 26 | else: 27 | return nn.Sequential(*[upconv, nn.Sigmoid()]) 28 | 29 | def create_conv(input_channels, output_channels, kernel, paddings, batch_norm=True, Relu=True, stride=1): 30 | model = [nn.Conv2d(input_channels, output_channels, kernel, stride = stride, padding = paddings)] 31 | if(batch_norm): 32 | model.append(nn.BatchNorm2d(output_channels)) 33 | if(Relu): 34 | model.append(nn.ReLU()) 35 | return nn.Sequential(*model) 36 | 37 | def weights_init(m): 38 | classname = m.__class__.__name__ 39 | if classname.find('Conv') != -1: 40 | m.weight.data.normal_(0.0, 0.02) 41 | elif classname.find('BatchNorm2d') != -1: 42 | m.weight.data.normal_(1.0, 0.02) 43 | m.bias.data.fill_(0) 44 | elif classname.find('Linear') != -1: 45 | m.weight.data.normal_(0.0, 0.02) 46 | 47 | class VisualNet(nn.Module): 48 | def __init__(self, original_resnet): 49 | super(VisualNet, self).__init__() 50 | layers = list(original_resnet.children())[0:-2] 51 | self.feature_extraction = nn.Sequential(*layers) #features before conv1x1 52 | 53 | def forward(self, x): 54 | x = self.feature_extraction(x) 55 | return x 56 | 57 | class AudioNet(nn.Module): 58 | def __init__(self, ngf=64, input_nc=2, output_nc=2): 59 | super(AudioNet, self).__init__() 60 | #initialize layers 61 | self.audionet_convlayer1 = unet_conv(input_nc, ngf) 62 | self.audionet_convlayer2 = unet_conv(ngf, ngf * 2) 63 | self.audionet_convlayer3 = unet_conv(ngf * 2, ngf * 4) 64 | self.audionet_convlayer4 = unet_conv(ngf * 4, ngf * 8) 65 | self.audionet_convlayer5 = unet_conv(ngf * 8, ngf * 8) 66 | self.audionet_upconvlayer1 = unet_upconv(1296, ngf * 8) #1296 (audio-visual feature) = 784 (visual feature) + 512 (audio feature) 67 | self.audionet_upconvlayer2 = unet_upconv(ngf * 16, ngf *4) 68 | self.audionet_upconvlayer3 = unet_upconv(ngf * 8, ngf * 2) 69 | self.audionet_upconvlayer4 = unet_upconv(ngf * 4, ngf) 70 | self.audionet_upconvlayer5 = unet_upconv(ngf * 2, output_nc, True) #outermost layer use a sigmoid to bound the mask 71 | self.conv1x1 = create_conv(512, 8, 1, 0) #reduce dimension of extracted visual features 72 | 73 | def forward(self, x, visual_feat): 74 | audio_conv1feature = self.audionet_convlayer1(x) 75 | audio_conv2feature = self.audionet_convlayer2(audio_conv1feature) 76 | audio_conv3feature = self.audionet_convlayer3(audio_conv2feature) 77 | audio_conv4feature = self.audionet_convlayer4(audio_conv3feature) 78 | audio_conv5feature = self.audionet_convlayer5(audio_conv4feature) 79 | 80 | visual_feat = self.conv1x1(visual_feat) 81 | visual_feat = visual_feat.view(visual_feat.shape[0], -1, 1, 1) #flatten visual feature 82 | visual_feat = visual_feat.repeat(1, 1, audio_conv5feature.shape[-2], audio_conv5feature.shape[-1]) #tile visual feature 83 | 84 | audioVisual_feature = torch.cat((visual_feat, audio_conv5feature), dim=1) 85 | 86 | audio_upconv1feature = self.audionet_upconvlayer1(audioVisual_feature) 87 | audio_upconv2feature = self.audionet_upconvlayer2(torch.cat((audio_upconv1feature, audio_conv4feature), dim=1)) 88 | audio_upconv3feature = self.audionet_upconvlayer3(torch.cat((audio_upconv2feature, audio_conv3feature), dim=1)) 89 | audio_upconv4feature = self.audionet_upconvlayer4(torch.cat((audio_upconv3feature, audio_conv2feature), dim=1)) 90 | mask_prediction = self.audionet_upconvlayer5(torch.cat((audio_upconv4feature, audio_conv1feature), dim=1)) * 2 - 1 91 | return mask_prediction 92 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/2.5D-Visual-Sound/73c73755bd8faa4478c1ab09441bc1e77fe7f1e7/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import os 11 | from util import util 12 | import torch 13 | 14 | class BaseOptions(): 15 | def __init__(self): 16 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | self.initialized = False 18 | 19 | def initialize(self): 20 | self.parser.add_argument('--hdf5FolderPath', help='path to the folder that contains train.h5, val.h5 and test.h5') 21 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 22 | self.parser.add_argument('--name', type=str, default='spatialAudioVisual', help='name of the experiment. It decides where to store models') 23 | self.parser.add_argument('--checkpoints_dir', type=str, default='checkpoints/', help='models are saved here') 24 | self.parser.add_argument('--model', type=str, default='audioVisual', help='chooses how datasets are loaded.') 25 | self.parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 26 | self.parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data') 27 | self.parser.add_argument('--audio_sampling_rate', default=16000, type=int, help='audio sampling rate') 28 | self.parser.add_argument('--audio_length', default=0.63, type=float, help='audio length, default 0.63s') 29 | self.enable_data_augmentation = True 30 | self.initialized = True 31 | 32 | def parse(self): 33 | if not self.initialized: 34 | self.initialize() 35 | self.opt = self.parser.parse_args() 36 | 37 | self.opt.mode = self.mode 38 | self.opt.isTrain = self.isTrain 39 | self.opt.enable_data_augmentation = self.enable_data_augmentation 40 | 41 | str_ids = self.opt.gpu_ids.split(',') 42 | self.opt.gpu_ids = [] 43 | for str_id in str_ids: 44 | id = int(str_id) 45 | if id >= 0: 46 | self.opt.gpu_ids.append(id) 47 | 48 | # set gpu ids 49 | if len(self.opt.gpu_ids) > 0: 50 | torch.cuda.set_device(self.opt.gpu_ids[0]) 51 | 52 | 53 | #I should process the opt here, like gpu ids, etc. 54 | args = vars(self.opt) 55 | print('------------ Options -------------') 56 | for k, v in sorted(args.items()): 57 | print('%s: %s' % (str(k), str(v))) 58 | print('-------------- End ----------------') 59 | 60 | 61 | # save to the disk 62 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 63 | util.mkdirs(expr_dir) 64 | file_name = os.path.join(expr_dir, 'opt.txt') 65 | with open(file_name, 'wt') as opt_file: 66 | opt_file.write('------------ Options -------------\n') 67 | for k, v in sorted(args.items()): 68 | opt_file.write('%s: %s\n' % (str(k), str(v))) 69 | opt_file.write('-------------- End ----------------\n') 70 | return self.opt 71 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from .base_options import BaseOptions 10 | 11 | class TestOptions(BaseOptions): 12 | def initialize(self): 13 | BaseOptions.initialize(self) 14 | 15 | self.parser.add_argument('--input_audio_path', required=True, help='path to the input audio file') 16 | self.parser.add_argument('--video_frame_path', required=True, help='path to the input video frames') 17 | self.parser.add_argument('--output_dir_root', type=str, default='test_output', help='path to the output files') 18 | self.parser.add_argument('--input_audio_length', type=float, default=10, help='length of the testing video/audio') 19 | self.parser.add_argument('--hop_size', default=0.05, type=float, help='the hop length to perform audio spatialization in a sliding window approach') 20 | 21 | #model arguments 22 | self.parser.add_argument('--weights_visual', type=str, default='', help="weights for visual stream") 23 | self.parser.add_argument('--weights_audio', type=str, default='', help="weights for audio stream") 24 | self.parser.add_argument('--unet_ngf', type=int, default=64, help="unet base channel dimension") 25 | self.parser.add_argument('--unet_input_nc', type=int, default=2, help="input spectrogram number of channels") 26 | self.parser.add_argument('--unet_output_nc', type=int, default=2, help="output spectrogram number of channels") 27 | 28 | self.mode = "test" 29 | self.isTrain = False 30 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from .base_options import BaseOptions 10 | 11 | class TrainOptions(BaseOptions): 12 | def initialize(self): 13 | BaseOptions.initialize(self) 14 | self.parser.add_argument('--display_freq', type=int, default=50, help='frequency of displaying average loss') 15 | self.parser.add_argument('--save_epoch_freq', type=int, default=50, help='frequency of saving checkpoints at the end of epochs') 16 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 17 | self.parser.add_argument('--niter', type=int, default=1000, help='# of epochs to train') 18 | self.parser.add_argument('--learning_rate_decrease_itr', type=int, default=-1, help='how often is the learning rate decreased by six percent') 19 | self.parser.add_argument('--decay_factor', type=float, default=0.94, help='learning rate decay factor') 20 | self.parser.add_argument('--tensorboard', type=bool, default=False, help='use tensorboard to visualize loss change ') 21 | self.parser.add_argument('--measure_time', type=bool, default=False, help='measure time of different steps during training') 22 | self.parser.add_argument('--validation_on', action='store_true', help='whether to test on validation set during training') 23 | self.parser.add_argument('--validation_freq', type=int, default=100, help='frequency of testing on validation set') 24 | self.parser.add_argument('--validation_batches', type=int, default=10, help='number of batches to test for validation') 25 | self.parser.add_argument('--enable_data_augmentation', type=bool, default=True, help='whether to augment input frame') 26 | 27 | #model arguments 28 | self.parser.add_argument('--weights_visual', type=str, default='', help="weights for visual stream") 29 | self.parser.add_argument('--weights_audio', type=str, default='', help="weights for audio stream") 30 | self.parser.add_argument('--unet_ngf', type=int, default=64, help="unet base channel dimension") 31 | self.parser.add_argument('--unet_input_nc', type=int, default=2, help="input spectrogram number of channels") 32 | self.parser.add_argument('--unet_output_nc', type=int, default=2, help="output spectrogram number of channels") 33 | 34 | #optimizer arguments 35 | self.parser.add_argument('--lr_visual', type=float, default=0.0001, help='learning rate for visual stream') 36 | self.parser.add_argument('--lr_audio', type=float, default=0.001, help='learning rate for unet') 37 | self.parser.add_argument('--optimizer', default='adam', type=str, help='adam or sgd for optimization') 38 | self.parser.add_argument('--beta1', default=0.9, type=float, help='momentum for sgd, beta1 for adam') 39 | self.parser.add_argument('--weight_decay', default=0.0005, type=float, help='weights regularizer') 40 | 41 | self.mode = "train" 42 | self.isTrain = True 43 | self.enable_data_augmentation = True 44 | -------------------------------------------------------------------------------- /reEncodeAudio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import os 11 | import glob 12 | import argparse 13 | import numpy as np 14 | import resampy 15 | from scikits.audiolab import Sndfile, Format 16 | 17 | def load_wav(fname, rate=None): 18 | fp = Sndfile(fname, 'r') 19 | _signal = fp.read_frames(fp.nframes) 20 | _signal = _signal.reshape((-1, fp.channels)) 21 | _rate = fp.samplerate 22 | 23 | if _signal.ndim == 1: 24 | _signal.reshape((-1, 1)) 25 | if rate is not None and rate != _rate: 26 | signal = resampy.resample(_signal, _rate, rate, axis=0, filter='kaiser_best') 27 | else: 28 | signal = _signal 29 | rate = _rate 30 | 31 | return signal, rate 32 | 33 | def save_wav(fname, signal, rate): 34 | fp = Sndfile(fname, 'w', Format('wav'), signal.shape[1], rate) 35 | fp.write_frames(signal) 36 | fp.close() 37 | 38 | def reEncodeAudio(audio_path, new_rate): 39 | audio, audio_rate = load_wav(audio_path,new_rate) 40 | save_wav(audio_path, audio, new_rate) 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description="re-encode all audios under a directory") 44 | parser.add_argument("--audio_dir_path", type=str, required=True) 45 | parser.add_argument("--new_rate", type=int, default=16000) 46 | args = parser.parse_args() 47 | 48 | audio_list = glob.glob(args.audio_dir_path + '/*.wav') 49 | print "Total number of audios to re-encode: ", len(audio_list) 50 | for audio_path in audio_list: 51 | reEncodeAudio(os.path.join(args.audio_dir_path, audio_path), args.new_rate) 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | import time 11 | import torch 12 | from options.train_options import TrainOptions 13 | from data.data_loader import CreateDataLoader 14 | from models.models import ModelBuilder 15 | from models.audioVisual_model import AudioVisualModel 16 | from torch.autograd import Variable 17 | from tensorboardX import SummaryWriter 18 | 19 | def create_optimizer(nets, opt): 20 | (net_visual, net_audio) = nets 21 | param_groups = [{'params': net_visual.parameters(), 'lr': opt.lr_visual}, 22 | {'params': net_audio.parameters(), 'lr': opt.lr_audio}] 23 | if opt.optimizer == 'sgd': 24 | return torch.optim.SGD(param_groups, momentum=opt.beta1, weight_decay=opt.weight_decay) 25 | elif opt.optimizer == 'adam': 26 | return torch.optim.Adam(param_groups, betas=(opt.beta1,0.999), weight_decay=opt.weight_decay) 27 | 28 | def decrease_learning_rate(optimizer, decay_factor=0.94): 29 | for param_group in optimizer.param_groups: 30 | param_group['lr'] *= decay_factor 31 | 32 | #used to display validation loss 33 | def display_val(model, loss_criterion, writer, index, dataset_val, opt): 34 | losses = [] 35 | with torch.no_grad(): 36 | for i, val_data in enumerate(dataset_val): 37 | if i < opt.validation_batches: 38 | output = model.forward(val_data) 39 | loss = loss_criterion(output['binaural_spectrogram'], output['audio_gt']) 40 | losses.append(loss.item()) 41 | else: 42 | break 43 | avg_loss = sum(losses)/len(losses) 44 | if opt.tensorboard: 45 | writer.add_scalar('data/val_loss', avg_loss, index) 46 | print('val loss: %.3f' % avg_loss) 47 | return avg_loss 48 | 49 | #parse arguments 50 | opt = TrainOptions().parse() 51 | opt.device = torch.device("cuda") 52 | 53 | #construct data loader 54 | data_loader = CreateDataLoader(opt) 55 | dataset = data_loader.load_data() 56 | dataset_size = len(data_loader) 57 | print('#training clips = %d' % dataset_size) 58 | 59 | #create validation set data loader if validation_on option is set 60 | if opt.validation_on: 61 | #temperally set to val to load val data 62 | opt.mode = 'val' 63 | data_loader_val = CreateDataLoader(opt) 64 | dataset_val = data_loader_val.load_data() 65 | dataset_size_val = len(data_loader_val) 66 | print('#validation clips = %d' % dataset_size_val) 67 | opt.mode = 'train' #set it back 68 | 69 | if opt.tensorboard: 70 | from tensorboardX import SummaryWriter 71 | writer = SummaryWriter(comment=opt.name) 72 | else: 73 | writer = None 74 | 75 | # network builders 76 | builder = ModelBuilder() 77 | net_visual = builder.build_visual(weights=opt.weights_visual) 78 | net_audio = builder.build_audio( 79 | ngf=opt.unet_ngf, 80 | input_nc=opt.unet_input_nc, 81 | output_nc=opt.unet_output_nc, 82 | weights=opt.weights_audio) 83 | nets = (net_visual, net_audio) 84 | 85 | # construct our audio-visual model 86 | model = AudioVisualModel(nets, opt) 87 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 88 | model.to(opt.device) 89 | 90 | # set up optimizer 91 | optimizer = create_optimizer(nets, opt) 92 | 93 | # set up loss function 94 | loss_criterion = torch.nn.MSELoss() 95 | if(len(opt.gpu_ids) > 0): 96 | loss_criterion.cuda(opt.gpu_ids[0]) 97 | 98 | # initialization 99 | total_steps = 0 100 | data_loading_time = [] 101 | model_forward_time = [] 102 | model_backward_time = [] 103 | batch_loss = [] 104 | best_err = float("inf") 105 | 106 | for epoch in range(1, opt.niter+1): 107 | torch.cuda.synchronize() 108 | epoch_start_time = time.time() 109 | 110 | if(opt.measure_time): 111 | iter_start_time = time.time() 112 | for i, data in enumerate(dataset): 113 | if(opt.measure_time): 114 | torch.cuda.synchronize() 115 | iter_data_loaded_time = time.time() 116 | 117 | total_steps += opt.batchSize 118 | 119 | # forward pass 120 | model.zero_grad() 121 | output = model.forward(data) 122 | 123 | # compute loss 124 | loss = loss_criterion(output['binaural_spectrogram'], Variable(output['audio_gt'], requires_grad=False)) 125 | batch_loss.append(loss.item()) 126 | 127 | if(opt.measure_time): 128 | torch.cuda.synchronize() 129 | iter_data_forwarded_time = time.time() 130 | 131 | # update optimizer 132 | optimizer.zero_grad() 133 | loss.backward() 134 | optimizer.step() 135 | 136 | if(opt.measure_time): 137 | iter_model_backwarded_time = time.time() 138 | data_loading_time.append(iter_data_loaded_time - iter_start_time) 139 | model_forward_time.append(iter_data_forwarded_time - iter_data_loaded_time) 140 | model_backward_time.append(iter_model_backwarded_time - iter_data_forwarded_time) 141 | 142 | if(total_steps // opt.batchSize % opt.display_freq == 0): 143 | print('Display training progress at (epoch %d, total_steps %d)' % (epoch, total_steps)) 144 | avg_loss = sum(batch_loss) / len(batch_loss) 145 | print('Average loss: %.3f' % (avg_loss)) 146 | batch_loss = [] 147 | if opt.tensorboard: 148 | writer.add_scalar('data/loss', avg_loss, total_steps) 149 | if(opt.measure_time): 150 | print('average data loading time: ' + str(sum(data_loading_time)/len(data_loading_time))) 151 | print('average forward time: ' + str(sum(model_forward_time)/len(model_forward_time))) 152 | print('average backward time: ' + str(sum(model_backward_time)/len(model_backward_time))) 153 | data_loading_time = [] 154 | model_forward_time = [] 155 | model_backward_time = [] 156 | print('end of display \n') 157 | 158 | if(total_steps // opt.batchSize % opt.save_latest_freq == 0): 159 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 160 | torch.save(net_visual.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'visual_latest.pth')) 161 | torch.save(net_audio.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'audio_latest.pth')) 162 | 163 | if(total_steps // opt.batchSize % opt.validation_freq == 0 and opt.validation_on): 164 | model.eval() 165 | opt.mode = 'val' 166 | print('Display validation results at (epoch %d, total_steps %d)' % (epoch, total_steps)) 167 | val_err = display_val(model, loss_criterion, writer, total_steps, dataset_val, opt) 168 | print('end of display \n') 169 | model.train() 170 | opt.mode = 'train' 171 | #save the model that achieves the smallest validation error 172 | if val_err < best_err: 173 | best_err = val_err 174 | print('saving the best model (epoch %d, total_steps %d) with validation error %.3f\n' % (epoch, total_steps, val_err)) 175 | torch.save(net_visual.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'visual_best.pth')) 176 | torch.save(net_audio.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, 'audio_best.pth')) 177 | 178 | if(opt.measure_time): 179 | iter_start_time = time.time() 180 | 181 | if(epoch % opt.save_epoch_freq == 0): 182 | print('saving the model at the end of epoch %d, total_steps %d' % (epoch, total_steps)) 183 | torch.save(net_visual.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, str(epoch) + '_visual.pth')) 184 | torch.save(net_audio.state_dict(), os.path.join('.', opt.checkpoints_dir, opt.name, str(epoch) + '_audio.pth')) 185 | 186 | #decrease learning rate 6% every opt.learning_rate_decrease_itr epochs 187 | if(opt.learning_rate_decrease_itr > 0 and epoch % opt.learning_rate_decrease_itr == 0): 188 | decrease_learning_rate(optimizer, opt.decay_factor) 189 | print('decreased learning rate by ', opt.decay_factor) 190 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/2.5D-Visual-Sound/73c73755bd8faa4478c1ab09441bc1e77fe7f1e7/util/__init__.py -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | 11 | def mkdirs(paths): 12 | if isinstance(paths, list) and not isinstance(paths, str): 13 | for path in paths: 14 | mkdir(path) 15 | else: 16 | mkdir(paths) 17 | 18 | def mkdir(path): 19 | if not os.path.exists(path): 20 | os.makedirs(path) 21 | --------------------------------------------------------------------------------