├── .gitignore ├── LICENSE ├── README.md ├── evaluations ├── AudioCLIP │ ├── __init__.py │ ├── demo.py │ ├── get_embedding.py │ ├── ignite_trainer │ │ ├── README.md │ │ ├── __init__.py │ │ ├── _interfaces.py │ │ ├── _trainer.py │ │ ├── _utils.py │ │ ├── _visdom.py │ │ └── version.py │ ├── main.py │ ├── model │ │ ├── __init__.py │ │ ├── audioclip.py │ │ ├── clip │ │ │ ├── __init__.py │ │ │ ├── clip.py │ │ │ └── model.py │ │ └── esresnet │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── base.py │ │ │ └── fbsp.py │ ├── protocols │ │ ├── audioclip-esc50.json │ │ └── audioclip-us8k.json │ └── utils │ │ ├── __init__.py │ │ ├── datasets │ │ ├── __init__.py │ │ ├── esc50.py │ │ └── us8k.py │ │ ├── simple_tokenizer.py │ │ └── transforms.py ├── c3d │ ├── __init__.py │ └── c3d_ft.py ├── compute_fvd_kvd.py ├── compute_image_is.py ├── compute_video_is.py ├── evaluator.py ├── fvd │ ├── __init__.py │ ├── convert_tf_pretrained.py │ ├── download.py │ ├── fvd.py │ └── pytorch_i3d.py └── util.py ├── fig ├── MM-UNet2.png ├── aist++.mp4 ├── audioset.mp4 ├── landscape.mp4 └── teaser.png ├── mm_diffusion ├── __init__.py ├── common.py ├── dist_util.py ├── dpm_solver_plus.py ├── evaluator.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── image_unet.py ├── logger.py ├── losses.py ├── multimodal_datasets.py ├── multimodal_dpm_solver_plus.py ├── multimodal_gaussian_diffusion.py ├── multimodal_respace.py ├── multimodal_script_util.py ├── multimodal_train_util.py ├── multimodal_unet.py ├── nn.py ├── optimization.py ├── real_image_datasets.py ├── resample.py ├── respace.py ├── script_util.py └── train_util.py ├── py_scripts ├── audio2video_sample_sr.py ├── eval.py ├── image_sr_train.py ├── multimodal_sample_sr.py ├── multimodal_train.py └── video2audio_sample.py ├── requirement.txt └── ssh_scripts ├── audio2video_sample_sr.sh ├── image_sr_train.sh ├── multimodal_eval.sh ├── multimodal_sample_sr.sh ├── multimodal_train.sh └── video2audio_sample.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Multimedia Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MM-Diffusion(CVPR 2023) 2 | This is the official PyTorch implementation of the paper [MM-Diffusion: Learning Multi-Modal Diffusion Models for Joint Audio and Video Generation](https://arxiv.org/abs/2212.09478), which is accpted by CVPR 2023. 3 | 4 | ## Contents 5 | - [MM-Diffusion(CVPR 2023)](#mm-diffusioncvpr-2023) 6 | - [Contents](#contents) 7 | - [Introduction](#introduction) 8 | - [Overview](#overview) 9 | - [Visual](#visual) 10 | - [Requirements and dependencies](#requirements-and-dependencies) 11 | - [Models](#models) 12 | - [Datasets](#datasets) 13 | - [Test](#test) 14 | - [Train](#train) 15 | - [Conditional Generation](#conditional-generation) 16 | - [Related projects](#related-projects) 17 | - [Citation](#citation) 18 | - [Contact](#contact) 19 | 20 | ## Introduction 21 | We propose the first joint audio-video generation framework named MM-Diffusion that brings engaging watching and listening experiences simultaneously, towards high-quality realistic videos. MM-Diffusion consists of a sequential multi-modal U-Net. Two subnets for audio and video learn to gradually generate aligned audio-video pairs from Gaussian noises. 22 | 23 | 24 | 25 | 26 | ### Overview 27 | 28 | 29 | 30 | ### Visual 31 | The generated audio-video examples on landscape: 32 | 33 | https://user-images.githubusercontent.com/105475691/207589456-52914a01-1175-4f77-b8f5-112d97013f7c.mp4 34 | 35 | The generated audio-video examples on AIST++: 36 | 37 | https://user-images.githubusercontent.com/105475691/207589611-fe300424-e5e6-4379-a917-d9a07e9dd8fb.mp4 38 | 39 | The generated audio-video examples on Audioset: 40 | 41 | https://user-images.githubusercontent.com/105475691/207589639-0a371435-f207-4ff4-a78e-3e9c0868d523.mp4 42 | 43 | ## Requirements and dependencies 44 | * python 3.8 (recommend to use [Anaconda](https://www.anaconda.com/)) 45 | * pytorch >= 1.11.0 46 | ``` 47 | git clone https://github.com/researchmm/MM-Diffusion.git 48 | cd MM-Diffusion 49 | 50 | conda create -n mmdiffusion python=3.8 51 | conda activate mmdiffusion 52 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch-nightly -c nvidia 53 | conda install mpi4py 54 | pip install -r requirement.txt 55 | ``` 56 | ## Models 57 | Pre-trained models can be downloaded from [google drive](https://drive.google.com/drive/folders/1Mno4A3BUXELAdX4m650CJ1VfuMVlkz5p?usp=share_link), and [baidu cloud](https://pan.baidu.com/s/1vJIZCHBVlmcq9np1ytstbQ?pwd=vqon). 58 | * *Landscape.pt*: trained on landscape dataset to generate audio-video pairs. 59 | * *Landscape_SR.pt*: trained on landscape dataset to upsample frame from reolution 64x64 to 256x256. 60 | * *AIST++.pt*: trained on AIST++ dataset to generate audio-video pairs. 61 | * *AIST++_SR.pt*: trained on AIST++ dataset to upsample frame from reolution 64x64 to 256x256. 62 | * *guided-diffusion_64_256_upsampler.pt*: from [guided-diffusion](https://github.com/openai/guided-diffusion), used as initialization of image SR model. 63 | 64 | * *i3d_pretrained_400.pt*: model for evaluting videos' FVD and KVD, Manually download to ```~/.cache/mmdiffusion/``` if the automatic download procedure fails. 65 | * *AudioCLIP-Full-Training.pt*: model for evaluting audios; FAD, Manually download to ```~/.cache/mmdiffusion/``` if the automatic download procedure fails. 66 | 67 | 69 | 70 | ## Datasets 71 | 1. Landscape 72 | 2. AIST++_crop 73 | 74 | The datasets can be downloaded from [google drive](https://drive.google.com/drive/folders/14A1zaQI5EfShlv3QirgCGeNFzZBzQ3lq?usp=sharing), and [baidu cloud](https://pan.baidu.com/s/1CRUSpUzdATIN7Jt8aNDaUw?pwd=fec8). \ 75 | We only use the training set for training and evaluation. 76 | 77 | You can also run our script on your own dataset by providing the directory path with relevant videos, and the script will capture all videos under the path, regardless of how they are organized. 78 | 79 | ## Test 80 | 81 | 1. Download the pre-trained checkpoints. 82 | 2. Download the datasets: Landscape or AIST++_crop. 83 | 3. Modify relative pathes and run generation scripts to generate audio-video pairs. 84 | ``` 85 | bash ssh_scripts/multimodal_sample_sr.sh 86 | ``` 87 | 4. Modify `REF_DIR`, `SAMPLE_DIR`, `OUTPUT_DIR` and run evaluation scripts. 88 | ``` 89 | bash ssh_scripts/multimodal_eval.sh 90 | ``` 91 | 92 | ## Train 93 | 94 | 1. Prepare training datasets: Landscape or AIST++_crop. 95 | 2. Download datasets: Landscape or AIST++_crop 96 | ``` 97 | # Traning Base model 98 | bash ssh_scripts/multimodal_train.sh 99 | 100 | # Training Upsampler from 64x64 -> 256x256, first extract videos into frames for SR training, 101 | bash ssh_scripts/image_sr_train.sh 102 | ``` 103 | 104 | ## Conditional Generation 105 | ``` 106 | # zero-shot conditional generation: audio-to-video 107 | bash ssh_scripts/audio2video_sample_sr.sh 108 | 109 | # zero-shot conditional generation: video-to-audio 110 | bash ssh_scripts/video2audio_sample.sh 111 | ``` 112 | ## Related projects 113 | We also sincerely recommend some other excellent works related to us. :sparkles: 114 | * [Diffusion Models Beat GANS on Image Synthesis](https://github.com/openai/guided-diffusion) 115 | * [AudioCLIP: Extending CLIP to Image, Text and Audio](https://github.com/AndreyGuzhov/AudioCLIP) 116 | * [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://github.com/LuChengTHU/dpm-solver) 117 | 118 | ## Citation 119 | If you find our work useful for your research, please consider citing our paper. :blush: 120 | ``` 121 | @inproceedings{ruan2022mmdiffusion, 122 | author = {Ruan, Ludan and Ma, Yiyang and Yang, Huan and He, Huiguo and Liu, Bei and Fu, Jianlong and Yuan, Nicholas Jing and Jin, Qin and Guo, Baining}, 123 | title = {MM-Diffusion: Learning Multi-Modal Diffusion Models for Joint Audio and Video Generation}, 124 | year = {2023}, 125 | booktitle = {CVPR}, 126 | } 127 | ``` 128 | 129 | ## Contact 130 | If you meet any problems, please describe them in issues or contact: 131 | * Ludan Ruan: 132 | * Huan Yang: 133 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/evaluations/AudioCLIP/__init__.py -------------------------------------------------------------------------------- /evaluations/AudioCLIP/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | import librosa 6 | import librosa.display 7 | 8 | import simplejpeg 9 | import numpy as np 10 | 11 | import torch 12 | import torchvision as tv 13 | 14 | 15 | from PIL import Image 16 | 17 | sys.path.append(os.path.abspath(f'{os.getcwd()}')) 18 | 19 | from model import AudioCLIP 20 | from utils.transforms import ToTensor1D 21 | 22 | 23 | torch.set_grad_enabled(False) 24 | 25 | MODEL_FILENAME = 'AudioCLIP-Full-Training.pt' 26 | # derived from ESResNeXt 27 | SAMPLE_RATE = 44100 28 | # derived from CLIP 29 | IMAGE_SIZE = 224 30 | IMAGE_MEAN = 0.48145466, 0.4578275, 0.40821073 31 | IMAGE_STD = 0.26862954, 0.26130258, 0.27577711 32 | 33 | LABELS = ['cat', 'thunderstorm', 'coughing', 'alarm clock', 'car horn'] 34 | 35 | 36 | aclp = AudioCLIP(pretrained=f'assets/{MODEL_FILENAME}') 37 | print(f"model_parameters:{sum(p.sum() for p in aclp.parameters())}") 38 | audio_transforms = ToTensor1D() 39 | 40 | image_transforms = tv.transforms.Compose([ 41 | tv.transforms.ToTensor(), 42 | tv.transforms.Resize(IMAGE_SIZE, interpolation=Image.BICUBIC), 43 | tv.transforms.CenterCrop(IMAGE_SIZE), 44 | tv.transforms.Normalize(IMAGE_MEAN, IMAGE_STD) 45 | ]) 46 | 47 | paths_to_audio = glob.glob('demo/audio/*.wav') 48 | 49 | audio = list() 50 | for path_to_audio in paths_to_audio: 51 | track, _ = librosa.load(path_to_audio, sr=SAMPLE_RATE, dtype=np.float32) 52 | 53 | # compute spectrograms using trained audio-head (fbsp-layer of ESResNeXt) 54 | # thus, the actual time-frequency representation will be visualized 55 | # spec = aclp.audio.spectrogram(torch.from_numpy(track.reshape(1, 1, -1))) 56 | # spec = np.ascontiguousarray(spec.numpy()).view(np.complex64) 57 | # pow_spec = 10 * np.log10(np.abs(spec) ** 2 + 1e-18).squeeze() 58 | 59 | audio.append((track, _)) 60 | 61 | paths_to_images = glob.glob('demo/images/*.jpg') 62 | 63 | images = list() 64 | for path_to_image in paths_to_images: 65 | with open(path_to_image, 'rb') as jpg: 66 | image = simplejpeg.decode_jpeg(jpg.read()) 67 | images.append(image) 68 | # AudioCLIP handles raw audio on input, so the input shape is [batch x channels x duration] 69 | audio = torch.stack([audio_transforms(track.reshape(1, -1)) for track, _ in audio]) 70 | # standard channel-first shape [batch x channels x height x width] 71 | import pdb; pdb.set_trace() 72 | images = torch.stack([image_transforms(image) for image in images]) 73 | 74 | # textual input is processed internally, so no need to transform it beforehand 75 | text = [[label] for label in LABELS] 76 | 77 | # AudioCLIP's output: Tuple[Tuple[Features, Logits], Loss] 78 | # Features = Tuple[AudioFeatures, ImageFeatures, TextFeatures] 79 | # Logits = Tuple[AudioImageLogits, AudioTextLogits, ImageTextLogits] 80 | 81 | ((audio_features, _, _), _), _ = aclp(audio=audio) 82 | ((_, image_features, _), _), _ = aclp(image=images) 83 | ((_, _, text_features), _), _ = aclp(text=text) 84 | 85 | audio_features = audio_features / torch.linalg.norm(audio_features, dim=-1, keepdim=True) 86 | image_features = image_features / torch.linalg.norm(image_features, dim=-1, keepdim=True) 87 | text_features = text_features / torch.linalg.norm(text_features, dim=-1, keepdim=True) 88 | 89 | scale_audio_image = torch.clamp(aclp.logit_scale_ai.exp(), min=1.0, max=100.0) 90 | scale_audio_text = torch.clamp(aclp.logit_scale_at.exp(), min=1.0, max=100.0) 91 | scale_image_text = torch.clamp(aclp.logit_scale.exp(), min=1.0, max=100.0) 92 | 93 | logits_audio_image = scale_audio_image * audio_features @ image_features.T 94 | logits_audio_text = scale_audio_text * audio_features @ text_features.T 95 | logits_image_text = scale_image_text * image_features @ text_features.T 96 | 97 | print('\t\tFilename, Audio\t\t\tTextual Label (Confidence)', end='\n\n') 98 | 99 | # calculate model confidence 100 | confidence = logits_audio_text.softmax(dim=1) 101 | for audio_idx in range(len(paths_to_audio)): 102 | # acquire Top-3 most similar results 103 | conf_values, ids = confidence[audio_idx].topk(3) 104 | 105 | # format output strings 106 | query = f'{os.path.basename(paths_to_audio[audio_idx]):>30s} ->\t\t' 107 | results = ', '.join([f'{LABELS[i]:>15s} ({v:06.2%})' for v, i in zip(conf_values, ids)]) 108 | 109 | print(query + results) 110 | 111 | 112 | print('\tFilename, Image\t\t\tTextual Label (Confidence)', end='\n\n') 113 | 114 | # calculate model confidence 115 | confidence = logits_image_text.softmax(dim=1) 116 | for image_idx in range(len(paths_to_images)): 117 | # acquire Top-3 most similar results 118 | conf_values, ids = confidence[image_idx].topk(3) 119 | 120 | # format output strings 121 | query = f'{os.path.basename(paths_to_images[image_idx]):>20s} ->\t\t' 122 | results = ', '.join([f'{LABELS[i]:>20s} ({v:06.2%})' for v, i in zip(conf_values, ids)]) 123 | 124 | print(query + results) 125 | 126 | print('\t\tTextual Label\t\tFilename, Audio (Confidence)', end='\n\n') 127 | 128 | # calculate model confidence 129 | confidence = logits_audio_text.softmax(dim=0) 130 | for label_idx in range(len(LABELS)): 131 | # acquire Top-2 most similar results 132 | conf_values, ids = confidence[:, label_idx].topk(2) 133 | 134 | # format output strings 135 | query = f'{LABELS[label_idx]:>25s} ->\t\t' 136 | results = ', '.join([f'{os.path.basename(paths_to_audio[i]):>30s} ({v:06.2%})' for v, i in zip(conf_values, ids)]) 137 | 138 | print(query + results) 139 | 140 | 141 | print('\tTextual Label\t\t\tFilename, Image (Confidence)', end='\n\n') 142 | 143 | # calculate model confidence 144 | confidence = logits_image_text.softmax(dim=0) 145 | for label_idx in range(len(LABELS)): 146 | # acquire Top-3 most similar results 147 | conf_values, ids = confidence[:, label_idx].topk(3) 148 | 149 | # format output strings 150 | query = f'{LABELS[label_idx]:>20s} ->\t\t' 151 | results = ', '.join([f'{os.path.basename(paths_to_images[i]):>20s} ({v:>06.2%})' for v, i in zip(conf_values, ids)]) 152 | 153 | print(query + results) 154 | 155 | print('\tTextual Label\t\t\tFilename, Image (Confidence)', end='\n\n') 156 | 157 | # calculate model confidence 158 | confidence = logits_audio_image.softmax(dim=0) 159 | for image_idx in range(len(paths_to_images)): 160 | # acquire Top-2 most similar results 161 | conf_values, ids = confidence[:, image_idx].topk(2) 162 | 163 | # format output strings 164 | query = f'{os.path.basename(paths_to_images[image_idx]):>25s} ->\t\t' 165 | results = ', '.join([f'{os.path.basename(paths_to_audio[i]):>30s} ({v:06.2%})' for v, i in zip(conf_values, ids)]) 166 | 167 | print(query + results) 168 | 169 | 170 | print('\tTextual Label\t\t\tFilename, Image (Confidence)', end='\n\n') 171 | 172 | # calculate model confidence 173 | confidence = logits_audio_image.softmax(dim=1) 174 | for audio_idx in range(len(paths_to_audio)): 175 | # acquire Top-3 most similar results 176 | conf_values, ids = confidence[audio_idx].topk(3) 177 | 178 | # format output strings 179 | query = f'{os.path.basename(paths_to_audio[audio_idx]):>30s} ->\t\t' 180 | results = ', '.join([f'{os.path.basename(paths_to_images[i]):>15s} ({v:06.2%})' for v, i in zip(conf_values, ids)]) 181 | 182 | print(query + results) 183 | 184 | print(logits_audio_image) -------------------------------------------------------------------------------- /evaluations/AudioCLIP/get_embedding.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | sys.path.append(os.path.dirname (os.path.abspath (__file__))) 4 | from einops import rearrange 5 | import torch.distributed as dist 6 | from model import AudioCLIP 7 | from utils.transforms import ToTensor1D 8 | import torch 9 | import torchvision as tv 10 | from torchvision.transforms import InterpolationMode 11 | 12 | IMAGE_SIZE = 224 13 | IMAGE_MEAN = 0.48145466, 0.4578275, 0.40821073 14 | IMAGE_STD = 0.26862954, 0.26130258, 0.27577711 15 | 16 | AUDIO_TRANSFORM = ToTensor1D() 17 | IMAGE_TRANSFORM = tv.transforms.Compose([ 18 | tv.transforms.ToTensor(), 19 | tv.transforms.Resize(IMAGE_SIZE, interpolation=InterpolationMode.BICUBIC), 20 | tv.transforms.CenterCrop(IMAGE_SIZE), 21 | tv.transforms.Normalize(IMAGE_MEAN, IMAGE_STD) 22 | ]) 23 | torch.set_grad_enabled(False) 24 | #ROOT_PATHES=[os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../copy/assets'), "/mnt/external/code/guided-diffusion/models/"] 25 | ROOT=os.path.expanduser("~/.cache/mmdiffusion") 26 | def download(fname): 27 | destination = os.path.join(ROOT, fname) 28 | if os.path.exists(destination): 29 | return destination 30 | 31 | os.makedirs(ROOT, exist_ok=True) 32 | download_cammand = f"wget -P {ROOT} https://github.com/AndreyGuzhov/AudioCLIP/releases/download/v0.1/{fname}" 33 | 34 | os.system(download_cammand) 35 | return destination 36 | 37 | 38 | 39 | def preprocess_video(videos): 40 | # videos in {0, ..., 255} as np.uint8 array 41 | b, f, c, h, w = videos.shape 42 | # TODO: 43 | # videos = videos.float() / 255. 44 | 45 | images = rearrange(videos, "b f c h w -> (b f) h w c").to("cpu").numpy() 46 | images = torch.stack([IMAGE_TRANSFORM(image) for image in images]) 47 | 48 | videos = rearrange(images, "(b f) c h w -> b f c h w", b=b) 49 | return videos # [-0.5, 0.5] -> [-1, 1] 50 | 51 | def preprocess_audio(audios): 52 | b,c,l = audios.shape 53 | # audios = torch.stack([AUDIO_TRANSFORM(track.reshape(1, -1)) for track in audios]) 54 | 55 | return audios 56 | 57 | 58 | 59 | 60 | 61 | def load_audioclip_pretrained(device=torch.device('cpu')): 62 | if dist.get_rank()==0: 63 | filepath = download('AudioCLIP-Full-Training.pt') 64 | dist.barrier() 65 | filepath = download('AudioCLIP-Full-Training.pt') 66 | audioclip = AudioCLIP(pretrained=filepath).to(device) 67 | 68 | return audioclip 69 | 70 | def get_audioclip_embeddings_scores(aclp, videos, audios): 71 | 72 | videos = preprocess_video(videos).to(aclp.device) 73 | audios = preprocess_audio(audios).to(aclp.device) 74 | 75 | with torch.no_grad(): 76 | ((audio_features, video_features, _), (logits_audio_video,_ ,_)), _ = aclp(audio=audios, video=videos) 77 | 78 | scores_audio_video = torch.diag(logits_audio_video) 79 | return video_features, audio_features, scores_audio_video 80 | 81 | 82 | def get_audioclip_a_embeddings(aclp, audios): 83 | 84 | 85 | audios = preprocess_audio(audios).to(aclp.device) 86 | 87 | with torch.no_grad(): 88 | ((audio_features, _, _), _), _ = aclp(audio=audios) 89 | 90 | 91 | return audio_features 92 | 93 | def get_audioclip_v_embeddings(aclp, videos): 94 | videos = preprocess_video(videos).to(aclp.device) 95 | 96 | with torch.no_grad(): 97 | ((_, video_features, _), _), _ = aclp(video=videos) 98 | 99 | 100 | return video_features 101 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/README.md: -------------------------------------------------------------------------------- 1 | # Training Wrapper 2 | 3 | Utility code to run training and evaluation of the model. 4 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os as _os 2 | import sys as _sys 3 | 4 | from ignite_trainer.version import __version__ 5 | from ._trainer import main, run 6 | from ._utils import load_class 7 | from ._interfaces import AbstractNet, AbstractTransform 8 | 9 | __all__ = [ 10 | '__version__', 11 | 'main', 'run', 12 | 'load_class', 13 | 'AbstractNet', 'AbstractTransform' 14 | ] 15 | 16 | _sys.path.extend([_os.getcwd()]) 17 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/_interfaces.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | from typing import Tuple 5 | from typing import Union 6 | from typing import Callable 7 | from typing import Optional 8 | 9 | 10 | TensorPair = Tuple[torch.Tensor, torch.Tensor] 11 | TensorOrTwo = Union[torch.Tensor, TensorPair] 12 | 13 | 14 | class AbstractNet(abc.ABC, torch.nn.Module): 15 | 16 | @abc.abstractmethod 17 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> TensorOrTwo: 18 | pass 19 | 20 | @abc.abstractmethod 21 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 22 | pass 23 | 24 | @property 25 | @abc.abstractmethod 26 | def loss_fn_name(self) -> str: 27 | pass 28 | 29 | 30 | class AbstractTransform(abc.ABC, Callable[[torch.Tensor], torch.Tensor]): 31 | 32 | @abc.abstractmethod 33 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 34 | pass 35 | 36 | def __repr__(self): 37 | return self.__class__.__name__ + '()' 38 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/_utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | import json 4 | import tqdm 5 | import datetime 6 | import importlib 7 | import contextlib 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data as td 13 | 14 | import torchvision as tv 15 | 16 | from PIL import Image 17 | 18 | from collections import OrderedDict 19 | 20 | from typing import Any 21 | from typing import Dict 22 | from typing import List 23 | from typing import Type 24 | from typing import Tuple 25 | from typing import Union 26 | from typing import Callable 27 | from typing import Optional 28 | 29 | 30 | @contextlib.contextmanager 31 | def tqdm_stdout(orig_stdout: Optional[io.TextIOBase] = None): 32 | 33 | class DummyFile(object): 34 | file = None 35 | 36 | def __init__(self, file): 37 | self.file = file 38 | 39 | def write(self, x): 40 | if len(x.rstrip()) > 0: 41 | tqdm.tqdm.write(x, file=self.file) 42 | 43 | def flush(self): 44 | return getattr(self.file, 'flush', lambda: None)() 45 | 46 | orig_out_err = sys.stdout, sys.stderr 47 | 48 | try: 49 | if orig_stdout is None: 50 | sys.stdout, sys.stderr = map(DummyFile, orig_out_err) 51 | yield orig_out_err[0] 52 | else: 53 | yield orig_stdout 54 | except Exception as exc: 55 | raise exc 56 | finally: 57 | sys.stdout, sys.stderr = orig_out_err 58 | 59 | 60 | def load_class(package_name: str, class_name: Optional[str] = None) -> Type: 61 | if class_name is None: 62 | package_name, class_name = package_name.rsplit('.', 1) 63 | 64 | importlib.invalidate_caches() 65 | 66 | package = importlib.import_module(package_name) 67 | cls = getattr(package, class_name) 68 | 69 | return cls 70 | 71 | 72 | def arg_selector(arg_cmd: Optional[Any], arg_conf: Optional[Any], arg_const: Any) -> Any: 73 | if arg_cmd is not None: 74 | return arg_cmd 75 | else: 76 | if arg_conf is not None: 77 | return arg_conf 78 | else: 79 | return arg_const 80 | 81 | 82 | def collate_fn(batch): 83 | batch_audio, batch_image, batch_text = zip(*batch) 84 | 85 | keep_ids = [idx for idx, (_, _) in enumerate(zip(batch_audio, batch_image))] 86 | 87 | if not all(audio is None for audio in batch_audio): 88 | batch_audio = [batch_audio[idx] for idx in keep_ids] 89 | batch_audio = torch.stack(batch_audio) 90 | else: 91 | batch_audio = None 92 | 93 | if not all(image is None for image in batch_image): 94 | batch_image = [batch_image[idx] for idx in keep_ids] 95 | batch_image = torch.stack(batch_image) 96 | else: 97 | batch_image = None 98 | 99 | if not all(text is None for text in batch_text): 100 | batch_text = [batch_text[idx] for idx in keep_ids] 101 | else: 102 | batch_text = None 103 | 104 | return batch_audio, batch_image, batch_text 105 | 106 | 107 | def get_data_loaders(Dataset: Type, 108 | dataset_args: Dict[str, Any], 109 | batch_train: int = 64, 110 | batch_test: int = 1024, 111 | workers_train: int = 0, 112 | workers_test: int = 0, 113 | transforms_train: Optional[Callable[ 114 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] 115 | ]] = None, 116 | transforms_test: Optional[Callable[ 117 | [Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor] 118 | ]] = None) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: 119 | 120 | dl_shuffle = dataset_args.pop('dl_shuffle', True) 121 | 122 | dataset_mode_train = {dataset_args['training']['key']: dataset_args['training']['yes']} 123 | dataset_mode_test = {dataset_args['training']['key']: dataset_args['training']['no']} 124 | 125 | dataset_args_train = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_train} 126 | dataset_args_test = {**{k: v for k, v in dataset_args.items() if k != 'training'}, **dataset_mode_test} 127 | 128 | ds_train = Dataset(**{ 129 | **dataset_args_train, 130 | **{'transform_audio': transforms_train}, 131 | **{'transform_frames': tv.transforms.Compose([ 132 | tv.transforms.ToTensor(), 133 | tv.transforms.Resize(224, interpolation=Image.BICUBIC), 134 | tv.transforms.CenterCrop(224), 135 | tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 136 | ])} 137 | }) 138 | train_loader = torch.utils.data.DataLoader( 139 | ds_train, 140 | batch_size=batch_train, 141 | shuffle=dl_shuffle, 142 | num_workers=workers_train, 143 | pin_memory=True, 144 | collate_fn=collate_fn, 145 | drop_last=True 146 | ) 147 | ds_eval = Dataset(**{ 148 | **dataset_args_test, 149 | **{'transform_audio': transforms_test}, 150 | **{'transform_frames': tv.transforms.Compose([ 151 | tv.transforms.ToTensor(), 152 | tv.transforms.Resize(224, interpolation=Image.BICUBIC), 153 | tv.transforms.CenterCrop(224), 154 | tv.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 155 | ])} 156 | }) 157 | eval_loader = torch.utils.data.DataLoader( 158 | ds_eval, 159 | batch_size=batch_test, 160 | num_workers=workers_test, 161 | pin_memory=True, 162 | collate_fn=collate_fn 163 | ) 164 | 165 | return train_loader, eval_loader 166 | 167 | 168 | def build_summary_str(experiment_name: str, 169 | model_short_name: str, 170 | model_class: str, 171 | model_args: Dict[str, Any], 172 | optimizer_class: str, 173 | optimizer_args: Dict[str, Any], 174 | dataset_class: str, 175 | dataset_args: Dict[str, Any], 176 | transforms: List[Dict[str, Union[str, Dict[str, Any]]]], 177 | epochs: int, 178 | batch_train: int, 179 | log_interval: int, 180 | saved_models_path: str, 181 | scheduler_class: Optional[str] = None, 182 | scheduler_args: Optional[Dict[str, Any]] = None) -> str: 183 | 184 | setup_title = '{}-{}'.format(experiment_name, model_short_name) 185 | 186 | summary_window_text = '

' 187 | summary_window_text += ''.format(setup_title) 188 | 189 | summary_window_text += setup_title 190 | 191 | summary_window_text += '' 192 | summary_window_text += '

' 193 | summary_window_text += '
' 194 | summary_window_text += '' 219 | summary_window_text += '
' 220 | 221 | return summary_window_text 222 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/_visdom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import tqdm 6 | import socket 7 | import subprocess 8 | import numpy as np 9 | 10 | import visdom 11 | 12 | from typing import Tuple 13 | from typing import Optional 14 | 15 | 16 | def calc_ytick_range(vis: visdom.Visdom, window_name: str, env: Optional[str] = None) -> Tuple[float, float]: 17 | lower_bound, upper_bound = -1.0, 1.0 18 | 19 | stats = vis.get_window_data(win=window_name, env=env) 20 | 21 | if stats: 22 | stats = json.loads(stats) 23 | 24 | stats = [np.array(item['y']) for item in stats['content']['data']] 25 | stats = [item[item != np.array([None])].astype(np.float16) for item in stats] 26 | 27 | if stats: 28 | q25s = np.array([np.quantile(item, 0.25) for item in stats if len(item) > 0]) 29 | q75s = np.array([np.quantile(item, 0.75) for item in stats if len(item) > 0]) 30 | 31 | if q25s.shape == q75s.shape and len(q25s) > 0: 32 | iqrs = q75s - q25s 33 | 34 | lower_bounds = q25s - 1.5 * iqrs 35 | upper_bounds = q75s + 1.5 * iqrs 36 | 37 | stats_sanitized = list() 38 | idx = 0 39 | for item in stats: 40 | if len(item) > 0: 41 | item_sanitized = item[(item >= lower_bounds[idx]) & (item <= upper_bounds[idx])] 42 | stats_sanitized.append(item_sanitized) 43 | 44 | idx += 1 45 | 46 | stats_sanitized = np.array(stats_sanitized) 47 | 48 | q25_sanitized = np.array([np.quantile(item, 0.25) for item in stats_sanitized]) 49 | q75_sanitized = np.array([np.quantile(item, 0.75) for item in stats_sanitized]) 50 | 51 | iqr_sanitized = np.sum(q75_sanitized - q25_sanitized) 52 | lower_bound = np.min(q25_sanitized) - 1.5 * iqr_sanitized 53 | upper_bound = np.max(q75_sanitized) + 1.5 * iqr_sanitized 54 | 55 | return lower_bound, upper_bound 56 | 57 | 58 | def plot_line(vis: visdom.Visdom, 59 | window_name: str, 60 | env: Optional[str] = None, 61 | line_label: Optional[str] = None, 62 | x: Optional[np.ndarray] = None, 63 | y: Optional[np.ndarray] = None, 64 | x_label: Optional[str] = None, 65 | y_label: Optional[str] = None, 66 | width: int = 576, 67 | height: int = 416, 68 | draw_marker: bool = False) -> str: 69 | 70 | empty_call = not vis.win_exists(window_name) 71 | 72 | if empty_call and (x is not None or y is not None): 73 | return window_name 74 | 75 | if x is None: 76 | x = np.ones(1) 77 | empty_call = empty_call & True 78 | 79 | if y is None: 80 | y = np.full(1, np.nan) 81 | empty_call = empty_call & True 82 | 83 | if x.shape != y.shape: 84 | x = np.ones_like(y) 85 | 86 | opts = { 87 | 'showlegend': True, 88 | 'markers': draw_marker, 89 | 'markersize': 5, 90 | } 91 | 92 | if empty_call: 93 | opts['title'] = window_name 94 | opts['width'] = width 95 | opts['height'] = height 96 | 97 | window_name = vis.line( 98 | X=x, 99 | Y=y, 100 | win=window_name, 101 | env=env, 102 | update='append', 103 | name=line_label, 104 | opts=opts 105 | ) 106 | 107 | xtickmin, xtickmax = 0.0, np.max(x) * 1.05 108 | ytickmin, ytickmax = calc_ytick_range(vis, window_name, env) 109 | 110 | opts = { 111 | 'showlegend': True, 112 | 'xtickmin': xtickmin, 113 | 'xtickmax': xtickmax, 114 | 'ytickmin': ytickmin, 115 | 'ytickmax': ytickmax, 116 | 'xlabel': x_label, 117 | 'ylabel': y_label 118 | } 119 | 120 | window_name = vis.update_window_opts(win=window_name, opts=opts, env=env) 121 | 122 | return window_name 123 | 124 | 125 | def create_summary_window(vis: visdom.Visdom, 126 | visdom_env_name: str, 127 | experiment_name: str, 128 | summary: str) -> str: 129 | 130 | return vis.text( 131 | text=summary, 132 | win=experiment_name, 133 | env=visdom_env_name, 134 | opts={'title': 'Summary', 'width': 576, 'height': 416}, 135 | append=vis.win_exists(experiment_name, visdom_env_name) 136 | ) 137 | 138 | 139 | def connection_is_alive(host: str, port: int) -> bool: 140 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 141 | try: 142 | sock.connect((host, port)) 143 | sock.shutdown(socket.SHUT_RDWR) 144 | 145 | return True 146 | except socket.error: 147 | return False 148 | 149 | 150 | def get_visdom_instance(host: str = 'localhost', 151 | port: int = 8097, 152 | env_name: str = 'main', 153 | env_path: str = 'visdom_env') -> Tuple[visdom.Visdom, Optional[int]]: 154 | 155 | vis_pid = None 156 | 157 | if not connection_is_alive(host, port): 158 | if any(host.strip('/').endswith(lh) for lh in ['127.0.0.1', 'localhost']): 159 | os.makedirs(env_path, exist_ok=True) 160 | 161 | tqdm.tqdm.write('Starting visdom on port {}'.format(port), end='') 162 | 163 | vis_args = [ 164 | sys.executable, 165 | '-m', 'visdom.server', 166 | '-port', str(port), 167 | '-env_path', os.path.join(os.getcwd(), env_path) 168 | ] 169 | vis_proc = subprocess.Popen(vis_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 170 | time.sleep(2.0) 171 | 172 | vis_pid = vis_proc.pid 173 | tqdm.tqdm.write('PID -> {}'.format(vis_pid)) 174 | 175 | trials_left = 5 176 | while not connection_is_alive(host, port): 177 | time.sleep(1.0) 178 | 179 | tqdm.tqdm.write('Trying to connect ({} left)...'.format(trials_left)) 180 | 181 | trials_left -= 1 182 | if trials_left < 1: 183 | raise RuntimeError('Visdom server is not running. Please run "python -m visdom.server".') 184 | 185 | vis = visdom.Visdom( 186 | server='http://{}'.format(host), 187 | port=port, 188 | env=env_name 189 | ) 190 | 191 | return vis, vis_pid 192 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/ignite_trainer/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.5b5' 2 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.7 2 | 3 | import ignite_trainer as it 4 | 5 | 6 | def main(): 7 | it.main() 8 | 9 | 10 | if __name__ == '__main__': 11 | main() 12 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .esresnet import * 3 | from .audioclip import AudioCLIP 4 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CLIP 2 | from .model import convert_weights 3 | 4 | 5 | __all__ = ['CLIP', 'convert_weights'] 6 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/clip/clip.py: -------------------------------------------------------------------------------- 1 | # CREDITS: https://github.com/openai/CLIP 2 | 3 | import hashlib 4 | import os 5 | import urllib 6 | import warnings 7 | from typing import Union, List 8 | 9 | import torch 10 | from PIL import Image 11 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 12 | from tqdm import tqdm 13 | 14 | from .model import build_model 15 | from utils.simple_tokenizer import SimpleTokenizer as _Tokenizer 16 | 17 | __all__ = ["available_models", "load", "tokenize"] 18 | _tokenizer = _Tokenizer() 19 | 20 | _MODELS = { 21 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 22 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 23 | } 24 | 25 | 26 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 27 | os.makedirs(root, exist_ok=True) 28 | filename = os.path.basename(url) 29 | 30 | expected_sha256 = url.split("/")[-2] 31 | download_target = os.path.join(root, filename) 32 | 33 | if os.path.exists(download_target) and not os.path.isfile(download_target): 34 | raise RuntimeError(f"{download_target} exists and is not a regular file") 35 | 36 | if os.path.isfile(download_target): 37 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 38 | return download_target 39 | else: 40 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 41 | 42 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 43 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 44 | while True: 45 | buffer = source.read(8192) 46 | if not buffer: 47 | break 48 | 49 | output.write(buffer) 50 | loop.update(len(buffer)) 51 | 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 53 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 54 | 55 | return download_target 56 | 57 | 58 | def _transform(n_px): 59 | return Compose([ 60 | Resize(n_px, interpolation=Image.BICUBIC), 61 | CenterCrop(n_px), 62 | lambda image: image.convert("RGB"), 63 | ToTensor(), 64 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 65 | ]) 66 | 67 | 68 | def available_models() -> List[str]: 69 | """Returns the names of available CLIP models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 74 | """Load a CLIP model 75 | 76 | Parameters 77 | ---------- 78 | name : str 79 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 80 | 81 | device : Union[str, torch.device] 82 | The device to put the loaded model 83 | 84 | jit : bool 85 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 86 | 87 | Returns 88 | ------- 89 | model : torch.nn.Module 90 | The CLIP model 91 | 92 | preprocess : Callable[[PIL.Image], torch.Tensor] 93 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 94 | """ 95 | if name in _MODELS: 96 | model_path = _download(_MODELS[name]) 97 | elif os.path.isfile(name): 98 | model_path = name 99 | else: 100 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 101 | 102 | try: 103 | # loading JIT archive 104 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 105 | state_dict = None 106 | except RuntimeError: 107 | # loading saved state dict 108 | if jit: 109 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 110 | jit = False 111 | state_dict = torch.load(model_path, map_location="cpu") 112 | 113 | if not jit: 114 | model = build_model(state_dict or model.state_dict()).to(device) 115 | if str(device) == "cpu": 116 | model.float() 117 | return model, _transform(model.visual.input_resolution) 118 | 119 | # patch the device names 120 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 121 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 122 | 123 | def patch_device(module): 124 | graphs = [module.graph] if hasattr(module, "graph") else [] 125 | if hasattr(module, "forward1"): 126 | graphs.append(module.forward1.graph) 127 | 128 | for graph in graphs: 129 | for node in graph.findAllNodes("prim::Constant"): 130 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 131 | node.copyAttributes(device_node) 132 | 133 | model.apply(patch_device) 134 | patch_device(model.encode_image) 135 | patch_device(model.encode_text) 136 | 137 | # patch dtype to float32 on CPU 138 | if str(device) == "cpu": 139 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 140 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 141 | float_node = float_input.node() 142 | 143 | def patch_float(module): 144 | graphs = [module.graph] if hasattr(module, "graph") else [] 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("aten::to"): 150 | inputs = list(node.inputs()) 151 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 152 | if inputs[i].node()["value"] == 5: 153 | inputs[i].node().copyAttributes(float_node) 154 | 155 | model.apply(patch_float) 156 | patch_float(model.encode_image) 157 | patch_float(model.encode_text) 158 | 159 | model.float() 160 | 161 | return model, _transform(model.input_resolution.item()) 162 | 163 | 164 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 165 | """ 166 | Returns the tokenized representation of given input string(s) 167 | 168 | Parameters 169 | ---------- 170 | texts : Union[str, List[str]] 171 | An input string or a list of input strings to tokenize 172 | 173 | context_length : int 174 | The context length to use; all CLIP models use 77 as the context length 175 | 176 | Returns 177 | ------- 178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 179 | """ 180 | if isinstance(texts, str): 181 | texts = [texts] 182 | 183 | sot_token = _tokenizer.encoder["<|startoftext|>"] 184 | eot_token = _tokenizer.encoder["<|endoftext|>"] 185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 187 | 188 | for i, tokens in enumerate(all_tokens): 189 | if len(tokens) > context_length: 190 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 191 | result[i, :len(tokens)] = torch.tensor(tokens) 192 | 193 | return result 194 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/esresnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ESResNet 2 | from .base import ESResNeXt 3 | from .fbsp import ESResNetFBSP 4 | from .fbsp import ESResNeXtFBSP 5 | from .attention import Attention2d 6 | 7 | 8 | __all__ = ['ESResNet', 'ESResNeXt', 'ESResNetFBSP', 'ESResNeXtFBSP', 'Attention2d'] 9 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/esresnet/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from typing import Tuple 5 | 6 | 7 | class Attention2d(torch.nn.Module): 8 | 9 | def __init__(self, 10 | in_channels: int, 11 | out_channels: int, 12 | num_kernels: int, 13 | kernel_size: Tuple[int, int], 14 | padding_size: Tuple[int, int]): 15 | 16 | super(Attention2d, self).__init__() 17 | 18 | self.conv_depth = torch.nn.Conv2d( 19 | in_channels=in_channels, 20 | out_channels=in_channels * num_kernels, 21 | kernel_size=kernel_size, 22 | padding=padding_size, 23 | groups=in_channels 24 | ) 25 | self.conv_point = torch.nn.Conv2d( 26 | in_channels=in_channels * num_kernels, 27 | out_channels=out_channels, 28 | kernel_size=(1, 1) 29 | ) 30 | self.bn = torch.nn.BatchNorm2d(num_features=out_channels) 31 | self.activation = torch.nn.Sigmoid() 32 | 33 | def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor: 34 | x = F.adaptive_max_pool2d(x, size) 35 | x = self.conv_depth(x) 36 | x = self.conv_point(x) 37 | x = self.bn(x) 38 | x = self.activation(x) 39 | 40 | return x 41 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/model/esresnet/fbsp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import torchvision as tv 7 | 8 | from utils import transforms 9 | from model.esresnet.base import _ESResNet 10 | from model.esresnet.base import Bottleneck 11 | 12 | from typing import cast 13 | from typing import List 14 | from typing import Tuple 15 | from typing import Union 16 | from typing import Optional 17 | 18 | 19 | class LinearFBSP(torch.nn.Module): 20 | 21 | def __init__(self, out_features: int, bias: bool = True, normalized: bool = False): 22 | super(LinearFBSP, self).__init__() 23 | 24 | self.out_features = out_features 25 | self.normalized = normalized 26 | self.eps = 1e-8 27 | 28 | default_dtype = torch.get_default_dtype() 29 | 30 | self.register_parameter('m', torch.nn.Parameter(torch.zeros(self.out_features, dtype=default_dtype))) 31 | self.register_parameter('fb', torch.nn.Parameter(torch.ones(self.out_features, dtype=default_dtype))) 32 | self.register_parameter('fc', torch.nn.Parameter(torch.arange(self.out_features, dtype=default_dtype))) 33 | self.register_parameter( 34 | 'bias', 35 | torch.nn.Parameter( 36 | torch.normal( 37 | 0.0, 0.5, (self.out_features, 2), dtype=default_dtype 38 | ) if bias else cast( 39 | torch.nn.Parameter, None 40 | ) 41 | ) 42 | ) 43 | 44 | self.m.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) 45 | self.fb.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) 46 | self.fc.register_hook(lambda grad: grad / (torch.norm(grad, p=float('inf')) + self.eps)) 47 | 48 | @staticmethod 49 | def power(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 50 | magnitudes = (x1[..., 0] ** 2 + x1[..., 1] ** 2) ** 0.5 51 | phases = x1[..., 1].atan2(x1[..., 0]) 52 | 53 | power_real = x2[..., 0] 54 | power_imag = x2[..., 1] 55 | 56 | mag_out = ((magnitudes ** 2) ** (0.5 * power_real) * torch.exp(-power_imag * phases)) 57 | 58 | return mag_out.unsqueeze(-1) * torch.stack(( 59 | (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).cos(), 60 | (power_real * phases + 0.5 * power_imag * (magnitudes ** 2).log()).sin() 61 | ), dim=-1) 62 | 63 | @staticmethod 64 | def sinc(x: torch.Tensor) -> torch.Tensor: 65 | return torch.where(cast(torch.Tensor, x == 0), torch.ones_like(x), torch.sin(x) / x) 66 | 67 | def _materialize_weights(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]: 68 | x_is_complex = x.shape[-1] == 2 69 | in_features = x.shape[-1 - int(x_is_complex)] 70 | 71 | t = np.pi * torch.linspace(-1.0, 1.0, in_features, dtype=x.dtype, device=x.device).reshape(1, -1, 1) + self.eps 72 | 73 | m = self.m.reshape(-1, 1, 1) 74 | fb = self.fb.reshape(-1, 1, 1) 75 | fc = self.fc.reshape(-1, 1, 1) 76 | 77 | kernel = torch.cat((torch.cos(fc * t), -torch.sin(fc * t)), dim=-1) # complex 78 | scale = fb.sqrt() # real 79 | win = self.sinc(fb * t / (m + self.eps)) # real 80 | win = self.power( 81 | torch.cat((win, torch.zeros_like(win)), dim=-1), 82 | torch.cat((m, torch.zeros_like(m)), dim=-1) 83 | ) # complex 84 | 85 | weights = scale * torch.cat(( 86 | win[..., :1] * kernel[..., :1] - win[..., 1:] * kernel[..., 1:], 87 | win[..., :1] * kernel[..., 1:] + win[..., 1:] * kernel[..., :1] 88 | ), dim=-1) 89 | 90 | if self.normalized: 91 | weights = weights / (in_features ** 0.5) 92 | 93 | return weights, x_is_complex 94 | 95 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 96 | weights, x_is_complex = self._materialize_weights(x) 97 | 98 | if x_is_complex: 99 | x = torch.stack(( 100 | F.linear(x[..., 0], weights[..., 0]) - F.linear(x[..., 1], weights[..., 1]), 101 | F.linear(x[..., 0], weights[..., 1]) + F.linear(x[..., 1], weights[..., 0]) 102 | ), dim=-1) 103 | else: 104 | x = torch.stack(( 105 | F.linear(x, weights[..., 0]), 106 | F.linear(x, weights[..., 1]) 107 | ), dim=-1) 108 | 109 | if (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)): 110 | x = x + self.bias 111 | 112 | return x, weights 113 | 114 | def extra_repr(self) -> str: 115 | return 'out_features={}, bias={}, normalized={}'.format( 116 | self.out_features, 117 | (self.bias is not None) and (self.bias.numel() == (self.out_features * 2)), 118 | self.normalized 119 | ) 120 | 121 | 122 | ttf_weights = dict() 123 | 124 | 125 | class _ESResNetFBSP(_ESResNet): 126 | 127 | def _inject_members(self): 128 | self.add_module( 129 | 'fbsp', 130 | LinearFBSP( 131 | out_features=int(round(self.n_fft / 2)) + 1 if self.onesided else self.n_fft, 132 | normalized=self.normalized, 133 | bias=False 134 | ) 135 | ) 136 | 137 | def spectrogram(self, x: torch.Tensor) -> torch.Tensor: 138 | with torch.no_grad(): 139 | frames = transforms.frame_signal( 140 | signal=x.view(-1, x.shape[-1]), 141 | frame_length=self.win_length, 142 | hop_length=self.hop_length, 143 | window=self.window 144 | ) 145 | 146 | if self.n_fft > self.win_length: 147 | pad_length = self.n_fft - self.win_length 148 | pad_left = pad_length // 2 149 | pad_right = pad_length - pad_left 150 | frames = F.pad(frames, [pad_left, pad_right]) 151 | 152 | spec, ttf_weights_ = self.fbsp(frames) 153 | 154 | spec = spec.transpose(-2, -3) 155 | ttf_weights[x.device] = ttf_weights_ 156 | 157 | return spec 158 | 159 | def loss_ttf(self, device: torch.device) -> torch.Tensor: 160 | ttf_norm = torch.norm(ttf_weights[device], p=2, dim=[-1, -2]) 161 | loss_ttf_norm = F.mse_loss( 162 | ttf_norm, 163 | torch.full_like(ttf_norm, 1.0 if self.normalized else self.n_fft ** 0.5) 164 | ) 165 | 166 | return loss_ttf_norm 167 | 168 | def loss_fn(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 169 | loss_pred = super(_ESResNetFBSP, self).loss_fn(y_pred, y) 170 | loss_ttf_norm = self.loss_ttf(y_pred.device) 171 | loss = loss_pred + loss_ttf_norm 172 | 173 | return loss 174 | 175 | 176 | class ESResNetFBSP(_ESResNetFBSP): 177 | 178 | loading_func = staticmethod(tv.models.resnet50) 179 | 180 | def __init__(self, 181 | n_fft: int = 256, 182 | hop_length: Optional[int] = None, 183 | win_length: Optional[int] = None, 184 | window: Optional[str] = None, 185 | normalized: bool = False, 186 | onesided: bool = True, 187 | spec_height: int = 224, 188 | spec_width: int = 224, 189 | num_classes: int = 1000, 190 | apply_attention: bool = False, 191 | pretrained: bool = False, 192 | lock_pretrained: Optional[Union[bool, List[str]]] = None): 193 | 194 | super(ESResNetFBSP, self).__init__( 195 | block=Bottleneck, 196 | layers=[3, 4, 6, 3], 197 | apply_attention=apply_attention, 198 | n_fft=n_fft, 199 | hop_length=hop_length, 200 | win_length=win_length, 201 | window=window, 202 | normalized=normalized, 203 | onesided=onesided, 204 | spec_height=spec_height, 205 | spec_width=spec_width, 206 | num_classes=num_classes, 207 | pretrained=pretrained, 208 | lock_pretrained=lock_pretrained 209 | ) 210 | 211 | 212 | class ESResNeXtFBSP(_ESResNetFBSP): 213 | 214 | loading_func = staticmethod(tv.models.resnext50_32x4d) 215 | 216 | def __init__(self, 217 | n_fft: int = 256, 218 | hop_length: Optional[int] = None, 219 | win_length: Optional[int] = None, 220 | window: Optional[str] = None, 221 | normalized: bool = False, 222 | onesided: bool = True, 223 | spec_height: int = 224, 224 | spec_width: int = 224, 225 | num_classes: int = 1000, 226 | apply_attention: bool = False, 227 | pretrained: Union[bool, str] = False, 228 | lock_pretrained: Optional[Union[bool, List[str]]] = None): 229 | 230 | super(ESResNeXtFBSP, self).__init__( 231 | block=Bottleneck, 232 | layers=[3, 4, 6, 3], 233 | apply_attention=apply_attention, 234 | n_fft=n_fft, 235 | hop_length=hop_length, 236 | win_length=win_length, 237 | window=window, 238 | normalized=normalized, 239 | onesided=onesided, 240 | spec_height=spec_height, 241 | spec_width=spec_width, 242 | num_classes=num_classes, 243 | pretrained=pretrained, 244 | lock_pretrained=lock_pretrained, 245 | groups=32, 246 | width_per_group=4 247 | ) 248 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/protocols/audioclip-esc50.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "Multimodal-Audio", 9 | "suffix": "CV1", 10 | "batch_train": 64, 11 | "batch_test": 64, 12 | "workers_train": 4, 13 | "workers_test": 4, 14 | "epochs": 50, 15 | "log_interval": 10, 16 | "saved_models_path": "/path/to/saved/models" 17 | }, 18 | "Model": { 19 | "class": "model.audioclip.AudioCLIP", 20 | "args": { 21 | "multilabel": false, 22 | "pretrained": "/path/to/assets/trained_AudioCLIP.pt" 23 | } 24 | }, 25 | "Optimizer": { 26 | "class": "torch.optim.SGD", 27 | "args": { 28 | "lr": 5e-5, 29 | "momentum": 0.9, 30 | "nesterov": true, 31 | "weight_decay": 5e-4 32 | } 33 | }, 34 | "Scheduler": { 35 | "class": "torch.optim.lr_scheduler.ExponentialLR", 36 | "args": { 37 | "gamma": 0.96 38 | } 39 | }, 40 | "Dataset": { 41 | "class": "utils.datasets.ESC50", 42 | "args": { 43 | "dl_shuffle": true, 44 | "root": "/path/to/ESC50", 45 | "sample_rate": 44100, 46 | "fold": 1, 47 | "training": {"key": "train", "yes": true, "no": false} 48 | } 49 | }, 50 | "Transforms": [ 51 | { 52 | "class": "utils.transforms.ToTensor1D", 53 | "args": {} 54 | }, 55 | { 56 | "class": "utils.transforms.RandomFlip", 57 | "args": {"p": 0.5}, 58 | "test": false 59 | }, 60 | { 61 | "class": "utils.transforms.RandomScale", 62 | "args": {"max_scale": 1.50}, 63 | "test": false 64 | }, 65 | { 66 | "class": "utils.transforms.RandomPadding", 67 | "args": {"out_len": 220500}, 68 | "test": false 69 | }, 70 | { 71 | "class": "utils.transforms.RandomCrop", 72 | "args": {"out_len": 220500}, 73 | "test": false 74 | }, 75 | { 76 | "class": "utils.transforms.RandomNoise", 77 | "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25}, 78 | "test": false 79 | }, 80 | { 81 | "class": "utils.transforms.RandomPadding", 82 | "args": {"out_len": 220500, "train": false}, 83 | "train": false 84 | }, 85 | { 86 | "class": "utils.transforms.RandomCrop", 87 | "args": {"out_len": 220500, "train": false}, 88 | "train": false 89 | } 90 | ], 91 | "Metrics": { 92 | "Performance": { 93 | "window_name": null, 94 | "x_label": "#Epochs", 95 | "y_label": "Accuracy", 96 | "width": 1890, 97 | "height": 416, 98 | "lines": [ 99 | { 100 | "line_label": "Val. Acc.", 101 | "class": "ignite.metrics.Accuracy", 102 | "args": {}, 103 | "is_checkpoint": true 104 | } 105 | ] 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/protocols/audioclip-us8k.json: -------------------------------------------------------------------------------- 1 | { 2 | "Visdom": { 3 | "host": null, 4 | "port": null, 5 | "env_path": null 6 | }, 7 | "Setup": { 8 | "name": "Multimodal-Audio", 9 | "suffix": "CV01", 10 | "batch_train": 64, 11 | "batch_test": 64, 12 | "workers_train": 4, 13 | "workers_test": 4, 14 | "epochs": 50, 15 | "log_interval": 25, 16 | "saved_models_path": "/path/to/saved/models" 17 | }, 18 | "Model": { 19 | "class": "model.audioclip.AudioCLIP", 20 | "args": { 21 | "multilabel": false, 22 | "pretrained": "/path/to/assets/trained_AudioCLIP.pt" 23 | } 24 | }, 25 | "Optimizer": { 26 | "class": "torch.optim.SGD", 27 | "args": { 28 | "lr": 1e-5, 29 | "momentum": 0.9, 30 | "nesterov": true, 31 | "weight_decay": 5e-4 32 | } 33 | }, 34 | "Scheduler": { 35 | "class": "torch.optim.lr_scheduler.ExponentialLR", 36 | "args": { 37 | "gamma": 0.96 38 | } 39 | }, 40 | "Dataset": { 41 | "class": "utils.datasets.UrbanSound8K", 42 | "args": { 43 | "root": "/path/to/UrbanSound8K", 44 | "sample_rate": 44100, 45 | "fold": 1, 46 | "mono": false, 47 | "training": {"key": "train", "yes": true, "no": false} 48 | } 49 | }, 50 | "Transforms": [ 51 | { 52 | "class": "utils.transforms.ToTensor1D", 53 | "args": {} 54 | }, 55 | { 56 | "class": "utils.transforms.RandomFlip", 57 | "args": {"p": 0.5}, 58 | "test": false 59 | }, 60 | { 61 | "class": "utils.transforms.RandomScale", 62 | "args": {"max_scale": 1.50}, 63 | "test": false 64 | }, 65 | { 66 | "class": "utils.transforms.RandomPadding", 67 | "args": {"out_len": 176400}, 68 | "test": false 69 | }, 70 | { 71 | "class": "utils.transforms.RandomCrop", 72 | "args": {"out_len": 176400}, 73 | "test": false 74 | }, 75 | { 76 | "class": "utils.transforms.RandomNoise", 77 | "args": {"snr_min_db": 10.0, "snr_max_db": 120.0, "p": 0.25}, 78 | "test": false 79 | }, 80 | { 81 | "class": "utils.transforms.RandomPadding", 82 | "args": {"out_len": 176400, "train": false}, 83 | "train": false 84 | }, 85 | { 86 | "class": "utils.transforms.RandomCrop", 87 | "args": {"out_len": 176400, "train": false}, 88 | "train": false 89 | } 90 | ], 91 | "Metrics": { 92 | "Performance": { 93 | "window_name": null, 94 | "x_label": "#Epochs", 95 | "y_label": "Accuracy", 96 | "width": 1890, 97 | "height": 416, 98 | "lines": [ 99 | { 100 | "line_label": "Val. Acc.", 101 | "class": "ignite.metrics.Accuracy", 102 | "args": {}, 103 | "is_checkpoint": true 104 | } 105 | ] 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import transforms 3 | 4 | __all__ = [ 5 | 'datasets', 6 | 'transforms' 7 | ] 8 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .esc50 import ESC50 2 | from .us8k import UrbanSound8K 3 | 4 | __all__ = ['ESC50', 'UrbanSound8K'] 5 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/datasets/esc50.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import multiprocessing as mp 4 | 5 | import tqdm 6 | import librosa 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | import torch.utils.data as td 12 | 13 | from typing import Any 14 | from typing import Dict 15 | from typing import List 16 | from typing import Tuple 17 | from typing import Union 18 | from typing import Optional 19 | 20 | 21 | class ESC50(td.Dataset): 22 | 23 | def __init__(self, 24 | root: str, 25 | sample_rate: int = 22050, 26 | train: bool = True, 27 | fold: Optional[int] = None, 28 | transform_audio=None, 29 | target_transform=None, 30 | **_): 31 | 32 | super(ESC50, self).__init__() 33 | 34 | self.sample_rate = sample_rate 35 | 36 | meta = self.load_meta(os.path.join(root, 'meta', 'esc50.csv')) 37 | 38 | if fold is None: 39 | fold = 5 40 | 41 | self.folds_to_load = set(meta['fold']) 42 | 43 | if fold not in self.folds_to_load: 44 | raise ValueError(f'fold {fold} does not exist') 45 | 46 | self.train = train 47 | self.transform = transform_audio 48 | 49 | if self.train: 50 | self.folds_to_load -= {fold} 51 | else: 52 | self.folds_to_load -= self.folds_to_load - {fold} 53 | 54 | self.data: Dict[Union[str, int], Dict[str, Any]] = dict() 55 | self.load_data(meta, os.path.join(root, 'audio')) 56 | self.indices = list(self.data.keys()) 57 | 58 | self.class_idx_to_label = dict() 59 | for row in self.data.values(): 60 | idx = row['target'] 61 | label = row['category'] 62 | self.class_idx_to_label[idx] = label 63 | self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()} 64 | 65 | self.target_transform = target_transform 66 | 67 | @staticmethod 68 | def load_meta(path_to_csv: str) -> pd.DataFrame: 69 | meta = pd.read_csv(path_to_csv) 70 | 71 | return meta 72 | 73 | @staticmethod 74 | def _load_worker(idx: int, filename: str, sample_rate: Optional[int] = None) -> Tuple[int, int, np.ndarray]: 75 | wav, sample_rate = librosa.load(filename, sr=sample_rate, mono=True) 76 | 77 | if wav.ndim == 1: 78 | wav = wav[:, np.newaxis] 79 | 80 | wav = wav.T * 32768.0 81 | 82 | return idx, sample_rate, wav.astype(np.float32) 83 | 84 | def load_data(self, meta: pd.DataFrame, base_path: str): 85 | items_to_load = dict() 86 | 87 | for idx, row in meta.iterrows(): 88 | if row['fold'] in self.folds_to_load: 89 | items_to_load[idx] = os.path.join(base_path, row['filename']), self.sample_rate 90 | 91 | items_to_load = [(idx, path, sample_rate) for idx, (path, sample_rate) in items_to_load.items()] 92 | 93 | num_processes = os.cpu_count() 94 | warnings.filterwarnings('ignore') 95 | with mp.Pool(processes=num_processes) as pool: 96 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') 97 | for idx, sample_rate, wav in pool.starmap( 98 | func=self._load_worker, 99 | iterable=items_to_load, 100 | chunksize=int(np.ceil(len(items_to_load) / num_processes)) or 1 101 | ): 102 | row = meta.loc[idx] 103 | 104 | self.data[idx] = { 105 | 'audio': wav, 106 | 'sample_rate': sample_rate, 107 | 'target': row['target'], 108 | 'category': row['category'].replace('_', ' '), 109 | 'fold': row['fold'], 110 | 'esc10': row['esc10'] 111 | } 112 | 113 | def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]: 114 | if not (0 <= index < len(self)): 115 | raise IndexError 116 | 117 | audio: np.ndarray = self.data[self.indices[index]]['audio'] 118 | target: str = self.data[self.indices[index]]['category'] 119 | 120 | if self.transform is not None: 121 | audio = self.transform(audio) 122 | if self.target_transform is not None: 123 | target = self.target_transform(target) 124 | 125 | return audio, None, [target] 126 | 127 | def __len__(self) -> int: 128 | return len(self.indices) 129 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/datasets/us8k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import multiprocessing as mp 4 | 5 | import tqdm 6 | import librosa 7 | import soundfile as sf 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import torch.utils.data as td 13 | 14 | import sklearn.model_selection as skms 15 | 16 | import utils.transforms as transforms 17 | 18 | from typing import Any 19 | from typing import Dict 20 | from typing import List 21 | from typing import Tuple 22 | from typing import Optional 23 | 24 | 25 | class UrbanSound8K(td.Dataset): 26 | 27 | def __init__(self, 28 | root: str, 29 | sample_rate: int = 22050, 30 | train: bool = True, 31 | fold: Optional[int] = None, 32 | mono: bool = False, 33 | transform_audio=None, 34 | target_transform=None, 35 | **_): 36 | 37 | super(UrbanSound8K, self).__init__() 38 | 39 | self.root = root 40 | self.sample_rate = sample_rate 41 | self.train = train 42 | self.random_split_seed = None 43 | 44 | if fold is None: 45 | fold = 1 46 | 47 | if not (1 <= fold <= 10): 48 | raise ValueError(f'Expected fold in range [1, 10], got {fold}') 49 | 50 | self.fold = fold 51 | self.folds_to_load = set(range(1, 11)) 52 | 53 | if self.fold not in self.folds_to_load: 54 | raise ValueError(f'fold {fold} does not exist') 55 | 56 | if self.train: 57 | # if in training mode, keep all but test fold 58 | self.folds_to_load -= {self.fold} 59 | else: 60 | # if in evaluation mode, keep the test samples only 61 | self.folds_to_load -= self.folds_to_load - {self.fold} 62 | 63 | self.mono = mono 64 | 65 | self.transform = transform_audio 66 | self.target_transform = target_transform 67 | 68 | self.data: Dict[str, Dict[str, Any]] = dict() 69 | self.indices = dict() 70 | self.load_data() 71 | 72 | self.class_idx_to_label = dict() 73 | for row in self.data.values(): 74 | idx = row['target'] 75 | label = row['category'] 76 | self.class_idx_to_label[idx] = label 77 | self.label_to_class_idx = {lb: idx for idx, lb in self.class_idx_to_label.items()} 78 | 79 | @staticmethod 80 | def _load_worker(fn: str, path_to_file: str, sample_rate: int, mono: bool = False) -> Tuple[str, int, np.ndarray]: 81 | wav, sample_rate_ = sf.read( 82 | path_to_file, 83 | dtype='float32', 84 | always_2d=True 85 | ) 86 | 87 | wav = librosa.resample(wav.T, sample_rate_, sample_rate) 88 | 89 | if wav.shape[0] == 1 and not mono: 90 | wav = np.concatenate((wav, wav), axis=0) 91 | 92 | wav = wav[:, :sample_rate * 4] 93 | wav = transforms.scale(wav, wav.min(), wav.max(), -32768.0, 32767.0) 94 | 95 | return fn, sample_rate, wav.astype(np.float32) 96 | 97 | def load_data(self): 98 | # read metadata 99 | meta = pd.read_csv( 100 | os.path.join(self.root, 'metadata', 'UrbanSound8K.csv'), 101 | sep=',', 102 | index_col='slice_file_name' 103 | ) 104 | 105 | for row_idx, (fn, row) in enumerate(meta.iterrows()): 106 | path = os.path.join(self.root, 'audio', 'fold{}'.format(row['fold']), fn) 107 | self.data[fn] = path, self.sample_rate, self.mono 108 | 109 | # by default, the official split from the metadata is used 110 | files_to_load = list() 111 | # if the random seed is not None, the random split is used 112 | if self.random_split_seed is not None: 113 | # given an integer random seed 114 | skf = skms.StratifiedKFold(n_splits=10, shuffle=True, random_state=self.random_split_seed) 115 | 116 | # split the US8K samples into 10 folds 117 | for fold_idx, (train_ids, test_ids) in enumerate(skf.split( 118 | np.zeros(len(meta)), meta['classID'].values.astype(int) 119 | ), 1): 120 | # if this is the fold we want to load, add the corresponding files to the list 121 | if fold_idx == self.fold: 122 | ids = train_ids if self.train else test_ids 123 | filenames = meta.iloc[ids].index 124 | files_to_load.extend(filenames) 125 | break 126 | else: 127 | # if the random seed is None, use the official split 128 | for fn, row in meta.iterrows(): 129 | if int(row['fold']) in self.folds_to_load: 130 | files_to_load.append(fn) 131 | 132 | self.data = {fn: vals for fn, vals in self.data.items() if fn in files_to_load} 133 | self.indices = {idx: fn for idx, fn in enumerate(self.data)} 134 | 135 | num_processes = os.cpu_count() 136 | warnings.filterwarnings('ignore') 137 | with mp.Pool(processes=num_processes) as pool: 138 | tqdm.tqdm.write(f'Loading {self.__class__.__name__} (train={self.train})') 139 | for fn, sample_rate, wav in pool.starmap( 140 | func=self._load_worker, 141 | iterable=[(fn, path, sr, mono) for fn, (path, sr, mono) in self.data.items()], 142 | chunksize=int(np.ceil(len(meta) / num_processes)) or 1 143 | ): 144 | self.data[fn] = { 145 | 'audio': wav, 146 | 'sample_rate': sample_rate, 147 | 'target': meta.loc[fn, 'classID'], 148 | 'category': meta.loc[fn, 'class'].replace('_', ' ').strip(' '), 149 | 'background': bool(meta.loc[fn, 'salience'] - 1) 150 | } 151 | 152 | def __getitem__(self, index: int) -> Tuple[np.ndarray, Optional[np.ndarray], List[str]]: 153 | if not (0 <= index < len(self)): 154 | raise IndexError 155 | 156 | audio: np.ndarray = self.data[self.indices[index]]['audio'] 157 | target: str = self.data[self.indices[index]]['category'] 158 | 159 | if self.transform is not None: 160 | audio = self.transform(audio) 161 | if self.target_transform is not None: 162 | target = self.target_transform(target) 163 | 164 | return audio, None, [target] 165 | 166 | def __len__(self) -> int: 167 | return len(self.data) 168 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | # CREDITS: https://github.com/openai/CLIP 2 | 3 | import gzip 4 | import html 5 | import os 6 | from functools import lru_cache 7 | 8 | import ftfy 9 | import regex as re 10 | ROOT_PATHES=[os.path.join(os.path.dirname(os.path.abspath(__file__)), '..','assets'), "/mnt/external/code/guided-diffusion/AudioCLIP/assets/"] 11 | 12 | def download(fname): 13 | for root in ROOT_PATHES: 14 | destination = os.path.join(root, fname) 15 | if os.path.exists(destination): 16 | return destination 17 | root = ROOT_PATHES[0] 18 | destination = os.path.join(root, fname) 19 | os.makedirs(root, exist_ok=True) 20 | download_cammand = f"wget -P {root} https://github.com/AndreyGuzhov/AudioCLIP/releases/download/v0.1/{fname}" 21 | 22 | os.system(download_cammand) 23 | return destination 24 | 25 | 26 | 27 | @lru_cache() 28 | def default_bpe(): 29 | default_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'assets', 'bpe_simple_vocab_16e6.txt.gz') 30 | if os.path.exists(default_path)==False: 31 | default_path = download('bpe_simple_vocab_16e6.txt.gz') 32 | return default_path 33 | 34 | 35 | @lru_cache() 36 | def bytes_to_unicode(): 37 | """ 38 | Returns list of utf-8 byte and a corresponding list of unicode strings. 39 | The reversible bpe codes work on unicode strings. 40 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 41 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 42 | This is a signficant percentage of your normal, say, 32K bpe vocab. 43 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 44 | And avoids mapping to whitespace/control characters the bpe code barfs on. 45 | """ 46 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 47 | cs = bs[:] 48 | n = 0 49 | for b in range(2**8): 50 | if b not in bs: 51 | bs.append(b) 52 | cs.append(2**8+n) 53 | n += 1 54 | cs = [chr(n) for n in cs] 55 | return dict(zip(bs, cs)) 56 | 57 | 58 | def get_pairs(word): 59 | """Return set of symbol pairs in a word. 60 | Word is represented as tuple of symbols (symbols being variable-length strings). 61 | """ 62 | pairs = set() 63 | prev_char = word[0] 64 | for char in word[1:]: 65 | pairs.add((prev_char, char)) 66 | prev_char = char 67 | return pairs 68 | 69 | 70 | def basic_clean(text): 71 | text = ftfy.fix_text(text) 72 | text = html.unescape(html.unescape(text)) 73 | return text.strip() 74 | 75 | 76 | def whitespace_clean(text): 77 | text = re.sub(r'\s+', ' ', text) 78 | text = text.strip() 79 | return text 80 | 81 | 82 | class SimpleTokenizer(object): 83 | def __init__(self, bpe_path: str = default_bpe()): 84 | self.byte_encoder = bytes_to_unicode() 85 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 86 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 87 | merges = merges[1:49152-256-2+1] 88 | merges = [tuple(merge.split()) for merge in merges] 89 | vocab = list(bytes_to_unicode().values()) 90 | vocab = vocab + [v+'' for v in vocab] 91 | for merge in merges: 92 | vocab.append(''.join(merge)) 93 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 94 | self.encoder = dict(zip(vocab, range(len(vocab)))) 95 | self.decoder = {v: k for k, v in self.encoder.items()} 96 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 97 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 98 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 99 | 100 | def bpe(self, token): 101 | if token in self.cache: 102 | return self.cache[token] 103 | word = tuple(token[:-1]) + ( token[-1] + '',) 104 | pairs = get_pairs(word) 105 | 106 | if not pairs: 107 | return token+'' 108 | 109 | while True: 110 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 111 | if bigram not in self.bpe_ranks: 112 | break 113 | first, second = bigram 114 | new_word = [] 115 | i = 0 116 | while i < len(word): 117 | try: 118 | j = word.index(first, i) 119 | new_word.extend(word[i:j]) 120 | i = j 121 | except: 122 | new_word.extend(word[i:]) 123 | break 124 | 125 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 126 | new_word.append(first+second) 127 | i += 2 128 | else: 129 | new_word.append(word[i]) 130 | i += 1 131 | new_word = tuple(new_word) 132 | word = new_word 133 | if len(word) == 1: 134 | break 135 | else: 136 | pairs = get_pairs(word) 137 | word = ' '.join(word) 138 | self.cache[token] = word 139 | return word 140 | 141 | def encode(self, text): 142 | bpe_tokens = [] 143 | text = whitespace_clean(basic_clean(text)).lower() 144 | for token in re.findall(self.pat, text): 145 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 146 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 147 | return bpe_tokens 148 | 149 | def decode(self, tokens): 150 | text = ''.join([self.decoder[token] for token in tokens]) 151 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 152 | return text 153 | -------------------------------------------------------------------------------- /evaluations/AudioCLIP/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torchvision as tv 7 | 8 | import ignite_trainer as it 9 | 10 | 11 | def scale(old_value, old_min, old_max, new_min, new_max): 12 | old_range = (old_max - old_min) 13 | new_range = (new_max - new_min) 14 | new_value = (((old_value - old_min) * new_range) / old_range) + new_min 15 | 16 | return new_value 17 | 18 | 19 | def frame_signal(signal: torch.Tensor, 20 | frame_length: int, 21 | hop_length: int, 22 | window: torch.Tensor = None) -> torch.Tensor: 23 | 24 | if window is None: 25 | window = torch.ones(frame_length, dtype=signal.dtype, device=signal.device) 26 | 27 | if window.shape[0] != frame_length: 28 | raise ValueError('Wrong `window` length: expected {}, got {}'.format(window.shape[0], frame_length)) 29 | 30 | signal_length = signal.shape[-1] 31 | 32 | if signal_length <= frame_length: 33 | num_frames = 1 34 | else: 35 | num_frames = 1 + int(math.ceil((1.0 * signal_length - frame_length) / hop_length)) 36 | 37 | pad_len = int((num_frames - 1) * hop_length + frame_length) 38 | if pad_len > signal_length: 39 | zeros = torch.zeros(pad_len - signal_length, device=signal.device, dtype=signal.dtype) 40 | 41 | while zeros.dim() < signal.dim(): 42 | zeros.unsqueeze_(0) 43 | 44 | pad_signal = torch.cat((zeros.expand(*signal.shape[:-1], -1)[..., :zeros.shape[-1] // 2], signal), dim=-1) 45 | pad_signal = torch.cat((pad_signal, zeros.expand(*signal.shape[:-1], -1)[..., zeros.shape[-1] // 2:]), dim=-1) 46 | else: 47 | pad_signal = signal 48 | 49 | indices = torch.arange(0, frame_length, device=signal.device).repeat(num_frames, 1) 50 | indices += torch.arange( 51 | 0, 52 | num_frames * hop_length, 53 | hop_length, 54 | device=signal.device 55 | ).repeat(frame_length, 1).t_() 56 | indices = indices.long() 57 | 58 | frames = pad_signal[..., indices] 59 | frames = frames * window 60 | 61 | return frames 62 | 63 | 64 | class ToTensor1D(tv.transforms.ToTensor): 65 | 66 | def __call__(self, tensor: np.ndarray): 67 | tensor_2d = super(ToTensor1D, self).__call__(tensor[..., np.newaxis]) 68 | 69 | return tensor_2d.squeeze_(0) 70 | 71 | 72 | class RandomFlip(it.AbstractTransform): 73 | 74 | def __init__(self, p: float = 0.5): 75 | super(RandomFlip, self).__init__() 76 | 77 | self.p = p 78 | 79 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 80 | if x.dim() > 2: 81 | flip_mask = torch.rand(x.shape[0], device=x.device) <= self.p 82 | x[flip_mask] = x[flip_mask].flip(-1) 83 | else: 84 | if torch.rand(1) <= self.p: 85 | x = x.flip(0) 86 | 87 | return x 88 | 89 | 90 | class RandomScale(it.AbstractTransform): 91 | 92 | def __init__(self, max_scale: float = 1.25): 93 | super(RandomScale, self).__init__() 94 | 95 | self.max_scale = max_scale 96 | 97 | @staticmethod 98 | def random_scale(max_scale: float, signal: torch.Tensor) -> torch.Tensor: 99 | scaling = np.power(max_scale, np.random.uniform(-1, 1)) 100 | output_size = int(signal.shape[-1] * scaling) 101 | ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling) 102 | 103 | ref1 = ref.clone().type(torch.int64) 104 | ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64)) 105 | r = ref - ref1.type(ref.type()) 106 | scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r 107 | 108 | return scaled_signal 109 | 110 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 111 | return self.random_scale(self.max_scale, x) 112 | 113 | 114 | class RandomCrop(it.AbstractTransform): 115 | 116 | def __init__(self, out_len: int = 44100, train: bool = True): 117 | super(RandomCrop, self).__init__() 118 | 119 | self.out_len = out_len 120 | self.train = train 121 | 122 | def random_crop(self, signal: torch.Tensor) -> torch.Tensor: 123 | if self.train: 124 | left = np.random.randint(0, signal.shape[-1] - self.out_len) 125 | else: 126 | left = int(round(0.5 * (signal.shape[-1] - self.out_len))) 127 | 128 | orig_std = signal.float().std() * 0.5 129 | output = signal[..., left:left + self.out_len] 130 | 131 | out_std = output.float().std() 132 | if out_std < orig_std: 133 | output = signal[..., :self.out_len] 134 | 135 | new_out_std = output.float().std() 136 | if orig_std > new_out_std > out_std: 137 | output = signal[..., -self.out_len:] 138 | 139 | return output 140 | 141 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 142 | return self.random_crop(x) if x.shape[-1] > self.out_len else x 143 | 144 | 145 | class RandomPadding(it.AbstractTransform): 146 | 147 | def __init__(self, out_len: int = 88200, train: bool = True): 148 | super(RandomPadding, self).__init__() 149 | 150 | self.out_len = out_len 151 | self.train = train 152 | 153 | def random_pad(self, signal: torch.Tensor) -> torch.Tensor: 154 | if self.train: 155 | left = np.random.randint(0, self.out_len - signal.shape[-1]) 156 | else: 157 | left = int(round(0.5 * (self.out_len - signal.shape[-1]))) 158 | 159 | right = self.out_len - (left + signal.shape[-1]) 160 | 161 | pad_value_left = signal[..., 0].float().mean().to(signal.dtype) 162 | pad_value_right = signal[..., -1].float().mean().to(signal.dtype) 163 | output = torch.cat(( 164 | torch.zeros(signal.shape[:-1] + (left,), dtype=signal.dtype, device=signal.device).fill_(pad_value_left), 165 | signal, 166 | torch.zeros(signal.shape[:-1] + (right,), dtype=signal.dtype, device=signal.device).fill_(pad_value_right) 167 | ), dim=-1) 168 | 169 | return output 170 | 171 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 172 | return self.random_pad(x) if x.shape[-1] < self.out_len else x 173 | 174 | 175 | class RandomNoise(it.AbstractTransform): 176 | 177 | def __init__(self, snr_min_db: float = -10.0, snr_max_db: float = 100.0, p: float = 0.5): 178 | super(RandomNoise, self).__init__() 179 | 180 | self.p = p 181 | self.snr_min_db = snr_min_db 182 | self.snr_max_db = snr_max_db 183 | 184 | def random_noise(self, signal: torch.Tensor) -> torch.Tensor: 185 | target_snr = np.random.rand() * (self.snr_max_db - self.snr_min_db + 1.0) + self.snr_min_db 186 | 187 | signal_watts = torch.mean(signal ** 2, dim=(-1, -2)) 188 | signal_db = 10 * torch.log10(signal_watts) 189 | 190 | noise_db = signal_db - target_snr 191 | noise_watts = 10 ** (noise_db / 10) 192 | noise = torch.normal(0.0, noise_watts.item() ** 0.5, signal.shape) 193 | 194 | output = signal + noise 195 | 196 | return output 197 | 198 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 199 | return self.random_noise(x) if np.random.rand() <= self.p else x 200 | -------------------------------------------------------------------------------- /evaluations/c3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/evaluations/c3d/__init__.py -------------------------------------------------------------------------------- /evaluations/c3d/c3d_ft.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import collections 4 | import os 5 | 6 | from chainer import link 7 | from chainer.dataset import download 8 | from chainer.functions.activation.relu import relu 9 | from chainer.functions.activation.softmax import softmax 10 | from chainer.functions.noise.dropout import dropout 11 | from chainer.functions.pooling.max_pooling_nd import max_pooling_nd 12 | from chainer.initializers import constant 13 | from chainer.initializers import normal 14 | from chainer.links.connection.convolution_nd import ConvolutionND 15 | from chainer.links.connection.linear import Linear 16 | from chainer.serializers import npz 17 | 18 | 19 | class C3DVersion1(link.Chain): 20 | 21 | def __init__(self, pretrained_model='auto'): 22 | if pretrained_model: 23 | # As a sampling process is time-consuming, 24 | # we employ a zero initializer for faster computation. 25 | init = constant.Zero() 26 | conv_kwargs = {'initialW': init, 'initial_bias': init} 27 | fc_kwargs = conv_kwargs 28 | else: 29 | # employ default initializers used in the original paper 30 | conv_kwargs = { 31 | 'initialW': normal.Normal(0.01), 32 | 'initial_bias': constant.Zero(), 33 | } 34 | fc_kwargs = { 35 | 'initialW': normal.Normal(0.005), 36 | 'initial_bias': constant.One(), 37 | } 38 | super(C3DVersion1, self).__init__( 39 | conv1a=ConvolutionND(3, 3, 64, 3, 1, 1, **conv_kwargs), 40 | conv2a=ConvolutionND(3, 64, 128, 3, 1, 1, **conv_kwargs), 41 | conv3a=ConvolutionND(3, 128, 256, 3, 1, 1, **conv_kwargs), 42 | conv3b=ConvolutionND(3, 256, 256, 3, 1, 1, **conv_kwargs), 43 | conv4a=ConvolutionND(3, 256, 512, 3, 1, 1, **conv_kwargs), 44 | conv4b=ConvolutionND(3, 512, 512, 3, 1, 1, **conv_kwargs), 45 | conv5a=ConvolutionND(3, 512, 512, 3, 1, 1, **conv_kwargs), 46 | conv5b=ConvolutionND(3, 512, 512, 3, 1, 1, **conv_kwargs), 47 | fc6=Linear(512 * 4 * 4, 4096, **fc_kwargs), 48 | fc7=Linear(4096, 4096, **fc_kwargs), 49 | fc8=Linear(4096, 101, **fc_kwargs), 50 | ) 51 | if pretrained_model == 'auto': 52 | _retrieve( 53 | 'conv3d_deepnetA_ucf.npz', 54 | 'http://vlg.cs.dartmouth.edu/c3d/' 55 | 'c3d_ucf101_finetune_whole_iter_20000', 56 | self) 57 | elif pretrained_model: 58 | npz.load_npz(pretrained_model, self) 59 | 60 | self.functions = collections.OrderedDict([ 61 | ('conv1a', [self.conv1a, relu]), 62 | ('pool1', [_max_pooling_2d]), 63 | ('conv2a', [self.conv2a, relu]), 64 | ('pool2', [_max_pooling_3d]), 65 | ('conv3a', [self.conv3a, relu]), 66 | ('conv3b', [self.conv3b, relu]), 67 | ('pool3', [_max_pooling_3d]), 68 | ('conv4a', [self.conv4a, relu]), 69 | ('conv4b', [self.conv4b, relu]), 70 | ('pool4', [_max_pooling_3d]), 71 | ('conv5a', [self.conv5a, relu]), 72 | ('conv5b', [self.conv5b, relu]), 73 | ('pool5', [_max_pooling_3d]), 74 | ('fc6', [self.fc6, relu, dropout]), 75 | ('fc7', [self.fc7, relu, dropout]), 76 | ('fc8', [self.fc8]), 77 | ('prob', [softmax]), 78 | ]) 79 | 80 | @property 81 | def available_layers(self): 82 | return list(self.functions.keys()) 83 | 84 | @classmethod 85 | def convert_caffemodel_to_npz(cls, path_caffemodel, path_npz): 86 | """Converts a pre-trained caffemodel to a chainer model. 87 | 88 | Args: 89 | path_caffemodel (str): Path of the pre-trained caffemodel. 90 | path_npz (str): Path of the converted chainer model. 91 | """ 92 | 93 | # As caffe_function uses shortcut symbols, 94 | # we import it here. 95 | from chainer.links.caffe import caffe_function 96 | caffe_pb = caffe_function.caffe_pb 97 | 98 | caffemodel = caffe_pb.NetParameter() 99 | with open(path_caffemodel, 'rb') as model_file: 100 | caffemodel.MergeFromString(model_file.read()) 101 | chainermodel = cls(pretrained_model=None) 102 | _transfer(caffemodel, chainermodel) 103 | npz.save_npz(path_npz, chainermodel, compression=False) 104 | 105 | def __call__(self, x, layers=['prob']): 106 | h = x 107 | activations = {} 108 | target_layers = set(layers) 109 | for key, funcs in self.functions.items(): 110 | if len(target_layers) == 0: 111 | break 112 | for func in funcs: 113 | h = func(h) 114 | if key in target_layers: 115 | activations[key] = h 116 | target_layers.remove(key) 117 | return activations 118 | 119 | 120 | def _max_pooling_3d(x): 121 | # print(x.data.shape) 122 | return max_pooling_nd(x, ksize=2) 123 | # return max_pooling_nd(x, ksize=2, stride=2) 124 | 125 | 126 | def _max_pooling_2d(x): 127 | return max_pooling_nd(x, ksize=(1, 2, 2)) 128 | # return max_pooling_nd(x, ksize=(1, 2, 2), stride=(1, 2, 2)) 129 | 130 | 131 | def _transfer(caffemodel, chainermodel): 132 | 133 | def transfer_layer(src, dst): 134 | dst.W.data.ravel()[:] = src.blobs[0].diff 135 | dst.b.data.ravel()[:] = src.blobs[1].diff 136 | 137 | layers = {l.name: l for l in caffemodel.layers} 138 | print([l.name for l in caffemodel.layers]) 139 | transfer_layer(layers['conv1a'], chainermodel.conv1a) 140 | transfer_layer(layers['conv2a'], chainermodel.conv2a) 141 | transfer_layer(layers['conv3a'], chainermodel.conv3a) 142 | transfer_layer(layers['conv3b'], chainermodel.conv3b) 143 | transfer_layer(layers['conv4a'], chainermodel.conv4a) 144 | transfer_layer(layers['conv4b'], chainermodel.conv4b) 145 | transfer_layer(layers['conv5a'], chainermodel.conv5a) 146 | transfer_layer(layers['conv5b'], chainermodel.conv5b) 147 | transfer_layer(layers['fc6'], chainermodel.fc6) 148 | transfer_layer(layers['fc7'], chainermodel.fc7) 149 | transfer_layer(layers['fc8'], chainermodel.fc8) 150 | 151 | 152 | def _make_npz(path_npz, url, model): 153 | import pdb; pdb.set_trace() 154 | # path_caffemodel = "/mnt/sakura201/mattya/c3d/c3d_ucf101_finetune_whole_iter_20000" 155 | path_caffemodel = "models/nc3d_ucf101_finetune_whole_iter_20000" 156 | print('Now loading caffemodel (usually it may take few minutes)') 157 | C3DVersion1.convert_caffemodel_to_npz(path_caffemodel, path_npz) 158 | npz.load_npz(path_npz, model) 159 | return n 160 | 161 | 162 | 163 | def _retrieve(name, url, model): 164 | 165 | root = download.get_dataset_directory('pfnet/chainer/models/') 166 | path = os.path.join(root, name) 167 | return download.cache_or_load_file( 168 | path, lambda path: _make_npz(path, url, model), 169 | lambda path: npz.load_npz(path, model)) 170 | -------------------------------------------------------------------------------- /evaluations/compute_fvd_kvd.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import yield_lines 2 | import click 3 | import sys; sys.path.extend(['.', 'src']) 4 | from tqdm import tqdm 5 | 6 | import argparse 7 | import numpy as np 8 | from sklearn.metrics.pairwise import polynomial_kernel 9 | import torch 10 | from util import load_data_for_worker 11 | 12 | 13 | from fvd.fvd import get_fvd_logits, frechet_distance 14 | from fvd.download import load_i3d_pretrained 15 | 16 | 17 | 18 | def polynomial_mmd(X, Y): 19 | m = X.shape[0] 20 | n = Y.shape[0] 21 | 22 | # compute kernels 23 | K_XX = polynomial_kernel(X) 24 | K_YY = polynomial_kernel(Y) 25 | K_XY = polynomial_kernel(X, Y) 26 | 27 | # compute mmd distance 28 | K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1)) 29 | K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1)) 30 | K_XY_sum = K_XY.sum() / (m * n) 31 | 32 | mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum 33 | 34 | return mmd 35 | 36 | 37 | 38 | def main( 39 | ): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--ref_batch", type=str, default="/data6/rld/data/UCF-101/train/ApplyEyeMakeup", help="path to reference batch npz file") 42 | parser.add_argument("--sample_batch", type=str, default="/data6/rld/data/UCF-101/train/ApplyEyeMakeup", help="path to sample batch npz file") 43 | parser.add_argument("--size", type=int, default=128, help="path to sample batch npz file") 44 | parser.add_argument("--frame_num",type=int, default=16, help="path to sample batch npz file") 45 | parser.add_argument("--sample_frame_gap",type=int, default=8, help="path to sample batch npz file") 46 | parser.add_argument("--sample_num", type=int, default=100) 47 | parser.add_argument("--batch_size", type=int, default=16) 48 | args = parser.parse_args() 49 | device = torch.device('cuda') 50 | #################### Load I3D ######################################## 51 | i3d = load_i3d_pretrained(device) 52 | #################### Compute FVD ############################### 53 | 54 | fvds = [] 55 | kvds = [] 56 | fvd, kvd = eval_fvd(i3d, args, device) 57 | fvds.append(fvd) 58 | kvds.append(kvd) 59 | 60 | fvd_mean = np.mean(fvds) 61 | kvd_mean = np.mean(kvds) 62 | fvd_std = np.std(fvds) 63 | kvd_std = np.std(kvds) 64 | 65 | print(f"Final FVD {fvd_mean:.2f} +/- {fvd_std:.2f}") 66 | print(f"Final KVD {kvd_mean:.2f} +/- {kvd_std:.2f}") 67 | 68 | def eval_fvd(i3d, args, device): 69 | sample_loader = load_data_for_worker(base_samples=args.sample_batch, image_size=args.size, \ 70 | batch_size=args.batch_size, frame_num=args.frame_num) 71 | ref_loader = load_data_for_worker(base_samples=args.ref_batch, image_size=args.size, \ 72 | batch_size = args.batch_size, frame_num=args.frame_num, frame_gap = args.sample_frame_gap) 73 | 74 | 75 | print("get real embeddings...") 76 | real_embeddings = [] 77 | for id,ref in enumerate(ref_loader): 78 | if id >=args.sample_num:break 79 | real = ref # BCTHW -> BTHWC 80 | import pdb; pdb.set_trace() 81 | real_embeddings.append(get_fvd_logits(real, i3d=i3d, device=device)) 82 | real_embeddings = torch.cat(real_embeddings) 83 | 84 | print("get fake embeddings...") 85 | fake_embeddings= [] 86 | for id, sample in enumerate(tqdm(sample_loader)): 87 | if id >=args.sample_num:break 88 | 89 | # b t h w c 90 | fake_embeddings.append(get_fvd_logits(sample, i3d=i3d, device=device)) 91 | fake_embeddings = torch.cat(fake_embeddings) 92 | 93 | 94 | 95 | fvd = frechet_distance(fake_embeddings.clone().detach(), real_embeddings.clone().detach()) 96 | kvd = polynomial_mmd(fake_embeddings.clone().detach().cpu().numpy(), real_embeddings.detach().cpu().numpy()) 97 | return fvd.item(), kvd.item() 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /evaluations/compute_video_is.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import sys 6 | 7 | import numpy as np 8 | 9 | import chainer 10 | import chainer.cuda 11 | import cv2 as cv 12 | # import cupy 13 | from c3d.c3d_ft import C3DVersion1 14 | from chainer import Variable 15 | from chainer import cuda 16 | from tqdm import tqdm 17 | 18 | sys.path.insert(0, '.') # isort:skip 19 | from util import load_data_for_worker 20 | 21 | def calc_inception(ys): 22 | N, C = ys.shape 23 | p_all = np.mean(ys, axis=0, keepdims=True) 24 | kl = np.sum(ys * np.log(ys + 1e-7) - ys * np.log(p_all + 1e-7)) / N 25 | return np.exp(kl) 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser(description='inception score') 30 | parser.add_argument("--ref_batch", type=str, default="/data6/rld/data/UCF-101/train/ApplyEyeMakeup", help="path to reference batch npz file") 31 | parser.add_argument("--size", type=int, default=128, help="path to sample batch npz file") 32 | parser.add_argument("--frame_num",type=int, default=16, help="path to sample batch npz file") 33 | parser.add_argument("--sample_frame_gap",type=int, default=1, help="path to sample batch npz file") 34 | parser.add_argument("--sample_num", type=int, default=100) 35 | parser.add_argument("--batch_size", type=int, default=48) 36 | parser.add_argument('--result_dir', type=str, default="./outputs/video-eval/") 37 | parser.add_argument('--devices', type=str, default='0') 38 | parser.add_argument('--mean', type=str, default='models/mean2.npz') 39 | parser.add_argument('--seed', type=int, default=0) 40 | parser.add_argument('--interpolation', type=str, default='INTER_CUBIC') 41 | args = parser.parse_args() 42 | 43 | np.random.seed(args.seed) 44 | 45 | inter_method = args.interpolation 46 | args.interpolation = getattr(cv, args.interpolation) 47 | 48 | cuda.get_device(args.devices).use() 49 | # cupy.random.seed(args.seed) 50 | xp = chainer.cuda.cupy 51 | 52 | c3dmodel = C3DVersion1() 53 | c3dmodel.to_gpu() 54 | 55 | # load model 56 | 57 | mean = np.load(args.mean)['mean'].astype('f') 58 | mean = mean.reshape((3, 1, 16, 128, 171))[:, :, :, :, 21:21 + 128] 59 | 60 | sample_loader = load_data_for_worker(base_samples=args.sample_batch, image_size=args.size, \ 61 | frame_num=args.frame_num, frame_gap= args.sample_frame_gap, batchsize=args.batchsize ) 62 | # generator 63 | ys = [] 64 | for batch_data in tqdm(sample_loader): # b f h w 3 65 | n, f, h, w, c = batch_data.shape 66 | batch_data = batch_data.reshape(n * f, h, w, c) 67 | x_ = np.zeros((n * f, 128, 128, 3)) 68 | for t in range(n * f): 69 | x_[t] = np.asarray( 70 | cv.resize(x[t], (args.size, args.size), interpolation=args.interpolation))# [n*f, 128, 128, 3] 71 | 72 | x = x_.transpose(3, 0, 1, 2).reshape(c, n, f, args.size, args.size) 73 | x = x[::-1] - mean # mean file is BGR-order while model outputs RGB-order 74 | x = x[:, :, :, 8:8 + 112, 8:8 + 112].astype('f') 75 | x = x.transpose(1, 0, 2, 3, 4) 76 | with chainer.using_config('train', False) and \ 77 | chainer.no_backprop_mode(): 78 | # C3D takes an image with BGR order 79 | y = c3dmodel(Variable(xp.asarray(x)), 80 | layers=['prob'])['prob'].data.get() 81 | ys.append(y) 82 | ys = np.asarray(ys).reshape((-1, 101)) 83 | 84 | 85 | score = calc_inception(ys) 86 | with open(f'{args.result_dir}/evaluation-{args.sample_batch}.log'.format(args.iter, inter_method), 'a+') as fp: 87 | print(args.result_dir, args.iter, args.calc_iter, args.mean, score, file=fp) 88 | print('IS score:{}'.format(score)) 89 | 90 | return 0 91 | 92 | 93 | if __name__ == '__main__': 94 | sys.exit(main()) 95 | -------------------------------------------------------------------------------- /evaluations/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/evaluations/fvd/__init__.py -------------------------------------------------------------------------------- /evaluations/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /evaluations/fvd/download.py: -------------------------------------------------------------------------------- 1 | from email.policy import strict 2 | import requests 3 | from tqdm import tqdm 4 | import os 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def get_confirm_token(response): 9 | for key, value in response.cookies.items(): 10 | if key.startswith('download_warning'): 11 | return value 12 | return None 13 | 14 | 15 | def save_response_content(response, destination): 16 | CHUNK_SIZE = 8192 17 | 18 | pbar = tqdm(total=0, unit='iB', unit_scale=True) 19 | with open(destination, 'wb') as f: 20 | for chunk in response.iter_content(CHUNK_SIZE): 21 | if chunk: 22 | f.write(chunk) 23 | pbar.update(len(chunk)) 24 | pbar.close() 25 | 26 | ROOT=os.path.expanduser("~/.cache/mmdiffusion") 27 | def download(id, fname): 28 | 29 | destination = os.path.join(ROOT, fname) 30 | if os.path.exists(destination): 31 | return destination 32 | 33 | os.makedirs(ROOT, exist_ok=True) 34 | destination = os.path.join(ROOT, fname) 35 | 36 | URL = 'https://drive.google.com/uc?export=download' 37 | session = requests.Session() 38 | 39 | response = session.get(URL, params={'id': id}, stream=True) 40 | token = get_confirm_token(response) 41 | 42 | if token: 43 | params = {'id': id, 'confirm': token} 44 | response = session.get(URL, params=params, stream=True) 45 | save_response_content(response, destination) 46 | return destination 47 | 48 | 49 | _I3D_PRETRAINED_ID = '1mQK8KD8G6UWRa5t87SRMm5PVXtlpneJT' 50 | def load_i3d_pretrained(device=torch.device('cpu')): 51 | from .pytorch_i3d import InceptionI3d 52 | i3d = InceptionI3d(400, in_channels=3).to(device) 53 | 54 | if dist.get_rank()==0: 55 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 56 | dist.barrier() 57 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 58 | is_strict=True 59 | state_dict=torch.load(filepath, map_location=device) 60 | 61 | i3d.load_state_dict(state_dict, strict=is_strict) 62 | i3d.eval() 63 | return i3d 64 | 65 | def load_i3d_pretrained_classifier(device=torch.device('cpu'), num_class=400): 66 | from .pytorch_i3d import InceptionI3d_Classifier 67 | i3d = InceptionI3d_Classifier(num_class, in_channels=3).to(device) 68 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 69 | is_strict=True 70 | state_dict=torch.load(filepath, map_location=device) 71 | if num_class!=400: 72 | state_dict.pop("logits.conv3d.weight") 73 | state_dict.pop("logits.conv3d.bias") 74 | is_strict=False 75 | i3d.load_state_dict(state_dict, strict=is_strict) 76 | 77 | return i3d 78 | 79 | -------------------------------------------------------------------------------- /evaluations/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | def preprocess_single(video, resolution, sequence_length=None): 6 | # video: TCHW, {0, ..., 255} 7 | video = video.float() / 255. # TCHW 8 | t, c, h, w = video.shape 9 | 10 | # temporal crop 11 | if sequence_length is not None: 12 | assert sequence_length <= t 13 | video = video[:sequence_length] 14 | 15 | # scale shorter side to resolution 16 | scale = resolution / min(h, w) 17 | if h < w: 18 | target_size = (resolution, math.ceil(w * scale)) 19 | else: 20 | target_size = (math.ceil(h * scale), resolution) 21 | video = F.interpolate(video, size=target_size, mode='bilinear', 22 | align_corners=False) 23 | 24 | # center crop 25 | t, c, h, w = video.shape 26 | w_start = (w - resolution) // 2 27 | h_start = (h - resolution) // 2 28 | video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution] 29 | video = video.permute(1, 0, 2, 3).contiguous() # CTHW 30 | 31 | video -= 0.5 32 | 33 | return video 34 | 35 | def preprocess(videos, target_resolution=224): 36 | # videos in {0, ..., 255} as np.uint8 array 37 | b, t, h, w, c = videos.shape 38 | #videos = torch.from_numpy(videos) 39 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 40 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 41 | 42 | def get_fvd_logits(videos, i3d, device): 43 | videos = preprocess(videos) 44 | embeddings = get_logits(i3d, videos, device) 45 | 46 | return embeddings 47 | 48 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 49 | def _symmetric_matrix_square_root(mat, eps=1e-10): 50 | u, s, v = torch.svd(mat) 51 | si = torch.where(s < eps, s, torch.sqrt(s)) 52 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 53 | 54 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 55 | def trace_sqrt_product(sigma, sigma_v): 56 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 57 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 58 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 59 | 60 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 61 | def cov(m, rowvar=False): 62 | '''Estimate a covariance matrix given data. 63 | 64 | Covariance indicates the level to which two variables vary together. 65 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 66 | then the covariance matrix element `C_{ij}` is the covariance of 67 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 68 | 69 | Args: 70 | m: A 1-D or 2-D array containing multiple variables and observations. 71 | Each row of `m` represents a variable, and each column a single 72 | observation of all those variables. 73 | rowvar: If `rowvar` is True, then each row represents a 74 | variable, with observations in the columns. Otherwise, the 75 | relationship is transposed: each column represents a variable, 76 | while the rows contain observations. 77 | 78 | Returns: 79 | The covariance matrix of the variables. 80 | ''' 81 | if m.dim() > 2: 82 | raise ValueError('m has more than 2 dimensions') 83 | if m.dim() < 2: 84 | m = m.view(1, -1) 85 | if not rowvar and m.size(0) != 1: 86 | m = m.t() 87 | 88 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 89 | m -= torch.mean(m, dim=1, keepdim=True) 90 | mt = m.t() # if complex: mt = m.t().conj() 91 | return fact * m.matmul(mt).squeeze() 92 | 93 | 94 | def frechet_distance(x1, x2): 95 | 96 | x1 = x1.flatten(start_dim=1) 97 | x2 = x2.flatten(start_dim=1) 98 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 99 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 100 | 101 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 102 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 103 | 104 | mean = torch.sum((m - m_w) ** 2) 105 | fd = trace + mean 106 | return fd 107 | 108 | 109 | def get_logits(i3d, videos, device): 110 | """ 111 | assert videos.shape[0] % 16 == 0 112 | with torch.no_grad(): 113 | logits = [] 114 | for i in range(0, videos.shape[0], 16): 115 | batch = videos[i:i + 16].to(device) 116 | logits.append(i3d(batch)) 117 | logits = torch.cat(logits, dim=0) 118 | return logits 119 | """ 120 | 121 | with torch.no_grad(): 122 | logits = i3d(videos.to(device)) 123 | return logits 124 | -------------------------------------------------------------------------------- /evaluations/util.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | 4 | from PIL import Image 5 | import cv2 6 | import blobfile as bf 7 | 8 | import numpy as np 9 | from torch.utils.data import DataLoader, Dataset 10 | from einops import rearrange, repeat 11 | from torchvision import transforms as T 12 | import torch as th 13 | 14 | def load_data( 15 | *, 16 | data_dir, 17 | frame_num, 18 | batch_size, 19 | image_size, 20 | class_cond=False, 21 | order_cond=False, 22 | deterministic=False, 23 | random_crop=False, 24 | random_flip=True, 25 | num_workers=0, 26 | frame_gap=8 27 | 28 | ): 29 | """ 30 | For a dataset, create a generator over (images, kwargs) pairs. 31 | 32 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 33 | more keys, each of which map to a batched Tensor of their own. 34 | The kwargs dict can be used for class labels, in which case the key is "y" 35 | and the values are integer tensors of class labels. 36 | 37 | :param data_dir: a dataset directory. 38 | :param batch_size: the batch size of each returned pair. 39 | :param image_size: the size to which images are resized. 40 | :param class_cond: if True, include a "y" key in returned dicts for class 41 | label. If classes are not available and this is true, an 42 | exception will be raised. 43 | :param deterministic: if True, yield results in a deterministic order. 44 | :param random_crop: if True, randomly crop the images for augmentation. 45 | :param random_flip: if True, randomly flip the images for augmentation. 46 | """ 47 | if not data_dir: 48 | raise ValueError("unspecified data directory") 49 | data_dir_splits = data_dir.split(',') 50 | 51 | all_files = [] 52 | for data_dir_split in data_dir_splits: 53 | all_files.extend(_list_video_files_recursively(data_dir_split)) 54 | 55 | print(f"len(data loader):{len(all_files)}") 56 | classes = None 57 | if class_cond: 58 | # Assume classes are the first part of the filename, 59 | # before an underscore. 60 | class_names = [path.split("/")[-2] for path in all_files] 61 | class_labels = set(class_names) 62 | sorted_classes = {x: i for i, x in enumerate(sorted(class_labels))} 63 | classes = [sorted_classes[x] for x in class_names] 64 | 65 | 66 | print(f"len(data loader classes):{len(class_labels)}") 67 | 68 | dataset = VideoDataset( 69 | all_files, 70 | image_size, 71 | frame_num, 72 | classes=classes, 73 | shard=0, 74 | num_shards=1, 75 | random_crop=random_crop, 76 | random_flip=random_flip, 77 | frame_gap = frame_gap, 78 | order_cond = order_cond 79 | ) 80 | 81 | # train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 82 | if deterministic: 83 | loader = DataLoader( 84 | dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True 85 | ) 86 | else: 87 | loader = DataLoader( 88 | dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True 89 | ) 90 | 91 | while True: 92 | yield from loader 93 | 94 | 95 | def _list_video_files_recursively(data_dir): 96 | results = [] 97 | for entry in sorted(bf.listdir(data_dir)): 98 | full_path = bf.join(data_dir, entry) 99 | ext = entry.split(".")[-1] 100 | if "." in entry and ext.lower() in ["avi", "gif", "mp4","png"]: 101 | 102 | results.append(full_path) 103 | elif bf.isdir(full_path): 104 | 105 | results.extend(_list_video_files_recursively(full_path)) 106 | return results 107 | 108 | 109 | class VideoDataset(Dataset): 110 | def __init__( 111 | self, 112 | video_paths, 113 | resolution, 114 | frame_num, 115 | classes=None, 116 | shard=0, 117 | num_shards=1, 118 | random_crop=False, 119 | random_flip=False, 120 | frame_gap=8, 121 | order_cond=False 122 | ): 123 | super().__init__() 124 | self.resolution = resolution 125 | self.local_videos = video_paths[shard:][::num_shards] 126 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 127 | self.random_crop = random_crop 128 | self.random_flip = random_flip 129 | self.frame_num = frame_num 130 | self.frame_gap = frame_gap 131 | self.order_cond = order_cond 132 | 133 | def __len__(self): 134 | return len(self.local_videos) 135 | 136 | def resize_img(self, img):# size:64:64 137 | ''' 138 | resize img to target_size with padding 139 | ''' 140 | 141 | old_size = img.size 142 | ratio = min(float(self.resolution)/(old_size[i]) for i in range(len(old_size))) 143 | new_size = tuple([int(i*ratio) for i in old_size]) 144 | 145 | img = img.resize((new_size[1], new_size[0]),Image.BICUBIC) #cv2.resize(img,(new_size[1], new_size[0])) 146 | img = np.array(img) 147 | pad_w = self.resolution - new_size[1] 148 | pad_h = self.resolution- new_size[0] 149 | top,bottom = pad_h//2, pad_h-(pad_h//2) 150 | left,right = pad_w//2, pad_w -(pad_w//2) 151 | img_new = cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,None,(0,0,0)) 152 | return img_new 153 | 154 | def _get_gif(self, path): 155 | with bf.BlobFile(path, "rb") as f: 156 | pil_images = Image.open(f) 157 | pil_images = list(map(self.resize_img, seek_all_images(pil_images, channels = 3))) 158 | return pil_images 159 | 160 | def _get_vid(self, path): 161 | cap = cv2.VideoCapture(path) 162 | frames = [] 163 | count = 0 164 | while cap.isOpened(): 165 | ret, frame = cap.read() 166 | if ret == False: 167 | break 168 | 169 | img =Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)) 170 | frames.append(self.resize_img(img)) 171 | count+=1 172 | return frames 173 | 174 | def __getitem__(self, idx): 175 | path = self.local_videos[idx] 176 | # load gif 177 | post_fix=path.split('.')[-1] 178 | if post_fix == 'gif': 179 | video = th.tensor(np.array(self._get_gif(path))) 180 | elif post_fix in ['avi', 'mp4']: 181 | video = th.tensor(np.array(self._get_vid(path))) 182 | elif post_fix in ['png']: 183 | video = th.tensor(np.array(Image.open(path).convert('RGB'))) # H,W,C 184 | f_num = video.shape[0] // self.resolution 185 | video = rearrange(video, "(f1 h) (f2 w) c -> (f1 f2) h w c", f1=f_num, f2=f_num) 186 | 187 | video= video.permute([0,3,1,2])#[B,C,H,W] 188 | 189 | if len(video) < self.frame_num: 190 | append = self.frame_num - len(video) 191 | video = th.cat([video, video[-1:].repeat(append,1,1,1)],dim=0) 192 | elif len(video) >= self.frame_num and len(video) <= self.frame_num * self.frame_gap: 193 | indices = np.linspace(0, len(video)-1, self.frame_num) 194 | video = th.stack([video[int(indice)] for indice in indices], dim=0) 195 | else: 196 | start = random.randint(0, len(video) - self.frame_num * self.frame_gap - 1) 197 | video = video[start : start + self.frame_num * self.frame_gap : self.frame_gap] 198 | 199 | 200 | 201 | video_after_process = video#.float() / 127.5 - 1 #0-1 202 | 203 | 204 | return video_after_process 205 | 206 | def seek_all_images(img, channels = 3): 207 | 208 | i = 0 209 | while True: 210 | try: 211 | img.seek(i) 212 | yield img.convert("RGB") 213 | except EOFError: 214 | break 215 | i += 1 216 | 217 | def load_data_for_worker(base_samples, image_size, frame_num, frame_gap=1, class_cond=False, batch_size=1): 218 | if base_samples.endswith('npz'): 219 | with bf.BlobFile(base_samples, "rb") as f: 220 | obj = np.load(f) 221 | image_arr = obj["arr_0"] 222 | if class_cond: 223 | label_arr = obj["arr_1"] 224 | 225 | while True: 226 | for i in range(len(image_arr)): 227 | video= image_arr[i] 228 | #b,f,h,w,c->b,f 229 | yield video 230 | else: 231 | 232 | dataset = load_data( 233 | data_dir=base_samples, 234 | batch_size=batch_size, 235 | frame_num=frame_num, 236 | image_size=image_size, 237 | class_cond=class_cond, 238 | num_workers=4, 239 | frame_gap=frame_gap, 240 | deterministic=True 241 | ) 242 | 243 | 244 | for batchdata in dataset: 245 | # import pdb; pdb.set_trace() 246 | batchdata= batchdata.permute(0, 1, 3, 4, 2)#.astype('uint8') #[batchsize, frame, W, H, C] 247 | if batchdata.shape[1] < 16: 248 | batchdata =th.cat([batchdata, batchdata[:,-1:,].repeat(1,8,1,1,1)], dim=1) 249 | 250 | # video = th.cat([video, video[-1:].repeat(append,1,1,1)],dim=0) 251 | yield batchdata 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /fig/MM-UNet2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/fig/MM-UNet2.png -------------------------------------------------------------------------------- /fig/aist++.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/fig/aist++.mp4 -------------------------------------------------------------------------------- /fig/audioset.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/fig/audioset.mp4 -------------------------------------------------------------------------------- /fig/landscape.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/fig/landscape.mp4 -------------------------------------------------------------------------------- /fig/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/fig/teaser.png -------------------------------------------------------------------------------- /mm_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/researchmm/MM-Diffusion/7250222114ee1eca36e32d90e8a41f12d92651e2/mm_diffusion/__init__.py -------------------------------------------------------------------------------- /mm_diffusion/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.distributed as dist 4 | import torch as th 5 | import numpy as np 6 | import glob 7 | from PIL import Image 8 | from einops import rearrange 9 | from moviepy.video.io.ImageSequenceClip import ImageSequenceClip 10 | from moviepy.audio.AudioClip import AudioArrayClip 11 | from . import logger 12 | 13 | # def sample_fn_decord(sample_fn): 14 | # sample_fn = sample_fn.replace('dpm_solver++_', '') 15 | # predict_x0, order, method = sample_fn.split('_') 16 | # predict_x0 = True if predict_x0 == "True" else False 17 | # order = int(order) 18 | # return predict_x0, order, method 19 | 20 | def delete_pkl(fake_dir): 21 | fake_paths = glob.glob(os.path.join(fake_dir, "*.pkl")) 22 | for fake_path in fake_paths: 23 | os.remove(fake_path) 24 | print(f"detete pkl from {fake_path}") 25 | return 26 | 27 | 28 | def save_audio(audio, output_path, audio_fps): 29 | audio = audio.T #[len, channel] 30 | audio = np.repeat(audio, 2, axis=1) 31 | audio_clip = AudioArrayClip(audio, fps=audio_fps) 32 | audio_clip.write_audiofile(output_path, fps=audio_fps) 33 | return 34 | 35 | def save_img(video, output_path): 36 | os.makedirs(output_path, exist_ok=True) 37 | for idx, img in enumerate(video): 38 | img_path = os.path.join(output_path, f"{idx:0>8d}.png") 39 | Image.fromarray(img).convert('RGB').save(img_path) 40 | return 41 | 42 | def save_png(img, output_path): 43 | Image.fromarray(img).convert('RGB').save(output_path) 44 | return 45 | 46 | def save_multimodal(video, audio, output_path, args): 47 | imgs = [img for img in video] 48 | audio = audio.T #[len, channel] 49 | audio = np.repeat(audio, 2, axis=1) 50 | audio_clip = AudioArrayClip(audio, fps=args.audio_fps) 51 | video_clip = ImageSequenceClip(imgs, fps=args.video_fps) 52 | video_clip = video_clip.set_audio(audio_clip) 53 | video_clip.write_videofile(output_path, args.video_fps, audio=True, audio_fps=args.audio_fps) 54 | return 55 | 56 | def save_one_image(images, save_path, row=5): 57 | images = images[:row**2, ...] 58 | assert images.shape[0] % row == 0 59 | images = np.pad(images,((0,0),(2,2),(2,2),(0,0))) 60 | images = rearrange(images, '(i j) h w c -> (i h) (j w) c', i = row) 61 | Image.fromarray(images).convert('RGB').save(save_path) 62 | return True 63 | 64 | def save_one_video(videos, save_path, row=5): 65 | videos = videos[:row**2,...] 66 | assert videos.shape[0] % row == 0 67 | videos = np.pad(videos,((0,0),(0,0),(2,2),(2,2),(0,0))) 68 | videos = rearrange(videos, '(i j) f h w c -> f (i h) (j w) c', i = row) 69 | imgs = [Image.fromarray(img) for img in videos] 70 | imgs[0].save(save_path, save_all=True, append_images=imgs[1:], duration=100, loop=0) 71 | return True 72 | 73 | def save_video(result, output_path, video_fps=10): 74 | ext = (output_path.split('.')[-1]).lower() 75 | if ext == 'gif': 76 | imgs = [Image.fromarray(img) for img in result] 77 | imgs[0].save(output_path, save_all=True, append_images=imgs[1:], duration=int(1000/video_fps), loop=0) 78 | elif ext in ["mp4", "avi"]: 79 | imgs = [img for img in result] 80 | video_clip = ImageSequenceClip(imgs, fps=video_fps) 81 | video_clip.write_videofile(output_path, video_fps) 82 | return 83 | 84 | def set_seed_logger(args): 85 | if os.path.exists(args.output_dir)==False and dist.gen_rank()==0: 86 | os.makedirs(args.output_dir) 87 | # predefining random initial seeds 88 | random.seed(args.seed) 89 | os.environ['PYTHONHASHSEED'] = str(args.seed) 90 | np.random.seed(args.seed) 91 | th.manual_seed(args.seed) 92 | th.cuda.manual_seed(args.seed) 93 | th.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. 94 | th.backends.cudnn.benchmark = False 95 | th.backends.cudnn.deterministic = True 96 | 97 | if dist.get_rank() == 0: 98 | logger.log("Effective parameters:") 99 | for key in sorted(args.__dict__): 100 | logger.log(" <<< {}: {}".format(key, args.__dict__[key])) 101 | return args 102 | 103 | def set_seed_logger_random(args): 104 | ''' 105 | training or evaluation on multiple GPUs requires different randomness 106 | ''' 107 | if os.path.exists(args.output_dir)==False and dist.get_rank()==0: 108 | os.makedirs(args.output_dir) 109 | # random.seed(args.seed) 110 | # os.environ['PYTHONHASHSEED'] = str(args.seed) 111 | # np.random.seed(args.seed) 112 | # th.manual_seed(args.seed) 113 | # th.cuda.manual_seed(args.seed) 114 | # th.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. 115 | th.backends.cudnn.benchmark = False 116 | th.backends.cudnn.deterministic = True 117 | 118 | if dist.get_rank() == 0: 119 | logger.log("Effective parameters:") 120 | for key in sorted(args.__dict__): 121 | logger.log(" <<< {}: {}".format(key, args.__dict__[key])) 122 | return args 123 | 124 | 125 | -------------------------------------------------------------------------------- /mm_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | import blobfile as bf 9 | from mpi4py import MPI 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | # Change this to reflect your cluster layout. 14 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 15 | 16 | GPUS_PER_NODE = 8 17 | 18 | def setup_dist(devices=None): 19 | """ 20 | Setup a distributed process group. 21 | """ 22 | global GPUS_PER_NODE 23 | if dist.is_initialized(): 24 | return 25 | 26 | if devices.startswith("G"): 27 | GPUS_PER_NODE = int(devices[1:]) 28 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 29 | else: 30 | devices_list=devices.split(',') 31 | GPUS_PER_NODE = len(devices_list) 32 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{devices_list[MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE]}" 33 | 34 | comm = MPI.COMM_WORLD 35 | 36 | backend = "gloo" if not th.cuda.is_available() else "nccl" 37 | 38 | if backend == "gloo": 39 | hostname = "localhost" 40 | else: 41 | hostname = socket.gethostbyname(socket.getfqdn()) 42 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 43 | os.environ["RANK"] = str(comm.rank) 44 | os.environ["WORLD_SIZE"] = str(comm.size) 45 | 46 | port = comm.bcast(_find_free_port(), root=0) 47 | os.environ["MASTER_PORT"] = str(port) 48 | 49 | dist.init_process_group(backend=backend, init_method="env://") 50 | 51 | 52 | 53 | 54 | def dev(): 55 | """ 56 | Get the device to use for torch.distributed. 57 | """ 58 | 59 | if th.cuda.is_available(): 60 | return th.device("cuda") 61 | return th.device("cpu") 62 | 63 | def load_state_dict(path, **kwargs): 64 | """ 65 | Load a PyTorch file without redundant fetches across MPI ranks. 66 | """ 67 | with bf.BlobFile(path, "rb") as f: 68 | data = f.read() 69 | 70 | return th.load(io.BytesIO(data), **kwargs) 71 | 72 | def sync_params(params): 73 | """ 74 | Synchronize a sequence of Tensors across ranks from rank 0. 75 | """ 76 | for p in params: 77 | with th.no_grad(): 78 | dist.broadcast(p, 0) 79 | 80 | 81 | 82 | 83 | def _find_free_port(): 84 | try: 85 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 86 | s.bind(("", 0)) 87 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 88 | return s.getsockname()[1] 89 | finally: 90 | s.close() 91 | -------------------------------------------------------------------------------- /mm_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | from . import logger 10 | 11 | INITIAL_LOG_LOSS_SCALE = 20.0 12 | 13 | def convert_module_to_f16(l): 14 | """ 15 | Convert primitive modules to float16. 16 | """ 17 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 18 | l.weight.data = l.weight.data.half() 19 | if l.bias is not None: 20 | l.bias.data = l.bias.data.half() 21 | 22 | 23 | def convert_module_to_f32(l): 24 | """ 25 | Convert primitive modules to float32, undoing convert_module_to_f16(). 26 | """ 27 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 28 | l.weight.data = l.weight.data.float() 29 | if l.bias is not None: 30 | l.bias.data = l.bias.data.float() 31 | 32 | 33 | def make_master_params(param_groups_and_shapes): 34 | """ 35 | Copy model parameters into a (differently-shaped) list of full-precision 36 | parameters. 37 | """ 38 | master_params = [] 39 | 40 | for param_group, shape in param_groups_and_shapes: 41 | master_param = nn.Parameter( 42 | _flatten_dense_tensors( 43 | [param.detach().float() for (_, param) in param_group] 44 | ).view(shape) 45 | ) 46 | master_param.requires_grad = True 47 | master_params.append(master_param) 48 | return master_params 49 | 50 | 51 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 52 | """ 53 | Copy the gradients from the model parameters into the master parameters 54 | from make_master_params(). 55 | """ 56 | for master_param, (param_group, shape) in zip( 57 | master_params, param_groups_and_shapes 58 | ): 59 | master_param.grad = _flatten_dense_tensors( 60 | [param_grad_or_zeros(name, param) for (name, param) in param_group] 61 | ).view(shape) 62 | 63 | 64 | def master_params_to_model_params(param_groups_and_shapes, master_params): 65 | """ 66 | Copy the master parameter data back into the model parameters. 67 | """ 68 | # Without copying to a list, if a generator is passed, this will 69 | # silently not copy any parameters. 70 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 71 | for (_, param), unflat_master_param in zip( 72 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 73 | ): 74 | param.detach().copy_(unflat_master_param) 75 | 76 | 77 | def unflatten_master_params(param_group, master_param): 78 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 79 | 80 | 81 | def get_param_groups_and_shapes(named_model_params): 82 | if not isinstance(named_model_params, list): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | 93 | return [scalar_vector_named_params, matrix_named_params] 94 | 95 | 96 | def master_params_to_state_dict( 97 | model, param_groups_and_shapes, master_params, use_fp16 98 | ): 99 | if use_fp16: 100 | state_dict = model.state_dict() 101 | for master_param, (param_group, _) in zip( 102 | master_params, param_groups_and_shapes 103 | ): 104 | for (name, _), unflat_master_param in zip( 105 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 106 | ): 107 | assert name in state_dict 108 | state_dict[name] = unflat_master_param 109 | else: 110 | state_dict = model.state_dict() 111 | for i, (name, _value) in enumerate(model.named_parameters()): 112 | assert name in state_dict 113 | state_dict[name] = master_params[i] 114 | return state_dict 115 | 116 | 117 | def state_dict_to_master_params(model, state_dict, use_fp16): 118 | if use_fp16: 119 | named_model_params = [ 120 | (name, state_dict[name]) for name, _ in model.named_parameters() 121 | ] 122 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 123 | master_params = make_master_params(param_groups_and_shapes) 124 | else: 125 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 126 | return master_params 127 | 128 | 129 | def zero_master_grads(master_params): 130 | for param in master_params: 131 | param.grad = None 132 | 133 | 134 | def zero_grad(model_params): 135 | for param in model_params: 136 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 137 | if param.grad is not None: 138 | param.grad.detach_() 139 | param.grad.zero_() 140 | 141 | 142 | def param_grad_or_zeros(name, param): 143 | if param.grad is not None: 144 | return param.grad.data.detach() 145 | else: 146 | return th.zeros_like(param) 147 | 148 | 149 | class MixedPrecisionTrainer: 150 | def __init__( 151 | self, 152 | *, 153 | model, 154 | use_fp16=False, 155 | fp16_scale_growth=1e-3, 156 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 157 | ): 158 | self.model = model 159 | self.use_fp16 = use_fp16 160 | self.fp16_scale_growth = fp16_scale_growth 161 | 162 | 163 | self.model_params = list(self.model.parameters()) 164 | self.master_params = self.model_params 165 | self.param_groups_and_shapes = None 166 | self.lg_loss_scale = initial_lg_loss_scale 167 | 168 | if self.use_fp16: 169 | self.param_groups_and_shapes = get_param_groups_and_shapes( 170 | self.model.named_parameters() 171 | ) 172 | self.master_params = make_master_params(self.param_groups_and_shapes) 173 | self.model.convert_to_fp16() 174 | 175 | 176 | 177 | def zero_grad(self): 178 | zero_grad(self.model_params) 179 | 180 | def backward(self, loss: th.Tensor): 181 | 182 | if self.use_fp16: 183 | loss_scale = 2 ** self.lg_loss_scale 184 | (loss * loss_scale).backward() 185 | else: 186 | loss.backward() 187 | 188 | def optimize(self, opt: th.optim.Optimizer): 189 | if self.use_fp16: 190 | return self._optimize_fp16(opt) 191 | else: 192 | return self._optimize_normal(opt) 193 | 194 | def _optimize_fp16(self, opt: th.optim.Optimizer): 195 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 196 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 197 | 198 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 199 | if check_overflow(grad_norm): 200 | self.lg_loss_scale -= 1 201 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 202 | zero_master_grads(self.master_params) 203 | return False 204 | 205 | logger.logkv("current_grad_norm", grad_norm) 206 | logger.logkv("current_param_norm", param_norm) 207 | logger.logkv_mean("grad_norm", grad_norm) 208 | logger.logkv_mean("param_norm", param_norm) 209 | 210 | # self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 211 | for p in self.master_params: 212 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 213 | opt.step() 214 | zero_master_grads(self.master_params) 215 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 216 | self.lg_loss_scale += self.fp16_scale_growth 217 | return True 218 | 219 | def _optimize_normal(self, opt: th.optim.Optimizer): 220 | grad_norm, param_norm = self._compute_norms() 221 | logger.logkv("current_grad_norm", grad_norm) 222 | logger.logkv("current_param_norm", param_norm) 223 | logger.logkv_mean("grad_norm", grad_norm) 224 | logger.logkv_mean("param_norm", param_norm) 225 | opt.step() 226 | return True 227 | 228 | def _compute_norms(self, grad_scale=1.0): 229 | grad_norm = 0.0 230 | param_norm = 0.0 231 | for p in self.master_params: 232 | with th.no_grad(): 233 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 234 | if p.grad is not None: 235 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 236 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 237 | 238 | def master_params_to_state_dict(self, master_params): 239 | return master_params_to_state_dict( 240 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 241 | ) 242 | 243 | def state_dict_to_master_params(self, state_dict): 244 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 245 | 246 | 247 | def check_overflow(value): 248 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 249 | -------------------------------------------------------------------------------- /mm_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from PIL import Image 4 | import blobfile as bf 5 | from mpi4py import MPI 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Dataset 8 | import torch as th 9 | import cv2 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | deterministic=False, 17 | random_crop=False, 18 | random_flip=True, 19 | num_workers=0, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor. 25 | 26 | :param data_dir: a dataset directory. 27 | :param batch_size: the batch size of each returned pair. 28 | :param image_size: the size to which images are resized. 29 | :param deterministic: if True, yield results in a deterministic order. 30 | :param random_crop: if True, randomly crop the images for augmentation. 31 | :param random_flip: if True, randomly flip the images for augmentation. 32 | :param num_workers: the number of workers to use for loading data. 33 | """ 34 | if not data_dir: 35 | raise ValueError("unspecified data directory") 36 | data_dir_splits = data_dir.split(',') 37 | 38 | all_files = [] 39 | for data_dir_split in data_dir_splits: 40 | all_files.extend(_list_image_files_recursively(data_dir_split)) 41 | 42 | if MPI.COMM_WORLD.Get_rank()==0: 43 | print(f"len(data loader):{len(all_files)}") 44 | 45 | dataset = ImageDataset( 46 | image_size, 47 | all_files, 48 | shard=MPI.COMM_WORLD.Get_rank(), 49 | num_shards=MPI.COMM_WORLD.Get_size(), 50 | random_crop=random_crop, 51 | random_flip=random_flip, 52 | ) 53 | 54 | if deterministic: 55 | loader = DataLoader( 56 | dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True 57 | ) 58 | else: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True 61 | ) 62 | 63 | while True: 64 | yield from loader 65 | 66 | 67 | def _list_image_files_recursively(data_dir, frame_gap=1): 68 | results = [] 69 | 70 | for entry in sorted(bf.listdir(data_dir)): 71 | full_path = bf.join(data_dir, entry) 72 | ext = entry.split(".")[-1] 73 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png"]: 74 | results.append(full_path) 75 | 76 | elif bf.isdir(full_path): 77 | results.extend(_list_image_files_recursively(full_path, frame_gap)) 78 | return results 79 | 80 | 81 | class ImageDataset(Dataset): 82 | def __init__( 83 | self, 84 | resolution, 85 | image_paths, 86 | shard=0, 87 | num_shards=1, 88 | random_crop=False, 89 | random_flip=False, 90 | ): 91 | super().__init__() 92 | self.resolution = resolution 93 | self.local_images = image_paths[shard:][::num_shards] 94 | self.random_crop = random_crop 95 | self.random_flip = random_flip 96 | 97 | def __len__(self): 98 | return len(self.local_images) 99 | 100 | def resize_img(self, img):# size:64:64 101 | ''' 102 | resize img to target_size with padding 103 | ''' 104 | old_size = img.shape[:2] 105 | ratio = min(float(self.resolution)/(old_size[i]) for i in range(len(old_size))) 106 | new_size = tuple([int(i*ratio) for i in old_size]) 107 | img = cv2.resize(img, (new_size[1], new_size[0]), interpolation=cv2.INTER_CUBIC) 108 | pad_w = self.resolution - new_size[1] 109 | pad_h = self.resolution- new_size[0] 110 | top,bottom = pad_h//2, pad_h-(pad_h//2) 111 | left,right = pad_w//2, pad_w -(pad_w//2) 112 | img_new = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,None,(0,0,0)) 113 | return img_new 114 | 115 | def __getitem__(self, idx): 116 | path = self.local_images[idx] 117 | with bf.BlobFile(path, "rb") as f: 118 | pil_image = Image.open(f) 119 | pil_image.load() 120 | pil_image = pil_image.convert("RGB") 121 | arr = self.resize_img(np.array(pil_image)) 122 | if self.random_flip and random.random() < 0.5: 123 | arr = arr[:, ::-1] 124 | arr = arr.astype(np.float32) / 127.5 - 1 125 | 126 | 127 | return np.transpose(arr, [2, 0, 1]) 128 | 129 | if __name__=='__main__': 130 | 131 | dataset=load_data( 132 | data_dir="../../data/ucf101_jpg/v_ApplyEyeMakeup_g01_c01", 133 | batch_size=8, 134 | image_size=256, 135 | frame_gap=8, 136 | random_flip=True) 137 | while True: 138 | import pdb; pdb.set_trace() 139 | batch, cond = next(dataset) 140 | batch = ((batch + 1) * 127.5).clamp(0, 255).to(th.uint8) 141 | images = batch.reshape(-1,3, 256, 256) 142 | 143 | images = images.permute(0,2,3,1) 144 | for ind, image in enumerate(images): 145 | out_path = f"{ind}.jpg" 146 | Image.fromarray(image.numpy()).convert('RGB').save(out_path) 147 | 148 | -------------------------------------------------------------------------------- /mm_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /mm_diffusion/multimodal_respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from .multimodal_gaussian_diffusion import GaussianDiffusion 4 | 5 | 6 | def space_timesteps(num_timesteps, section_counts): 7 | """ 8 | Create a list of timesteps to use from an original diffusion process, 9 | given the number of timesteps we want to take from equally-sized portions 10 | of the original process. 11 | 12 | For example, if there's 300 timesteps and the section counts are [10,15,20] 13 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 14 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 15 | 16 | If the stride is a string starting with "ddim", then the fixed striding 17 | from the DDIM paper is used, and only one section is allowed. 18 | 19 | :param num_timesteps: the number of diffusion steps in the original 20 | process to divide up. 21 | :param section_counts: either a list of numbers, or a string containing 22 | comma-separated numbers, indicating the step count 23 | per section. As a special case, use "ddimN" where N 24 | is a number of steps to use the striding from the 25 | DDIM paper. 26 | :return: a set of diffusion steps from the original process to use. 27 | """ 28 | if isinstance(section_counts, str): 29 | if section_counts.startswith("ddim"): 30 | desired_count = int(section_counts[len("ddim") :]) 31 | for i in range(1, num_timesteps): 32 | if len(range(0, num_timesteps, i)) == desired_count: 33 | return set(range(0, num_timesteps, i)) 34 | raise ValueError( 35 | f"cannot create exactly {num_timesteps} steps with an integer stride" 36 | ) 37 | section_counts = [int(x) for x in section_counts.split(",")] 38 | size_per = num_timesteps // len(section_counts) 39 | extra = num_timesteps % len(section_counts) 40 | start_idx = 0 41 | all_steps = [] 42 | for i, section_count in enumerate(section_counts): 43 | size = size_per + (1 if i < extra else 0) 44 | if size < section_count: 45 | raise ValueError( 46 | f"cannot divide section of {size} steps into {section_count}" 47 | ) 48 | if section_count <= 1: 49 | frac_stride = 1 50 | else: 51 | frac_stride = (size - 1) / (section_count - 1) 52 | cur_idx = 0.0 53 | taken_steps = [] 54 | for _ in range(section_count): 55 | taken_steps.append(start_idx + round(cur_idx)) 56 | cur_idx += frac_stride 57 | all_steps += taken_steps 58 | start_idx += size 59 | return set(all_steps) 60 | 61 | 62 | class SpacedDiffusion(GaussianDiffusion): 63 | """ 64 | A diffusion process which can skip steps in a base diffusion process. 65 | 66 | :param use_timesteps: a collection (sequence or set) of timesteps from the 67 | original diffusion process to retain. 68 | :param kwargs: the kwargs to create the base diffusion process. 69 | """ 70 | 71 | def __init__(self, use_timesteps, **kwargs): 72 | self.use_timesteps = set(use_timesteps) 73 | self.timestep_map = [] 74 | self.original_num_steps = len(kwargs["betas"]) 75 | 76 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 77 | last_alpha_cumprod = 1.0 78 | new_betas = [] 79 | 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def multimodal_training_losses( 100 | self, model, *args, **kwargs 101 | ): # pylint: disable=signature-differs 102 | return super().multimodal_training_losses(self._wrap_model(model), *args, **kwargs) 103 | 104 | def multimodal_conditional_contrast_training_losses( 105 | self, model, *args, **kwargs 106 | ): # pylint: disable=signature-differs 107 | return super().multimodal_conditional_contrast_training_losses(self._wrap_model(model), *args, **kwargs) 108 | 109 | def condition_mean(self, cond_fn, *args, **kwargs): 110 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 111 | 112 | def condition_score(self, cond_fn, *args, **kwargs): 113 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 114 | 115 | def _wrap_model(self, model): 116 | if isinstance(model, _WrappedModel): 117 | return model 118 | return _WrappedModel( 119 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 120 | ) 121 | 122 | def _scale_timesteps(self, t): 123 | # Scaling is done by the wrapped model. 124 | return t 125 | 126 | 127 | class _WrappedModel: 128 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 129 | self.model = model 130 | self.timestep_map = timestep_map 131 | self.rescale_timesteps = rescale_timesteps 132 | self.original_num_steps = original_num_steps 133 | 134 | def __call__(self, video_x, audio_x, ts, **kwargs): 135 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 136 | new_ts = map_tensor[ts] 137 | if self.rescale_timesteps: 138 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 139 | return self.model(video_x, audio_x, new_ts, **kwargs) 140 | -------------------------------------------------------------------------------- /mm_diffusion/multimodal_script_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is extended from guided_diffusion: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/scripts_util.py 3 | """ 4 | 5 | 6 | import argparse 7 | from einops import rearrange 8 | from . import multimodal_gaussian_diffusion as gd 9 | from .multimodal_respace import SpacedDiffusion, space_timesteps 10 | from .multimodal_unet import MultimodalUNet 11 | 12 | def diffusion_defaults(): 13 | """ 14 | Defaults for multi-modal training. 15 | """ 16 | return dict( 17 | learn_sigma=False, 18 | diffusion_steps=1000, 19 | noise_schedule="linear", 20 | timestep_respacing="", 21 | use_kl=False, 22 | predict_xstart=False, 23 | rescale_timesteps=False, 24 | rescale_learned_sigmas=False, 25 | ) 26 | 27 | 28 | def model_defaults(): 29 | """ 30 | Defaults for multi-modal training. 31 | """ 32 | res = dict( 33 | video_size="16,3,64,64", 34 | audio_size="1,25600", 35 | num_channels=128, 36 | num_res_blocks=2, 37 | num_heads=4, 38 | num_heads_upsample=-1, 39 | num_head_channels=-1, 40 | cross_attention_resolutions="2,4,8", 41 | cross_attention_windows="1,4,8", 42 | cross_attention_shift=True, 43 | video_attention_resolutions="2,4,8", 44 | audio_attention_resolutions="-1", 45 | channel_mult="", 46 | dropout=0.0, 47 | class_cond=False, 48 | use_checkpoint=False, 49 | use_scale_shift_norm=True, 50 | resblock_updown=False, 51 | use_fp16=False, 52 | video_type="2d+1d", 53 | audio_type="1d", 54 | ) 55 | return res 56 | 57 | def model_and_diffusion_defaults(): 58 | res = model_defaults() 59 | res.update(diffusion_defaults()) 60 | return res 61 | 62 | def create_model_and_diffusion( 63 | video_size, 64 | audio_size, 65 | learn_sigma, 66 | num_channels, 67 | num_res_blocks, 68 | channel_mult, 69 | num_heads, 70 | num_head_channels, 71 | num_heads_upsample, 72 | cross_attention_resolutions, 73 | cross_attention_windows, 74 | cross_attention_shift, 75 | video_attention_resolutions, 76 | audio_attention_resolutions, 77 | dropout, 78 | diffusion_steps, 79 | noise_schedule, 80 | timestep_respacing, 81 | use_kl, 82 | predict_xstart, 83 | rescale_timesteps, 84 | rescale_learned_sigmas, 85 | use_checkpoint, 86 | use_scale_shift_norm, 87 | resblock_updown, 88 | use_fp16, 89 | video_type="2d+1d", 90 | audio_type="1d", 91 | class_cond=False 92 | ): 93 | model = create_model( 94 | video_size=video_size, 95 | audio_size=audio_size, 96 | num_channels=num_channels, 97 | num_res_blocks=num_res_blocks, 98 | channel_mult=channel_mult, 99 | learn_sigma=learn_sigma, 100 | class_cond=class_cond, 101 | use_checkpoint=use_checkpoint, 102 | cross_attention_resolutions=cross_attention_resolutions, 103 | cross_attention_windows=cross_attention_windows, 104 | cross_attention_shift=cross_attention_shift, 105 | video_attention_resolutions=video_attention_resolutions, 106 | audio_attention_resolutions=audio_attention_resolutions, 107 | num_heads=num_heads, 108 | num_head_channels=num_head_channels, 109 | num_heads_upsample=num_heads_upsample, 110 | use_scale_shift_norm=use_scale_shift_norm, 111 | dropout=dropout, 112 | resblock_updown=resblock_updown, 113 | use_fp16=use_fp16, 114 | video_type=video_type, 115 | audio_type=audio_type, 116 | 117 | ) 118 | diffusion = create_gaussian_diffusion( 119 | steps=diffusion_steps, 120 | learn_sigma=learn_sigma, 121 | noise_schedule=noise_schedule, 122 | use_kl=use_kl, 123 | predict_xstart=predict_xstart, 124 | rescale_timesteps=rescale_timesteps, 125 | rescale_learned_sigmas=rescale_learned_sigmas, 126 | timestep_respacing=timestep_respacing, 127 | ) 128 | return model, diffusion 129 | 130 | 131 | def create_model( 132 | video_size, 133 | audio_size, 134 | num_channels, 135 | num_res_blocks, 136 | channel_mult="", 137 | learn_sigma=False, 138 | class_cond=False, 139 | use_checkpoint=False, 140 | cross_attention_resolutions="2,4,8", 141 | video_attention_resolutions="2,4,8", 142 | audio_attention_resolutions="2,4,8", 143 | cross_attention_windows="1,4,8", 144 | cross_attention_shift=True, 145 | num_heads=1, 146 | num_head_channels=-1, 147 | num_heads_upsample=-1, 148 | use_scale_shift_norm=False, 149 | dropout=0, 150 | use_fp16=False, 151 | video_type="2d+1d", 152 | audio_type="1d", 153 | resblock_updown=True 154 | ): 155 | 156 | image_size = video_size[-1] 157 | if channel_mult == "": 158 | if image_size == 512: 159 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 160 | elif image_size == 256: 161 | channel_mult = (1, 1, 2, 2, 4, 4) 162 | elif image_size == 128: 163 | channel_mult = (1, 1, 2, 3, 4) 164 | elif image_size == 64: 165 | channel_mult = (1, 2, 3, 4) 166 | else: 167 | raise ValueError(f"unsupported image size: {image_size}") 168 | else: 169 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 170 | 171 | cross_attention_resolutions = [int(i) for i in cross_attention_resolutions.split(',')] 172 | video_attention_resolutions = [int(i) for i in video_attention_resolutions.split(',')] 173 | audio_attention_resolutions = [int(i) for i in audio_attention_resolutions.split(',')] 174 | cross_attention_windows = [int(i) for i in cross_attention_windows.split(',')] 175 | 176 | return MultimodalUNet( 177 | video_size=video_size, 178 | audio_size=audio_size, 179 | model_channels=num_channels, 180 | video_out_channels=(3 if not learn_sigma else 6), 181 | audio_out_channels=(1 if not learn_sigma else 2), 182 | num_res_blocks=num_res_blocks, 183 | cross_attention_resolutions=cross_attention_resolutions, 184 | cross_attention_windows=cross_attention_windows, 185 | cross_attention_shift=cross_attention_shift, 186 | video_attention_resolutions=video_attention_resolutions, 187 | audio_attention_resolutions=audio_attention_resolutions, 188 | video_type=video_type, 189 | audio_type= audio_type, 190 | 191 | dropout=dropout, 192 | channel_mult=channel_mult, 193 | num_classes=None, 194 | use_checkpoint=use_checkpoint, 195 | use_fp16=use_fp16, 196 | num_heads=num_heads, 197 | num_head_channels=num_head_channels, 198 | num_heads_upsample=num_heads_upsample, 199 | use_scale_shift_norm=use_scale_shift_norm, 200 | resblock_updown=resblock_updown 201 | ) 202 | 203 | 204 | def create_gaussian_diffusion( 205 | *, 206 | steps=1000, 207 | learn_sigma=False, 208 | sigma_small=False, 209 | noise_schedule="linear", 210 | use_kl=False, 211 | predict_xstart=False, 212 | rescale_timesteps=False, 213 | rescale_learned_sigmas=False, 214 | timestep_respacing="", 215 | ): 216 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 217 | if use_kl: 218 | loss_type = gd.LossType.RESCALED_KL 219 | elif rescale_learned_sigmas: 220 | loss_type = gd.LossType.RESCALED_MSE 221 | else: 222 | loss_type = gd.LossType.MSE 223 | if not timestep_respacing: 224 | timestep_respacing = [steps] 225 | return SpacedDiffusion( 226 | use_timesteps=space_timesteps(steps, timestep_respacing), 227 | betas=betas, 228 | model_mean_type=( 229 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 230 | ), 231 | model_var_type=( 232 | ( 233 | gd.ModelVarType.FIXED_LARGE 234 | if not sigma_small 235 | else gd.ModelVarType.FIXED_SMALL 236 | ) 237 | if not learn_sigma 238 | else gd.ModelVarType.LEARNED_RANGE 239 | ), 240 | loss_type=loss_type, 241 | rescale_timesteps=rescale_timesteps, 242 | ) 243 | 244 | 245 | def add_dict_to_argparser(parser, default_dict): 246 | for k, v in default_dict.items(): 247 | v_type = type(v) 248 | if v is None: 249 | v_type = str 250 | elif isinstance(v, bool): 251 | v_type = str2bool 252 | parser.add_argument(f"--{k}", default=v, type=v_type) 253 | 254 | 255 | def args_to_dict(args, keys): 256 | return {k: getattr(args, k) for k in keys} 257 | 258 | 259 | def str2bool(v): 260 | """ 261 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 262 | """ 263 | if isinstance(v, bool): 264 | return v 265 | if v.lower() in ("yes", "true", "t", "y", "1"): 266 | return True 267 | elif v.lower() in ("no", "false", "f", "n", "0"): 268 | return False 269 | else: 270 | raise argparse.ArgumentTypeError("boolean value expected") 271 | -------------------------------------------------------------------------------- /mm_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | from einops import rearrange, repeat 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | class GroupNorm32(nn.Module): 17 | def __init__(self, group, channel ): 18 | super(GroupNorm32, self).__init__() 19 | self.channel = channel 20 | self.GroupNorm = nn.GroupNorm(group, channel) 21 | 22 | def forward(self, x): 23 | rearrange_flag = False 24 | if x.shape[1] != self.channel and x.dim()==5: 25 | b,f,c,h,w =x.shape 26 | x = rearrange(x, 'b t c h w -> b c t h w') 27 | rearrange_flag = True 28 | 29 | x = self.GroupNorm(x.float()).type(x.dtype) 30 | 31 | if rearrange_flag: 32 | x = rearrange(x, 'b c t h w -> b t c h w', b=b) 33 | return x 34 | 35 | class ImgGroupNorm(nn.GroupNorm): 36 | def forward(self, x): 37 | return super().forward(x.float()).type(x.dtype) 38 | 39 | 40 | class GroupNorm32_3d(nn.Module): 41 | def __init__(self, group, channel, batch_size): 42 | super(GroupNorm32_3d, self).__init__() 43 | self.batch_size = batch_size 44 | self.GroupNorm = nn.GroupNorm(group, channel) 45 | 46 | def forward(self, x): 47 | 48 | input_cluster = True 49 | if x.shape[0] > self.batch_size: 50 | if x.dim() == 3: 51 | # b_x, c_x, w_x = x.shape 52 | h = rearrange(x, '(b t) c h -> b c h t' , b=self.batch_size) 53 | elif x.dim() == 4: 54 | # b_x, c_x, w_x, h_x = x.shape 55 | h = rearrange(x, '(b t) c h w -> b c h w t' , b=self.batch_size) 56 | elif x.dim()==5: 57 | # b_x, c_x, w_x, h_x, o_x = x.shape 58 | h = rearrange(x, '(b t) c h w o -> b c h w o t' , b=self.batch_size) 59 | else: 60 | raise NotImplementedError 61 | else: 62 | input_cluster = False 63 | 64 | h = rearrange(x, 'b t c h w -> b c h w t' ) 65 | 66 | 67 | 68 | h = self.GroupNorm.forward(h.float()).type(x.dtype) 69 | if input_cluster: 70 | if h.dim() == 5: 71 | h = rearrange(h, 'b c h w t -> (b t) c h w') 72 | elif h.dim() == 4: 73 | h = rearrange(h, 'b c h t -> (b t) c h') 74 | elif h.dim() == 6: 75 | h = rearrange(h, 'b c h w o t -> (b t) c h w o') 76 | else: 77 | raise NotImplementedError 78 | else: 79 | h = rearrange(h, 'b c h w t -> b t c h w' ) 80 | 81 | return h 82 | 83 | 84 | def conv_nd(dims, *args, **kwargs): 85 | """ 86 | Create a 1D, 2D, or 3D convolution module. 87 | """ 88 | if dims == 1: 89 | return nn.Conv1d(*args, **kwargs) 90 | elif dims == 2: 91 | return nn.Conv2d(*args, **kwargs) 92 | elif dims == 3: 93 | return nn.Conv3d(*args, **kwargs) 94 | raise ValueError(f"unsupported dimensions: {dims}") 95 | 96 | class temporal_conv(nn.Module): 97 | """ 98 | Create a 1D, 2D, or 3D convolution module. 99 | """ 100 | def __init__(self,*args, **kwargs): 101 | self.conv = nn.Conv1d(*args, **kwargs) 102 | def forward(x): 103 | 104 | return self.conv(x) 105 | 106 | 107 | 108 | def linear(*args, **kwargs): 109 | """ 110 | Create a linear module. 111 | """ 112 | return nn.Linear(*args, **kwargs) 113 | 114 | 115 | def avg_pool_nd(dims, *args, **kwargs): 116 | """ 117 | Create a 1D, 2D, or 3D average pooling module. 118 | """ 119 | if dims == 1: 120 | return nn.AvgPool1d(*args, **kwargs) 121 | elif dims == 2: 122 | return nn.AvgPool2d(*args, **kwargs) 123 | elif dims == 3: 124 | return nn.AvgPool3d(*args, **kwargs) 125 | raise ValueError(f"unsupported dimensions: {dims}") 126 | 127 | 128 | def update_ema(target_params, source_params, rate=0.99): 129 | """ 130 | Update target parameters to be closer to those of source parameters using 131 | an exponential moving average. 132 | 133 | :param target_params: the target parameter sequence. 134 | :param source_params: the source parameter sequence. 135 | :param rate: the EMA rate (closer to 1 means slower). 136 | """ 137 | for targ, src in zip(target_params, source_params): 138 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 139 | 140 | 141 | def zero_module(module): 142 | """ 143 | Zero out the parameters of a module and return it. 144 | """ 145 | for p in module.parameters(): 146 | p.detach().zero_() 147 | return module 148 | 149 | 150 | def scale_module(module, scale): 151 | """ 152 | Scale the parameters of a module and return it. 153 | """ 154 | for p in module.parameters(): 155 | p.detach().mul_(scale) 156 | return module 157 | 158 | 159 | def mean_flat(tensor): 160 | """ 161 | Take the mean over all non-batch dimensions. 162 | """ 163 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 164 | 165 | 166 | def normalization_3d(channels, batch_size): 167 | """ 168 | Make a standard normalization layer. 169 | 170 | :param channels: number of input channels. 171 | :return: an nn.Module for normalization. 172 | """ 173 | 174 | return GroupNorm32_3d(32, channels, batch_size) 175 | def normalization(channels): 176 | """ 177 | Make a standard normalization layer. 178 | :param channels: number of input channels. 179 | :return: an nn.Module for normalization. 180 | """ 181 | return GroupNorm32(32, channels) 182 | 183 | def Imgnormalization(channels): 184 | """ 185 | Make a standard normalization layer. 186 | :param channels: number of input channels. 187 | :return: an nn.Module for normalization. 188 | """ 189 | return ImgGroupNorm(32, channels) 190 | 191 | 192 | def timestep_embedding(timesteps, dim, max_period=10000): 193 | """ 194 | Create sinusoidal timestep embeddings. 195 | 196 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 197 | These may be fractional. 198 | :param dim: the dimension of the output. 199 | :param max_period: controls the minimum frequency of the embeddings. 200 | :return: an [N x dim] Tensor of positional embeddings. 201 | """ 202 | half = dim // 2 203 | freqs = th.exp( 204 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 205 | ).to(device=timesteps.device) 206 | args = timesteps[:, None].float() * freqs[None] 207 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 208 | if dim % 2: 209 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 210 | return embedding 211 | 212 | def temporalstep_embedding(timesteps, dim, max_period=10): 213 | """ 214 | Create sinusoidal timestep embeddings. 215 | 216 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 217 | These may be fractional. 218 | :param dim: the dimension of the output. 219 | :param max_period: controls the minimum frequency of the embeddings. 220 | :return: an [N x dim] Tensor of positional embeddings. 221 | """ 222 | half = dim // 2 223 | freqs = th.exp( 224 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 225 | ).to(device=timesteps.device) 226 | args = timesteps[:, None].float() * freqs[None] 227 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 228 | if dim % 2: 229 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 230 | return embedding 231 | 232 | 233 | def checkpoint(func, inputs, params, flag): 234 | """ 235 | Evaluate a function without caching intermediate activations, allowing for 236 | reduced memory at the expense of extra compute in the backward pass. 237 | 238 | :param func: the function to evaluate. 239 | :param inputs: the argument sequence to pass to `func`. 240 | :param params: a sequence of parameters `func` depends on but does not 241 | explicitly take as arguments. 242 | :param flag: if False, disable gradient checkpointing. 243 | """ 244 | if flag: 245 | args = tuple(inputs) + tuple(params) 246 | return CheckpointFunction.apply(func, len(inputs), *args) 247 | else: 248 | return func(*inputs) 249 | 250 | 251 | class CheckpointFunction(th.autograd.Function): 252 | @staticmethod 253 | def forward(ctx, run_function, length, *args): 254 | ctx.run_function = run_function 255 | ctx.input_tensors = list(args[:length]) 256 | ctx.input_params = list(args[length:]) 257 | with th.no_grad(): 258 | output_tensors = ctx.run_function(*ctx.input_tensors) 259 | return output_tensors 260 | 261 | @staticmethod 262 | def backward(ctx, *output_grads): 263 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 264 | with th.enable_grad(): 265 | # Fixes a bug where the first op in run_function modifies the 266 | # Tensor storage in place, which is not allowed for detach()'d 267 | # Tensors. 268 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 269 | output_tensors = ctx.run_function(*shallow_copies) 270 | input_grads = th.autograd.grad( 271 | output_tensors, 272 | ctx.input_tensors + ctx.input_params, 273 | output_grads, 274 | allow_unused=True, 275 | ) 276 | del ctx.input_tensors 277 | del ctx.input_params 278 | del output_tensors 279 | return (None, None) + input_grads 280 | -------------------------------------------------------------------------------- /mm_diffusion/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | from . import logger 23 | 24 | 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + math.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | state = self.state[p] 93 | if len(state) == 0: 94 | return [0] 95 | if group['t_total'] != -1: 96 | schedule_fct = SCHEDULES[group['schedule']] 97 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 98 | else: 99 | lr_scheduled = group['lr'] 100 | lr.append(lr_scheduled) 101 | return lr 102 | 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Arguments: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | loss = closure() 112 | 113 | for group in self.param_groups: 114 | for p in group['params']: 115 | if p.grad is None: 116 | continue 117 | 118 | grad = p.grad.data 119 | 120 | if grad.is_sparse: 121 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 122 | 123 | state = self.state[p] 124 | 125 | # State initialization 126 | if len(state) == 0: 127 | state['step'] = 0 128 | # Exponential moving average of gradient values 129 | state['next_m'] = torch.zeros_like(p.data) 130 | # Exponential moving average of squared gradient values 131 | state['next_v'] = torch.zeros_like(p.data) 132 | 133 | next_m, next_v = state['next_m'], state['next_v'] 134 | beta1, beta2 = group['b1'], group['b2'] 135 | 136 | # Add grad clipping 137 | if group['max_grad_norm'] > 0: 138 | clip_grad_norm_(p, group['max_grad_norm']) 139 | 140 | # Decay the first and second moment running average coefficient 141 | # In-place operations to update the averages at the same time 142 | # next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7 143 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1) 144 | # next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7 145 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 146 | update = next_m / (next_v.sqrt() + group['e']) 147 | 148 | # Just adding the square of the weights to the loss function is *not* 149 | # the correct way of using L2 regularization/weight decay with Adam, 150 | # since that will interact with the m and v parameters in strange ways. 151 | # 152 | # Instead we want to decay the weights in a manner that doesn't interact 153 | # with the m/v parameters. This is equivalent to adding the square 154 | # of the weights to the loss with plain (non-momentum) SGD. 155 | if group['weight_decay'] > 0.0: 156 | update += group['weight_decay'] * p.data 157 | 158 | if group['t_total'] != -1: 159 | schedule_fct = SCHEDULES[group['schedule']] 160 | progress = state['step']/group['t_total'] 161 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 162 | else: 163 | lr_scheduled = group['lr'] 164 | 165 | update_with_lr = lr_scheduled * update 166 | p.data.add_(-update_with_lr) 167 | 168 | state['step'] += 1 169 | 170 | return loss -------------------------------------------------------------------------------- /mm_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /mm_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | if self.rescale_timesteps: 128 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /py_scripts/eval.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.path.dirname (os.path.dirname (os.path.abspath (__file__)))) 3 | import argparse 4 | from mm_diffusion import dist_util, logger 5 | from mm_diffusion.evaluator import eval_multimodal 6 | from mm_diffusion.common import delete_pkl 7 | 8 | 9 | # command: mpiexec -n 4 python py_scripts/eval.py --devices 0,1,2,3 10 | 11 | def main( 12 | ): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--ref_dir", type=str, default="/home/v-zixituo/rld/data/landscape/train", help="path to reference batch npz file") 15 | parser.add_argument("--fake_dir", type=str, default="/home/v-zixituo/rld/outputs/MM-Diffusion/samples/video-sample-sr/landscape_16x64x64_bs128_res2_channel128_linear1000_att248_dropout0.1/ema_0.9999_100000.pt/original", help="path to sample batch npz file") 16 | parser.add_argument("--output_dir", type=str, default="/home/v-zixituo/rld/outputs/MM-Diffusion/video-eval/debug", help="" ) 17 | parser.add_argument("--sample_num", type=int, default=2048) 18 | parser.add_argument("--devices", type=str, default="G8") 19 | args = parser.parse_args() 20 | 21 | dist_util.setup_dist(args.devices) 22 | logger.configure(dir=args.output_dir, log_suffix="_val") 23 | 24 | metric = eval_multimodal(args.ref_dir, args.fake_dir, eval_num=args.sample_num) 25 | logger.log(f"metric:{metric}") 26 | delete_pkl(args.fake_dir) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | 31 | -------------------------------------------------------------------------------- /py_scripts/image_sr_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse, sys, os 6 | sys.path.append(os.path.dirname (os.path.dirname (os.path.abspath (__file__)))) 7 | import torch.nn.functional as F 8 | from mm_diffusion import dist_util, logger 9 | from mm_diffusion.real_image_datasets import load_data 10 | from mm_diffusion.resample import create_named_schedule_sampler 11 | from mm_diffusion.common import set_seed_logger_random 12 | from mm_diffusion.script_util import ( 13 | image_sr_model_and_diffusion_defaults, 14 | image_sr_create_model_and_diffusion, 15 | args_to_dict, 16 | add_dict_to_argparser, 17 | ) 18 | from mm_diffusion.train_util import TrainLoop 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | 23 | dist_util.setup_dist(args.devices) 24 | logger.configure(args.output_dir) 25 | args = set_seed_logger_random(args) 26 | 27 | logger.log("creating model...") 28 | model, diffusion = image_sr_create_model_and_diffusion( 29 | **args_to_dict(args, image_sr_model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_superres_data(args) 36 | 37 | logger.log("training...") 38 | TrainLoop( 39 | model=model, 40 | diffusion=diffusion, 41 | data=data, 42 | batch_size=args.batch_size, 43 | microbatch=args.microbatch, 44 | lr=args.lr, 45 | ema_rate=args.ema_rate, 46 | log_interval=args.log_interval, 47 | save_interval=args.save_interval, 48 | resume_checkpoint=args.resume_checkpoint, 49 | use_fp16=args.use_fp16, 50 | fp16_scale_growth=args.fp16_scale_growth, 51 | schedule_sampler=schedule_sampler, 52 | weight_decay=args.weight_decay, 53 | lr_anneal_steps=args.lr_anneal_steps, 54 | use_db=args.use_db, 55 | save_type=args.save_type, 56 | class_cond=args.sr_class_cond, 57 | sample_fn=args.sample_fn 58 | ).run_loop() 59 | 60 | 61 | def load_superres_data(args): 62 | data = load_data( 63 | data_dir=args.data_dir, 64 | batch_size=args.batch_size, 65 | image_size=args.large_size, 66 | class_cond=args.sr_class_cond, 67 | num_workers=args.num_workers 68 | ) 69 | for small_batch, large_batch, sr_batch, model_kwargs in data: 70 | model_kwargs["low_res"] = small_batch 71 | yield small_batch, large_batch, sr_batch, model_kwargs 72 | 73 | 74 | 75 | def create_argparser(): 76 | defaults = dict( 77 | data_dir="", 78 | schedule_sampler="uniform", 79 | lr=1e-4, 80 | weight_decay=0.0, 81 | lr_anneal_steps=0, 82 | batch_size=1, 83 | microbatch=-1, 84 | ema_rate="0.9999", 85 | log_interval=10, 86 | save_interval=10000, 87 | resume_checkpoint="", 88 | use_fp16=False, 89 | fp16_scale_growth=1e-3, 90 | frame_gap=8, 91 | num_workers=4, 92 | use_db=False, 93 | devices="0", 94 | output_dir="~/tmp", 95 | seed=42, 96 | save_type='one', 97 | sample_fn='dpm_solver' 98 | ) 99 | defaults.update(image_sr_model_and_diffusion_defaults()) 100 | parser = argparse.ArgumentParser() 101 | add_dict_to_argparser(parser, defaults) 102 | return parser 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /py_scripts/multimodal_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on audio-video pairs. 3 | """ 4 | import sys,os 5 | sys.path.append(os.path.dirname (os.path.dirname (os.path.abspath (__file__)))) 6 | import argparse 7 | from mm_diffusion import dist_util, logger 8 | from mm_diffusion.multimodal_datasets import load_data 9 | from mm_diffusion.resample import create_named_schedule_sampler 10 | from mm_diffusion.multimodal_script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser 15 | ) 16 | from mm_diffusion.multimodal_train_util import TrainLoop 17 | from mm_diffusion.common import set_seed_logger_random 18 | 19 | 20 | def load_training_data(args): 21 | data = load_data( 22 | data_dir=args.data_dir, 23 | batch_size=args.batch_size, 24 | video_size=args.video_size, 25 | audio_size=args.audio_size, 26 | num_workers=args.num_workers, 27 | video_fps=args.video_fps, 28 | audio_fps=args.audio_fps 29 | ) 30 | 31 | for video_batch, audio_batch in data: 32 | gt_batch = {"video": video_batch, "audio":audio_batch} 33 | 34 | yield gt_batch 35 | 36 | 37 | 38 | def main(): 39 | args = create_argparser().parse_args() 40 | args.video_size = [int(i) for i in args.video_size.split(',')] 41 | args.audio_size = [int(i) for i in args.audio_size.split(',')] 42 | logger.configure(args.output_dir) 43 | dist_util.setup_dist(args.devices) 44 | 45 | args = set_seed_logger_random(args) 46 | 47 | logger.log("creating model and diffusion...") 48 | 49 | model, diffusion = create_model_and_diffusion( 50 | **args_to_dict(args, [key for key in model_and_diffusion_defaults().keys()]) 51 | ) 52 | 53 | model.to(dist_util.dev()) 54 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 55 | 56 | logger.log("creating data loader...") 57 | 58 | data = load_training_data(args) 59 | 60 | 61 | TrainLoop( 62 | model=model, 63 | diffusion=diffusion, 64 | data=data, 65 | save_type=args.save_type, 66 | batch_size=args.batch_size, 67 | microbatch=args.microbatch, 68 | ema_rate=args.ema_rate, 69 | log_interval=args.log_interval, 70 | save_interval=args.save_interval, 71 | resume_checkpoint=args.resume_checkpoint, 72 | lr=args.lr, 73 | t_lr=args.t_lr, 74 | use_fp16=args.use_fp16, 75 | fp16_scale_growth=args.fp16_scale_growth, 76 | schedule_sampler=schedule_sampler, 77 | weight_decay=args.weight_decay, 78 | lr_anneal_steps=args.lr_anneal_steps, 79 | use_db=args.use_db, 80 | sample_fn=args.sample_fn, 81 | video_fps= args.video_fps, 82 | audio_fps= args.audio_fps, 83 | ).run_loop() 84 | 85 | 86 | def create_argparser(): 87 | defaults = dict( 88 | data_dir="", 89 | schedule_sampler="uniform", 90 | lr=0.0, 91 | t_lr=1e-4, 92 | seed=42, 93 | weight_decay=0.0, 94 | lr_anneal_steps=0, 95 | batch_size=1, 96 | num_workers=0, 97 | save_type="mp4", 98 | microbatch=-1, # -1 disables microbatches 99 | ema_rate="0.9999", # comma-separated list of EMA values 100 | log_interval=100, 101 | devices=None, 102 | save_interval=10000, 103 | output_dir="", 104 | resume_checkpoint="", 105 | use_fp16=False, 106 | fp16_scale_growth=1e-3, 107 | use_db=False, 108 | sample_fn="dpm_solver", 109 | frame_gap=1, 110 | video_fps=10, 111 | audio_fps=16000, 112 | ) 113 | defaults.update(model_and_diffusion_defaults()) 114 | parser = argparse.ArgumentParser() 115 | add_dict_to_argparser(parser, defaults) 116 | return parser 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /py_scripts/video2audio_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import sys,os 6 | sys.path.append(os.path.dirname (os.path.dirname (os.path.abspath (__file__)))) 7 | import argparse 8 | import os 9 | import torch as th 10 | import torch.distributed as dist 11 | from mm_diffusion import dist_util, logger 12 | from mm_diffusion.multimodal_script_util import ( 13 | model_and_diffusion_defaults, 14 | create_model_and_diffusion, 15 | add_dict_to_argparser, 16 | args_to_dict 17 | ) 18 | from mm_diffusion.common import set_seed_logger_random, save_audio, save_img, save_multimodal, delete_pkl 19 | from mm_diffusion.evaluator import eval_multimodal 20 | from mm_diffusion.multimodal_datasets import load_data 21 | 22 | 23 | 24 | def load_training_data(args): 25 | data = load_data( 26 | data_dir=args.ref_path, 27 | batch_size=args.batch_size, 28 | video_size=args.video_size, 29 | audio_size=args.audio_size, 30 | num_workers=args.num_workers, 31 | video_fps=args.video_fps, 32 | audio_fps=args.audio_fps 33 | ) 34 | 35 | for video_batch, audio_batch in data: 36 | 37 | gt_batch = {"video": video_batch, "audio":audio_batch} 38 | model_kwargs = {} 39 | yield gt_batch, model_kwargs 40 | 41 | def main(): 42 | args = create_argparser().parse_args() 43 | args.video_size = [int(i) for i in args.video_size.split(',')] 44 | args.audio_size = [int(i) for i in args.audio_size.split(',')] 45 | 46 | dist_util.setup_dist(args.devices) 47 | logger.configure(args.output_dir) 48 | args = set_seed_logger_random(args) 49 | 50 | logger.log("creating model and diffusion...") 51 | model, diffusion = create_model_and_diffusion( 52 | **args_to_dict(args, [key for key in model_and_diffusion_defaults().keys()]) 53 | ) 54 | 55 | if os.path.isdir(args.model_path): 56 | model_name_list = [model_name for model_name in os.listdir(args.model_path) \ 57 | if (model_name.startswith('model') and model_name.endswith('.pt') and int(model_name.split('.')[0][5:])>= args.skip_steps)] 58 | model_name_list.sort() 59 | model_path_list = [os.path.join(args.model_path, model_name) for model_name in model_name_list[::1]] 60 | else: 61 | model_path_list = [model_path for model_path in args.model_path.split(',')] 62 | 63 | logger.log(f"models waiting to be evaluated:{model_path_list}") 64 | data = load_training_data(args) 65 | for model_path in model_path_list: 66 | model.load_state_dict_( 67 | dist_util.load_state_dict(model_path, map_location="cpu"), is_strict=args.is_strict 68 | ) 69 | 70 | model.to(dist_util.dev()) 71 | if args.use_fp16: 72 | model.convert_to_fp16() 73 | model.eval() 74 | 75 | logger.log(f"conditional sampling samples for {model_path}") 76 | model_name = model_path.split('/')[-1] 77 | 78 | groups= 0 79 | gt_save_path = os.path.join(args.output_dir, model_name, "gt") 80 | reconstruct_save_path = os.path.join(args.output_dir, model_name, "reconstract") 81 | audio_save_path = os.path.join(args.output_dir, model_name, "audio") 82 | if dist.get_rank() == 0: 83 | os.makedirs(gt_save_path, exist_ok=True) 84 | os.makedirs(reconstruct_save_path, exist_ok=True) 85 | os.makedirs(audio_save_path, exist_ok=True) 86 | 87 | while groups * args.batch_size * dist.get_world_size() < args.all_save_num: 88 | 89 | 90 | batch_data, model_kwargs = next(data) 91 | 92 | # save gt 93 | idx = 0 94 | for video, audio in zip(batch_data["video"], batch_data["audio"]): 95 | video = video.permute(0, 2, 3, 1) 96 | video = ((video + 1) * 127.5).clamp(0, 255).to(th.uint8).numpy() 97 | audio = audio.numpy() 98 | video_output_path = os.path.join(gt_save_path, f"{args.sample_fn}_samples_{groups}_{dist.get_rank()}_{idx}.mp4") 99 | save_multimodal(video, audio, video_output_path, args) 100 | idx += 1 101 | 102 | model_kwargs["video"] = batch_data["video"].to(dist_util.dev()) 103 | 104 | shape = {"video":(args.batch_size , *args.video_size), \ 105 | "audio":(args.batch_size , *args.audio_size) 106 | } 107 | 108 | if args.sample_fn == 'dpm_solver': 109 | #TODO 110 | print("dpm_solver is not implemented yet..") 111 | 112 | 113 | elif args.sample_fn == 'dpm_solver++': 114 | #TODO 115 | print("dpm_solver++ is not implemented yet..") 116 | 117 | 118 | else: 119 | sample_fn = ( 120 | diffusion.conditional_p_sample_loop if args.sample_fn=="ddpm" else diffusion.ddim_sample_loop 121 | ) 122 | 123 | sample = sample_fn( 124 | model, 125 | shape=shape, 126 | use_fp16 = args.use_fp16, 127 | clip_denoised=args.clip_denoised, 128 | model_kwargs=model_kwargs, 129 | class_scale=args.classifier_scale 130 | 131 | ) 132 | 133 | video = ((sample["video"] + 1) * 127.5).clamp(0, 255).to(th.uint8) 134 | audio = sample["audio"] 135 | video = video.permute(0, 1, 3, 4, 2) 136 | video = video.contiguous() 137 | 138 | all_videos = video.cpu().numpy() 139 | all_audios = audio.cpu().numpy() 140 | 141 | 142 | idx = 0 143 | for video, audio in zip(all_videos, all_audios): 144 | video_output_path = os.path.join(reconstruct_save_path, f"{args.sample_fn}_samples_{groups}_{dist.get_rank()}_{idx}.mp4") 145 | audio_output_path = os.path.join(audio_save_path, f"{args.sample_fn}_samples_{groups}_{dist.get_rank()}_{idx}.wav") 146 | 147 | save_multimodal(video, audio, video_output_path, args) 148 | save_audio(audio, audio_output_path, args.audio_fps) 149 | 150 | idx += 1 151 | 152 | groups += 1 153 | dist.barrier() 154 | 155 | logger.log("sampling complete") 156 | 157 | 158 | def create_argparser(): 159 | defaults = dict( 160 | clip_denoised=True, 161 | 162 | batch_size=16, 163 | sample_fn="ddpm", 164 | model_path="", 165 | output_dir="", 166 | save_type="mp4", 167 | classifier_scale=0.0, 168 | devices=None, 169 | is_strict=True, 170 | all_save_num= 1024, 171 | seed=42, 172 | video_fps=10, 173 | audio_fps=16000, 174 | ref_path = "", 175 | num_workers=4 176 | ) 177 | 178 | defaults.update(model_and_diffusion_defaults()) 179 | parser = argparse.ArgumentParser() 180 | add_dict_to_argparser(parser, defaults) 181 | return parser 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | audioread==3.0.0 3 | av==10.0.0 4 | blobfile==2.0.1 5 | brotlipy==0.7.0 6 | certifi==2022.12.7 7 | cffi==1.15.1 8 | charset-normalizer==2.0.4 9 | click==8.1.3 10 | cryptography==39.0.1 11 | decorator==4.4.2 12 | docker-pycreds==0.4.0 13 | einops==0.6.0 14 | einops-exts==0.0.4 15 | filelock==3.9.0 16 | flit_core==3.6.0 17 | ftfy==6.1.1 18 | gitdb==4.0.10 19 | GitPython==3.1.31 20 | idna==3.4 21 | imageio==2.26.0 22 | imageio-ffmpeg==0.4.8 23 | importlib-metadata==6.0.0 24 | joblib==1.2.0 25 | jsonpatch==1.32 26 | jsonpointer==2.3 27 | lazy_loader==0.1 28 | librosa==0.10.0 29 | llvmlite==0.39.1 30 | lpips==0.1.4 31 | lxml==4.9.2 32 | mkl-fft==1.3.1 33 | mkl-random==1.2.2 34 | mkl-service==2.4.0 35 | moviepy==1.0.3 36 | msgpack==1.0.4 37 | networkx==3.0 38 | numba==0.56.4 39 | numpy==1.23.5 40 | opencv-python==4.7.0.72 41 | opencv-python-headless==4.7.0.72 42 | packaging==23.0 43 | pandas==1.5.3 44 | pathtools==0.1.2 45 | Pillow==9.4.0 46 | pip==23.0.1 47 | platformdirs==3.1.0 48 | pooch==1.7.0 49 | proglog==0.1.10 50 | protobuf==4.22.1 51 | psutil==5.9.4 52 | pycparser==2.21 53 | pycryptodomex==3.17 54 | pyOpenSSL==23.0.0 55 | PySocks==1.7.1 56 | python-dateutil==2.8.2 57 | pytorch-ignite==0.4.11 58 | pytz==2022.7.1 59 | PyYAML==6.0 60 | regex==2022.10.31 61 | requests==2.28.1 62 | scikit-learn==1.2.1 63 | scipy==1.10.1 64 | sentry-sdk==1.16.0 65 | setproctitle==1.3.2 66 | setuptools==65.6.3 67 | six==1.16.0 68 | sklearn==0.0.post1 69 | smmap==5.0.0 70 | soundfile==0.12.1 71 | soxr==0.3.4 72 | termcolor==2.2.0 73 | threadpoolctl==3.1.0 74 | tornado==6.2 75 | tqdm==4.65.0 76 | typing_extensions==4.4.0 77 | urllib3==1.26.14 78 | visdom==0.2.4 79 | wandb==0.13.11 80 | wcwidth==0.2.6 81 | websocket-client==1.5.1 82 | wheel==0.38.4 83 | xmltodict==0.13.0 84 | zipp==3.15.0 -------------------------------------------------------------------------------- /ssh_scripts/audio2video_sample_sr.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--cross_attention_resolutions 2,4,8 --cross_attention_windows 1,4,8 2 | --cross_attention_shift True 3 | --class_cond False --video_attention_resolutions 2,4,8 4 | --audio_attention_resolutions -1 5 | --video_size 16,3,64,64 --audio_size 1,25600 --learn_sigma False --num_channels 128 6 | --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True 7 | --use_scale_shift_norm True" 8 | 9 | # if classifier_scale(λ) is 0, conditional generation follows the replacement based method 10 | # if classifier_scale(λ) is larger than 0, conditional generation follows the gradient based method 11 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear 12 | --all_save_num 64 --save_type mp4 --devices 0 13 | --batch_size 1 --is_strict True --sample_fn ddpm --classifier_scale 3.0 " 14 | 15 | SRMODEL_FLAGS="--sr_attention_resolutions 8,16,32 --large_size 256 16 | --small_size 64 --sr_learn_sigma True --sr_class_cond False 17 | --sr_num_channels 192 --sr_num_heads 4 --sr_num_res_blocks 2 18 | --sr_resblock_updown True --sr_use_scale_shift_norm True --num_workers 4" 19 | 20 | SR_DIFFUSION_FLAGS="--sr_diffusion_steps 1000 --sr_sample_fn ddim --sr_timestep_respacing ddim25 " 21 | 22 | MULTIMODAL_MODEL_PATH="/data10/rld/outputs/MM-Diffusion/models/AIST++.pt" 23 | SR_MODEL_PATH="/data10/rld/ouotputs/MM-Diffusion/models/AIST++_SR.pt" 24 | OUT_DIR="/data10/rld/outputs/MM-Diffusion/audio2video/" 25 | REF_PATH="/data10/rld/data/AIST++_crop/train" 26 | NUM_GPUS=1 27 | 28 | mpiexec -n $NUM_GPUS python py_scripts/audio2video_sample_sr.py \ 29 | $MODEL_FLAGS $DIFFUSION_FLAGS $SRMODEL_FLAGS $SR_DIFFUSION_FLAGS\ 30 | --output_dir ${OUT_DIR} --multimodal_model_path ${MULTIMODAL_MODEL_PATH} --ref_path ${REF_PATH} --sr_model_path ${SR_MODEL_PATH} 31 | -------------------------------------------------------------------------------- /ssh_scripts/image_sr_train.sh: -------------------------------------------------------------------------------- 1 | #################64 x 64 -> 256 x 256########## 2 | MODEL_FLAGS="--sr_attention_resolutions 8,16,32 --large_size 256 3 | --small_size 64 --sr_learn_sigma True --sr_class_cond False 4 | --sr_num_channels 192 --sr_num_heads 4 --sr_num_res_blocks 2 5 | --sr_resblock_updown True --use_fp16 True --sr_use_scale_shift_norm True" 6 | 7 | DIFFUSION_FLAGS="--sr_diffusion_steps 1000 --noise_schedule linear" # --use_kl True 8 | 9 | TRAIN_FLAGS="--lr 1e-4 --batch_size 6 --devices G8 --log_interval 100 --sample_fn ddpm 10 | --save_interval 10000 --num_workers 8 --frame_gap 1 11 | --use_db False --resume_checkpoint /data10/rld/outputs/MM-Diffusion/models/guided-diffusion_64_256_upsampler.pt" #--schedule_sampler loss-second-moment --resume_checkpoint models/256x256_diffusion_uncond.pt 12 | 13 | NUM_GPUS=8 14 | DATA_DIR="/data6/rld/data/landscape_png/train" 15 | OUT_DIR="/data6/rld/outputs/MM-Diffusion/sr-image-train/" 16 | 17 | mpiexec -n $NUM_GPUS --allow-run-as-root python py_scripts/image_sr_train.py --data_dir $DATA_DIR --output_dir ${OUT_DIR} $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 18 | -------------------------------------------------------------------------------- /ssh_scripts/multimodal_eval.sh: -------------------------------------------------------------------------------- 1 | REF_DIR="/home/v-zixituo/rld/dataset/landscape/train" 2 | SAMPLE_DIR="/home/v-zixituo/rld/outputs/MM-Diffusion/samples/landscape/sr_mp4" 3 | OUTPUT_DIR="/home/v-zixituo/rld/outputs/MM-Diffusion/eval/debug" 4 | 5 | mpiexec -n 1 python py_scripts/eval.py --devices 0 --sample_num 2048 --ref_dir ${REF_DIR} --fake_dir ${SAMPLE_DIR} --output_dir ${OUTPUT_DIR} -------------------------------------------------------------------------------- /ssh_scripts/multimodal_sample_sr.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | MODEL_FLAGS="--cross_attention_resolutions 2,4,8 --cross_attention_windows 1,4,8 4 | --cross_attention_shift True --video_attention_resolutions 2,4,8 5 | --audio_attention_resolutions -1 6 | --video_size 16,3,64,64 --audio_size 1,25600 --learn_sigma False --num_channels 128 7 | --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True 8 | --use_scale_shift_norm True" 9 | 10 | SRMODEL_FLAGS="--sr_attention_resolutions 8,16,32 --large_size 256 11 | --small_size 64 --sr_learn_sigma True 12 | --sr_num_channels 192 --sr_num_heads 4 --sr_num_res_blocks 2 13 | --sr_resblock_updown True --use_fp16 True --sr_use_scale_shift_norm True" 14 | 15 | # Modify --devices according your GPU number 16 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear 17 | --all_save_num 64 --save_type mp4 --devices 0,1,2,3 18 | --batch_size 4 --is_strict True --sample_fn dpm_solver" 19 | 20 | SR_DIFFUSION_FLAGS="--sr_diffusion_steps 1000 --sr_sample_fn ddim --sr_timestep_respacing ddim25" 21 | 22 | # Modify the following paths to your own paths 23 | MULTIMODAL_MODEL_PATH="/data10/rld/outputs/MM-Diffusion/models/AIST++.pt" 24 | SR_MODEL_PATH="/data10/rld/outputs/MM-Diffusion/models/AIST++_SR.pt" 25 | OUT_DIR="/data10/rld/outputs/MM-Diffusion/samples/multimodal-sample-sr/dpm_solver" 26 | REF_PATH="/data10/rld/dataset/AIST++_crop/train" 27 | 28 | NUM_GPUS=4 29 | mpiexec -n $NUM_GPUS python3 py_scripts/multimodal_sample_sr.py \ 30 | $MODEL_FLAGS $SRMODEL_FLAGS $DIFFUSION_FLAGS $SR_DIFFUSION_FLAGS --ref_path ${REF_PATH} \ 31 | --output_dir ${OUT_DIR} --multimodal_model_path ${MULTIMODAL_MODEL_PATH} --sr_model_path ${SR_MODEL_PATH} 32 | -------------------------------------------------------------------------------- /ssh_scripts/multimodal_train.sh: -------------------------------------------------------------------------------- 1 | 2 | #################64 x 64 uncondition########################################################### 3 | MODEL_FLAGS="--cross_attention_resolutions 2,4,8 --cross_attention_windows 1,4,8 4 | --cross_attention_shift True --dropout 0.1 5 | --video_attention_resolutions 2,4,8 6 | --audio_attention_resolutions -1 7 | --video_size 16,3,64,64 --audio_size 1,25600 --learn_sigma False --num_channels 128 8 | --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True 9 | --use_scale_shift_norm True --num_workers 4" 10 | 11 | # Modify --devices to your own GPU ID 12 | TRAIN_FLAGS="--lr 0.0001 --batch_size 4 13 | --devices G8 --log_interval 100 --save_interval 10000 --use_db False " #--schedule_sampler loss-second-moment 14 | DIFFUSION_FLAGS="--noise_schedule linear --diffusion_steps 1000 --save_type mp4 --sample_fn ddpm" 15 | 16 | # Modify the following pathes to your own paths 17 | DATA_DIR="/home/rld/dataset/landscape/train" 18 | OUTPUT_DIR="/home/rld/outputs/MM-Diffusion/debug/" 19 | NUM_GPUS=8 20 | 21 | mpiexec -n $NUM_GPUS python3 py_scripts/multimodal_train.py --data_dir ${DATA_DIR} --output_dir ${OUTPUT_DIR} $MODEL_FLAGS $TRAIN_FLAGS $VIDEO_FLAGS $DIFFUSION_FLAGS 22 | -------------------------------------------------------------------------------- /ssh_scripts/video2audio_sample.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--cross_attention_resolutions 2,4,8 --cross_attention_windows 1,4,8 2 | --cross_attention_shift True 3 | --class_cond False --video_attention_resolutions 2,4,8 4 | --audio_attention_resolutions -1 5 | --video_size 16,3,64,64 --audio_size 1,25600 --learn_sigma False --num_channels 128 6 | --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True 7 | --use_scale_shift_norm True" 8 | 9 | # if classifier_scale(λ) is 0, conditional generation follows the replacement based method 10 | # if classifier_scale(λ) is larger than 0, conditional generation follows the gradient based method 11 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear 12 | --all_save_num 48 --save_type mp4 --devices 7 13 | --batch_size 2 --is_strict True --sample_fn ddpm --classifier_scale 3.0" 14 | 15 | MODEL_PATH="/data10/rld/outputs/MM-Diffusion/models/AIST++.pt" 16 | OUT_DIR="/data10/rld/outputs/MM-Diffusion/video2audio/video2audio" 17 | REF_PATH="/data10/rld/data/AIST++_crop/train" 18 | NUM_GPUS=1 19 | 20 | mpiexec -n $NUM_GPUS python3 py_scripts/video2audio_sample.py \ 21 | $MODEL_FLAGS $DIFFUSION_FLAGS \ 22 | --output_dir ${OUT_DIR} --model_path ${MODEL_PATH} --ref_path ${REF_PATH} 23 | --------------------------------------------------------------------------------