├── .gitignore ├── README.md ├── infer_i2a.py ├── infer_mm2a.py ├── infer_v2a.py ├── requirements.txt └── ssv2a ├── __init__.py ├── data ├── __init__.py ├── dataset.py ├── detect.py ├── pairs.py ├── tpairs.py └── utils.py ├── evals ├── __init__.py ├── cluster.py ├── cs.py ├── fad.py └── ms.py └── model ├── __init__.py ├── aggregator.py ├── aldm.py ├── clap.py ├── dalle2_prior.py ├── generator.py ├── manifold.py ├── modules.py ├── pipeline.py └── remixer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gotta Hear Them All: Sound Source-Aware Vision to Audio Generation 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2411.15447-brightgreen?logo=arxiv&logoColor=white&style=flat-square)](https://arxiv.org/abs/2411.15447) [![githubio](https://img.shields.io/badge/GitHub.io-Demo_Website-blue?logo=Github&logoColor=white&style=flat-square)](https://ssv2a.github.io/SSV2A-demo/) [![Hugging Face Spaces](https://img.shields.io/badge/Gradio-Interactive_Demo-orange?logo=huggingface&logoColor=white&style=flat-square)](https://ssv2a.ngrok.io/) 4 | 5 | **Flexibly generate sounds by composing visual, text, and audio sound source prompts.** 6 | 7 | In order to run our code, please clone the repository and follow these instructions to set up a virtual environment: 8 | 9 | 1. `conda create -n SSV2A python==3.10` 10 | 2. `pip install -r requirements.txt` 11 | 12 | The `ssv2a` module provides implementations for SSV2A. We also provide scripts for major functions below. 13 | 14 | ## Scheduled Releases 15 | - [ ] Distribute the VGG Sound Single Source (VGGS3) dataset. 16 | - [x] Upload code for multimodal inference. 17 | - [x] Upload code for vision-to-audio inference. 18 | 19 | ## Pretrained Weights 20 | We provide pretrained weights of SSV2A modules at [this google drive link](https://drive.google.com/drive/folders/17SAuZ2sZrTYf21BiNKhRsEfdj-fbeQQN?usp=sharing), 21 | which has the following contents: 22 | 23 | | Files | Comment | 24 | |------------|--------------------------------------------------------------------------------------| 25 | | ssv2a.json | Configuration File of SSV2A | 26 | | ssv2a.pth | Pretrained Checkpoint of SSV2A | 27 | | agg.pth | Pretrained Checkpoint of Temporal Aggregation Module (for video-to-audio generation) | 28 | 29 | Please download them according to your usage cases. 30 | 31 | As SSV2A works with [YOLOv8](https://docs.ultralytics.com/models/yolov8/) for visual sound source detection, 32 | it also needs to include a pretrained YOLO checkpoint for inference. We recommend using [yolov8x-oi7](https://docs.ultralytics.com/datasets/detect/open-images-v7/) 33 | pretrained on the OpenImagesV7 dataset. After downloading this model, paste its path in the `"detection-model"` field in `ssv2a.json`. 34 | 35 | ## Inference 36 | There are several hyperparameters you can adjust to control the generation fidelity/diversity/relevance. We list them here: 37 | 38 | | Parameter | Default Value | Comment | 39 | |-------------------|----|-------------------------------------------------------------------------------------------------------------------------------| 40 | | `--var_samples` | 64 | Number of variational samples drawn in each generation and averaged. Higher number increases fidelity and decreases diversity. | 41 | | `--cycle_its` | 64 | Number of Cycle Mix iterations. Higher number increases generation relevance to given conditions. | 42 | | `--cycle_samples` | 64 | Number of variational samples drawn in each Cycle Mix iteration. Higher number increases fidelity and decreases diversity. | 43 | | `--duration` | 10 | Length of generated audio in seconds. | 44 | | `--seed` | 42 | Random seed for generation. | 45 | 46 | ### Image to Audio Generation 47 | Navigate to the root directory of this repo and execute the following script: 48 | 49 | ```shell 50 | python infer_i2a.py \ 51 | --cfg "ssv2a.json" \ 52 | --ckpt "ssv2a.pth" \ 53 | --image_dir "./images" \ 54 | --out_dir "./output" 55 | ``` 56 | Replace the arguments with the actual path names on your machine. 57 | 58 | ### Video to Audio Generation 59 | Navigate to the root directory of this repo and execute the following script: 60 | 61 | ```shell 62 | python infer_v2a.py \ 63 | --cfg "ssv2a.json" \ 64 | --ckpt "ssv2a.pth" \ 65 | --agg_ckpt "agg.pth" \ 66 | --image_dir "/images" \ 67 | --out_dir "./output" 68 | ``` 69 | Replace the arguments with the actual path names on your machine. 70 | 71 | ### Multimodal Sound Source Composition 72 | SSV2A accepts multimodal conditions where you describe sound sources as image, text, or audio. 73 | 74 | You need to download the DALLE-2 Prior module first in order to close the modality gap of text conditions in CLIP. 75 | We recommend [this version pretrained by LAION](https://huggingface.co/laion/DALLE2-PyTorch). 76 | You can also download from [our drive](https://drive.google.com/drive/folders/17SAuZ2sZrTYf21BiNKhRsEfdj-fbeQQN?usp=sharing): 77 | 78 | | Item | File | 79 | |--------------------|------| 80 | | Configuration File | dalle2_prior_config.json | 81 | | Checkpoint | dalle2_prior.pth | 82 | 83 | When these are ready, navigate to the root directory of this repo and execute the following script: 84 | 85 | ```shell 86 | python infer_v2a.py \ 87 | --cfg "ssv2a.json" \ 88 | --ckpt "ssv2a.pth" \ 89 | --dalle2_cfg "dalle2_prior_config.json" \ 90 | --dalle2_ckpt "dalle2_prior.pth" \ 91 | --images "talking_man.png" "dog.png" \ 92 | --texts "raining heavily" "street ambient" \ 93 | --audios "thunder.wav" \ 94 | --out_dir "./output/audio.wav" 95 | ``` 96 | 97 | Here are some argument specifications: 98 | 1. `--images` takes visual conditions as a list of images as `.png` or `.jpg` files. 99 | 2. `--texts` takes text conditions as a list of strings. 100 | 3. `--audios` takes audio conditions as a list of `.wav`, `.flac`, or `.mp3` files. 101 | 102 | Note that this script, unlike our I2A and V2A codes, only support single-sample inference instead of batches. 103 | We support a maximum of 64 sound source condition slots in total for generation. 104 | You can leave any modality blank for flexibility. You can also only supply one modality only, such as texts. 105 | 106 | Feel free to play with this feature and let your imagination run wild :) 107 | 108 | ## Cite this work 109 | If you find our work useful, please consider citing 110 | 111 | ```bibtex 112 | @article{SSV2A, 113 | title={Gotta Hear Them All: Sound Source Aware Vision to Audio Generation}, 114 | author={Guo, Wei and Wang, Heng and Ma, Jianbo and Cai, Weidong}, 115 | journal={arXiv preprint arXiv:2411.15447}, 116 | year={2024} 117 | } 118 | ``` 119 | 120 | ## References 121 | SSV2A has made friends with several models. 122 | We list major references in our code here: 123 | 124 | 1. [AudioLDM](https://github.com/haoheliu/AudioLDM), by Haohe Liu 125 | 2. [AudioLDM2](https://github.com/haoheliu/AudioLDM2), by Haohe Liu 126 | 3. [LAION-Audio-630K](https://github.com/LAION-AI/audio-dataset), by LAION 127 | 4. [CLAP](https://github.com/LAION-AI/CLAP), by LAION 128 | 3. [frechet-audio-distance](https://github.com/gudgud96/frechet-audio-distance), by Haohao Tan 129 | 4. [DALLE2-pytorch](https://github.com/lucidrains/DALLE2-pytorch), by Phil Wang 130 | 5. [CLIP](https://github.com/openai/CLIP), by OpenAI 131 | 132 | Thank you for the excellent works! Other references are commented inline. 133 | 134 | -------------------------------------------------------------------------------- /infer_i2a.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, './SSV2A') 3 | import argparse 4 | import glob 5 | 6 | from ssv2a.model.pipeline import Pipeline, image_to_audio 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser(description='SSV2A') 10 | parser.add_argument('--cfg', type=str, help='Model Config File') 11 | parser.add_argument('--ckpt', type=str, default=None, help='Pretrained Checkpoint') 12 | parser.add_argument('--image_dir', type=str, default=None, help='Path to the image files') 13 | parser.add_argument('--out_dir', type=str, default='./output', help='Path to save the output audios to') 14 | parser.add_argument('--bs', type=int, default=64, help='batch size') 15 | parser.add_argument('--var_samples', type=int, default=64, help='variational samples') 16 | parser.add_argument('--cycle_its', type=int, default=64, help='number of Cycle Mix iterations') 17 | parser.add_argument('--cycle_samples', type=int, default=64, help='number of Cycle Mix samples') 18 | parser.add_argument('--duration', type=int, default=10, help='generation duration in seconds') 19 | parser.add_argument('--seed', type=int, default=42, help='random seed') 20 | parser.add_argument('--device', type=str, default='cuda', help='Computation Device') 21 | args = parser.parse_args() 22 | 23 | pipe = Pipeline(config=args.cfg, pretrained=args.ckpt, device=args.device) 24 | images = glob.glob(f'{args.image_dir}/*') 25 | image_to_audio(images, text="", transcription="", save_dir=args.out_dir, config=args.cfg, 26 | gen_remix=True, gen_tracks=False, emb_only=False, 27 | pretrained=args.ckpt, batch_size=args.bs, var_samples=args.var_samples, 28 | shuffle_remix=True, cycle_its=args.cycle_its, cycle_samples=args.cycle_samples, 29 | keep_data_cache=False, duration=args.duration, seed=args.seed, device=args.device) 30 | 31 | ''' 32 | python infer.py \ 33 | --cfg "/home/wguo/Repos/SDV2A/checkpoints/JS-kl00005-best/model.json" \ 34 | --ckpt "/home/wguo/Repos/SDV2A/checkpoints/JS-kl00005-best/best_val.pth" \ 35 | --image_dir "/home/wguo/Repos/SDV2A/data/samples/images" \ 36 | --bs 16 37 | ''' -------------------------------------------------------------------------------- /infer_mm2a.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, './SSV2A') 3 | import argparse 4 | 5 | from ssv2a.model.pipeline import Pipeline, srcs_to_audio 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='SSV2A') 9 | parser.add_argument('--cfg', type=str, help='Model Config File') 10 | parser.add_argument('--ckpt', type=str, default=None, help='Pretrained Checkpoint') 11 | parser.add_argument('--dalle2_cfg', type=str, default=None, help='DALLE2 Prior Config File') 12 | parser.add_argument('--dalle2_ckpt', type=str, default=None, help='DALLE2 Prior Pretrained Checkpoint') 13 | parser.add_argument('--images', nargs='+', type=str, default=None, help='Image Conditions') 14 | parser.add_argument('--texts', nargs='+', type=str, default=None, help='Text Conditions') 15 | parser.add_argument('--audios', nargs='+', type=str, default=None, help='Image Conditions') 16 | parser.add_argument('--out_dir', type=str, default='./output', help='Path to save the output audio to') 17 | parser.add_argument('--bs', type=int, default=64, help='batch size') 18 | parser.add_argument('--var_samples', type=int, default=64, help='variational samples') 19 | parser.add_argument('--cycle_its', type=int, default=64, help='number of Cycle Mix iterations') 20 | parser.add_argument('--cycle_samples', type=int, default=64, help='number of Cycle Mix samples') 21 | parser.add_argument('--duration', type=int, default=10, help='generation duration in seconds') 22 | parser.add_argument('--seed', type=int, default=42, help='random seed') 23 | parser.add_argument('--device', type=str, default='cuda', help='Computation Device') 24 | args = parser.parse_args() 25 | 26 | pipe = Pipeline(config=args.cfg, pretrained=args.ckpt, device=args.device) 27 | srcs = { 28 | 'image': [] if args.images is None else args.images, 29 | 'text': [] if args.texts is None else args.texts, 30 | 'audio': [] if args.audios is None else args.audios, 31 | } 32 | srcs_to_audio(srcs, args.out_dir, 33 | config=args.cfg, pretrained=args.ckpt, 34 | dalle2_cfg=args.dalle2_cfg, dalle2_ckpt=args.dalle2_ckpt, 35 | shuffle_remix=True, cycle_its=args.cycle_its, cycle_samples=args.cycle_samples, 36 | var_samples=args.var_samples, batch_size=args.bs, seed=args.seed, 37 | duration=args.duration, device=args.device) 38 | 39 | -------------------------------------------------------------------------------- /infer_v2a.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, './SSV2A') 3 | import argparse 4 | import glob 5 | 6 | from ssv2a.model.pipeline import Pipeline, image_to_audio, video_to_audio 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser(description='SSV2A') 10 | parser.add_argument('--cfg', type=str, help='Model Config File') 11 | parser.add_argument('--ckpt', type=str, default=None, help='Pretrained Checkpoint') 12 | parser.add_argument('--agg_ckpt', type=str, default=None, help='Pretrained Aggregator Checkpoint') 13 | parser.add_argument('--vid_dir', type=str, default=None, help='Path to the video files') 14 | parser.add_argument('--frames', type=int, default=64, help='Total frames to pass to Aggregator per video') 15 | parser.add_argument('--out_dir', type=str, default='./output', help='Path to save the output audios to') 16 | parser.add_argument('--bs', type=int, default=64, help='batch size') 17 | parser.add_argument('--var_samples', type=int, default=64, help='variational samples') 18 | parser.add_argument('--cycle_its', type=int, default=64, help='number of Cycle Mix iterations') 19 | parser.add_argument('--cycle_samples', type=int, default=64, help='number of Cycle Mix samples') 20 | parser.add_argument('--duration', type=int, default=10, help='generation duration in seconds') 21 | parser.add_argument('--seed', type=int, default=42, help='random seed') 22 | parser.add_argument('--device', type=str, default='cuda', help='Computation Device') 23 | args = parser.parse_args() 24 | 25 | pipe = Pipeline(config=args.cfg, pretrained=args.ckpt, device=args.device) 26 | vids = glob.glob(f'{args.vid_dir}/*') 27 | video_to_audio(args.cfg, args.ckpt, vids, args.agg_ckpt, args.out_dir, 28 | agg_var_samples=1, frames=args.frames, 29 | batch_size=args.bs, var_samples=args.var_samples, 30 | cycle_its=args.cycle_its, cycle_samples=args.cycle_samples, 31 | duration=args.duration, seed=args.seed, device=args.device) 32 | 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioldm==0.1.1 2 | audioldm2==0.1.0 3 | clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1 4 | dalle2-pytorch==1.1.0 # need this to accomodate the pretrained DALLE2 prior from LAION 5 | frechet_audio_distance==0.3.1 6 | laion_clap==1.1.6 # you must have this version, instead of 1.1.5 required by frechet-audio-distance 7 | librosa==0.9.2 8 | numpy==1.23.5 9 | scikit-video==1.1.11 10 | torch==2.4.1 11 | torchaudio==2.4.1 12 | torchvision==0.19.1 13 | ultralytics==8.2.90 14 | wandb==0.17.9 15 | wav2clip==0.1.0 16 | kneed==0.8.5 17 | -------------------------------------------------------------------------------- /ssv2a/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wguo86/SSV2A/7f01d5f7a45e6f3bf68b75ade3db30f08978c4b4/ssv2a/__init__.py -------------------------------------------------------------------------------- /ssv2a/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wguo86/SSV2A/7f01d5f7a45e6f3bf68b75ade3db30f08978c4b4/ssv2a/data/__init__.py -------------------------------------------------------------------------------- /ssv2a/data/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class ImageDataset(Dataset): 7 | def __init__(self, img_fs): 8 | self.img_fs = img_fs 9 | 10 | def __len__(self): 11 | return len(self.img_fs) 12 | 13 | def __getitem__(self, idx): 14 | img = Image.open(self.img_fs[idx]) 15 | return np.array(img), self.img_fs[idx] 16 | 17 | -------------------------------------------------------------------------------- /ssv2a/data/detect.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from ultralytics import YOLO 7 | from ultralytics.models.sam import Predictor as SAMPredictor 8 | from PIL import Image 9 | from tqdm.auto import tqdm 10 | 11 | # please link to the CaR modules from https://github.com/google-research/google-research/tree/master/clip_as_rnn 12 | from ssv2a.data.utils import read_classes, video2images, mask2bbox, elbow 13 | 14 | 15 | # detect and segment images, return all or top k segment masks > conf, optionally, save masked images to disk 16 | # default to cropping instead of segmentation if crop=True 17 | def yolo_detect(images, detection_model='yolov8x-worldv2.pt', segment_model="sam_b.pt", resize=None, crop=True, 18 | classes=None, batch_size=64, conf=.5, iou=0.5, max_det=64, top_k=None, save_dir="", device='cuda', **_): 19 | if not os.path.exists(save_dir): 20 | os.makedirs(save_dir) 21 | 22 | model = YOLO(detection_model) 23 | model.to(device) 24 | if 'world' in detection_model and classes is not None: 25 | classes = read_classes(classes) 26 | model.set_classes(classes) 27 | 28 | # automatically determine image size, assuming all images are the same size as image[0] 29 | if resize is not None: 30 | imgsz = resize 31 | else: 32 | sample_img = Image.open(images[0]) 33 | imgsz = sample_img.size 34 | img_area = imgsz[0] * imgsz[1] 35 | 36 | print(f"Detecting objects with {detection_model}:") 37 | segments = {} 38 | for img in images: 39 | segments[img] = [] 40 | 41 | for i in tqdm(range(0, len(images), batch_size)): 42 | e = min(len(images), i + batch_size) 43 | for img in images[i:e]: 44 | oimg = Image.open(img) 45 | if resize is not None and oimg.size != resize: 46 | oimg = oimg.resize(resize, resample=Image.Resampling.BICUBIC) 47 | oimg.save(img, 'PNG') 48 | detect_results = model.predict(images[i:e], imgsz=imgsz, conf=conf, iou=iou, max_det=max_det, 49 | augment=True, verbose=False) 50 | 51 | if crop: 52 | # print("Cropping objects:") 53 | for j, img in enumerate(images[i:e]): 54 | oimg = Image.open(img) 55 | rs = detect_results[j][:top_k] 56 | for z, r in enumerate(rs): 57 | box = r.boxes.xyxy.cpu().tolist()[0] 58 | cimg = oimg.crop(box).resize(imgsz, Image.Resampling.BICUBIC) 59 | cimg_file = Path(save_dir) / Path(images[i:e][j]).name.replace('.png', f'_{z}.png') 60 | cimg.save(cimg_file, 'PNG') 61 | locality = abs(box[2] - box[0]) * abs(box[3] - box[1]) / img_area # locality ratio 62 | segments[img].append((str(cimg_file), locality)) 63 | 64 | else: 65 | # print(f"Segmenting objects with {segment_model}:") 66 | overrides = dict(conf=.25, retina_masks=True, task="segment", mode="predict", 67 | imgsz=imgsz, model=segment_model, save=False, verbose=False, device=device) 68 | model = SAMPredictor(overrides=overrides) 69 | for j in range(len(images[i:e])): 70 | model.set_image(images[i:e][j]) 71 | img = np.array(Image.open(images[i:e][j])) 72 | rs = detect_results[j][:top_k] 73 | for z, r in enumerate(rs): 74 | mask = model(bboxes=r.boxes.xyxy)[0].masks.data.cpu().numpy() 75 | mask = np.squeeze(mask, axis=0).astype(int) 76 | mimg_file = Path(save_dir) / Path(images[i:e][j]).name.replace('.png', f'_{z}.png') 77 | Image.fromarray((img * np.expand_dims(mask, axis=2)).astype(np.uint8)).save(mimg_file, 'PNG') 78 | locality = float(np.sum(mask.astype(int))) / img_area 79 | segments[images[i]].append((str(mimg_file), locality)) 80 | model.reset_image() 81 | 82 | return segments 83 | 84 | 85 | # filter a list of single-source or multi-source videos for true positive visual frames 86 | def yolo_detect_videos(videos, signatures, save_dir, imgsz=(512, 512), fps=4, conf=.5, 87 | detection_model='yolov8x-worldv2.pt', 88 | classes=None, batchsize=64, device='cuda'): 89 | if not os.path.exists(save_dir): 90 | os.makedirs(save_dir) 91 | positives = dict([(vid, []) for vid in videos]) 92 | 93 | model = YOLO(detection_model) 94 | if 'world' in detection_model and classes is not None: 95 | model.set_classes(read_classes(classes)) 96 | signatures = set(signatures) 97 | 98 | for s in tqdm(range(0, len(videos), batchsize)): 99 | e = min(s + batchsize, len(videos)) 100 | for vid in videos[s:e]: 101 | frames, _, _, video_name = video2images(vid, fps=fps) 102 | frames = [Image.fromarray(frames[i]).resize(imgsz, Image.Resampling.BICUBIC) for i in range(len(frames))] 103 | 104 | rs = model.predict(frames, conf=conf, imgsz=imgsz, augment=True, verbose=False, device=device) 105 | for j, r in enumerate(rs): 106 | cls = set([model.names[c] for c in r.boxes.cls.cpu().tolist()]) 107 | if cls == signatures: # bingo 108 | fimg_file = Path(save_dir) / f'{video_name}_{j}.png' 109 | frames[j].save(fimg_file, 'PNG') 110 | positives[vid].append(fimg_file) 111 | 112 | # remove video if it doesn't contain any qualified visual frames 113 | false_pos = [] 114 | for pos in positives: 115 | if len(positives[pos]) == 0: 116 | false_pos.append(pos) 117 | for pos in false_pos: 118 | del positives[pos] 119 | 120 | return positives 121 | 122 | 123 | def detect(images, detector_cfg, save_dir='masked_images', batch_size=64, device='cuda'): 124 | detector_cfg['save_dir'] = save_dir 125 | detector_cfg['batch_size'] = batch_size 126 | detector_cfg['device'] = device 127 | 128 | if 'yolo' in detector_cfg['detection_model']: 129 | return yolo_detect(images, **detector_cfg) 130 | 131 | else: 132 | raise NotImplementedError('Detection model is unsupported.') 133 | 134 | -------------------------------------------------------------------------------- /ssv2a/data/pairs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | from enum import Enum 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import clip 8 | import pandas as pd 9 | import torch 10 | from PIL import Image 11 | from torch.utils.data import Dataset 12 | from tqdm.auto import tqdm 13 | 14 | from ssv2a.data.utils import read_wav_file 15 | from ssv2a.model.clap import CLAP 16 | 17 | 18 | class Mode(Enum): 19 | NA = -1 20 | AUDIO = 0 21 | IMAGE = 1 22 | VIDEO = 2 23 | TEXT = 3 24 | LABEL = 4 25 | 26 | 27 | # batch embed the sources of a list of pairs with CLIP 28 | def clip_embed_pairs(pids, pairs_dir, model=None, preprocess=None, version='ViT-L/14', batch_size=64, device='cuda'): 29 | with torch.no_grad(): 30 | if model is None: 31 | model, preprocess = clip.load(version, device=device) 32 | print(f'Embedding {len(pids)} pairs into CLIP:') 33 | for locality in ['local', 'global', 'context']: 34 | for s in tqdm(range(0, len(pids), batch_size)): 35 | e = min(len(pids), s + batch_size) 36 | pairs = {pid: load_pair(pairs_dir, pid) for pid in pids[s:e]} 37 | texts = [(pid, p.get_sources(f'{locality}_srcs')) for pid, p in pairs.items() 38 | if p.mode == Mode.TEXT or p.mode == Mode.LABEL] 39 | images = [(pid, p.get_sources(f'{locality}_srcs')) for pid, p in pairs.items() 40 | if (p.mode == Mode.IMAGE or p.mode == Mode.VIDEO) and p.data[f'{locality}_clips'] is None] 41 | # embed images 42 | if len(images) > 0: 43 | imgs = [] 44 | for _, image_set in images: 45 | imgs += image_set 46 | if len(imgs) == 0: 47 | continue 48 | img_arr = [] 49 | img_sz = None 50 | for img in imgs: 51 | try: 52 | processed_img = preprocess(img).unsqueeze(0) 53 | if img_sz is None: 54 | img_sz = processed_img.shape 55 | img_arr.append(processed_img) 56 | except Exception as e: 57 | print(f'Illegal image {img}.') 58 | img_arr.append(torch.zeros(img_sz)) 59 | imgs = torch.cat(img_arr).to(device) 60 | img_embeds = model.encode_image(imgs) 61 | img_embeds = img_embeds / img_embeds.norm(dim=-1, keepdim=True) 62 | img_embeds = img_embeds.detach().cpu().numpy() 63 | img_embeds = np.nan_to_num(img_embeds) 64 | idx = 0 65 | for pid, image_set in images: 66 | step = len(image_set) 67 | pairs[pid].data[f'{locality}_clips'] = img_embeds[idx:idx + step] 68 | pairs[pid].save(pairs_dir) 69 | idx += step 70 | # embed texts 71 | if len(texts) > 0: 72 | ts = [] 73 | for _, text_set in texts: 74 | ts += text_set 75 | ts = clip.tokenize(ts).to(device) 76 | ts_embeds = model.encode_text(ts) 77 | ts_embeds = ts_embeds / ts_embeds.norm(dim=-1, keepdim=True) 78 | ts_embeds = ts_embeds.detach().cpu().numpy() 79 | ts_embeds = np.nan_to_num(ts_embeds) 80 | idx = 0 81 | for pid, text_set in texts: 82 | step = len(text_set) 83 | pairs[pid].data[f'{locality}_clips'] = ts_embeds[idx:idx + step] 84 | pairs[pid].save(pairs_dir) 85 | idx += step 86 | 87 | 88 | # batch embed the audios of a list of pairs with CLAP (AudioLDM2 flavor) 89 | def clap_embed_pairs(pids, pairs_dir, model=None, clap_version='audioldm-s-full-v2', 90 | duration=10, batch_size=256, sampling_rate=16000, device='cuda'): 91 | with torch.no_grad(): 92 | seg_length = int(duration * 102.4) * (sampling_rate // 100) 93 | if model is None: 94 | clap = CLAP(clap_version=clap_version, embed_mode='audio', sample_rate=sampling_rate, device=device) 95 | else: 96 | clap = model 97 | 98 | # embed 99 | print(f'Embedding {len(pids)} pairs into CLAP:') 100 | for s in tqdm(range(0, len(pids), batch_size)): 101 | e = min(len(pids), s + batch_size) 102 | pairs = [load_pair(pairs_dir, pid) for pid in pids[s:e]] 103 | wavs = np.concatenate([read_wav_file(p.data['audio'], seg_length, sampling_rate=sampling_rate) 104 | for p in pairs]) 105 | embeds = clap.model(torch.from_numpy(wavs).float()).detach().cpu().numpy() 106 | embeds = np.nan_to_num(embeds) 107 | # update pairs and save 108 | for i, p in enumerate(pairs): 109 | p.data['clap'] = embeds[i] 110 | p.save(pairs_dir) 111 | 112 | 113 | def pairs2clips(pairs, clips_type, max_length=0, device='cuda'): # local clips or global clips 114 | if max_length == 0: 115 | clips = [] 116 | for p in pairs: 117 | clips.append(torch.from_numpy(p.data[clips_type])) 118 | return torch.cat(clips).float().to(device) 119 | else: 120 | clips = torch.zeros(len(pairs), max_length, pairs[0].data[clips_type].shape[-1]) 121 | for i, p in enumerate(pairs): 122 | emb = torch.from_numpy(p.data[clips_type]) 123 | emb = emb.reshape(-1, emb.shape[-1]) 124 | clips[i, :emb.shape[0], :] = emb 125 | return clips.float().to(device) 126 | 127 | 128 | def pairs2claps(pairs, align_clips=None, device='cuda'): # optionally, align with multiple clips by duplication 129 | if align_clips: 130 | claps = [] 131 | for p in pairs: 132 | claps.append( 133 | torch.from_numpy(np.repeat(p.data['clap'], p.data[align_clips].shape[0], axis=0))) 134 | return torch.cat(claps).float().to(device) 135 | else: 136 | return torch.cat([torch.from_numpy(p.data['clap']) for p in pairs]).float().to(device) 137 | 138 | 139 | """ 140 | A pair can be image-audio, text-audio, video-audio. 141 | Notice that this can be a many-to-one pair because a video has multiple frames. 142 | """ 143 | 144 | 145 | class Pair: 146 | def __init__(self, global_srcs, context_srcs, local_srcs, localities, aud, mode, pid): 147 | self.mode = mode 148 | self.data = { 149 | 'pid': pid, # unique id of this pair 150 | 'mode': mode.value, 151 | # if text/label, a list of strings, otherwise, a list of image files (for videos, extracted visual frames) 152 | 'global_srcs': global_srcs, 153 | 'context_srcs': context_srcs, 154 | 'local_srcs': local_srcs, 155 | 'localities': localities, 156 | 'audio': aud, 157 | 'local_clips': None, # clip embeddings of single sources 158 | 'global_clips': None, # clip embeddings of the original data 159 | 'context_clips': None, 160 | 'clap': None, # clap embedding of audio 161 | } 162 | 163 | def get_sources(self, src_type): # return preprocessed source data 164 | src = self.data[src_type] 165 | if src is None: 166 | raise NotImplementedError('Source data corrupted for this pair!') 167 | 168 | # if label, preprocess by prompt augmentation 169 | if self.mode == Mode.LABEL: 170 | prompts = [] 171 | for s in src: 172 | prompts += [f"the sound of {s}", 173 | f"the sound {s} makes", 174 | f"the audio of {s}"] 175 | return prompts 176 | elif self.mode == Mode.TEXT: 177 | return src 178 | 179 | images = [] 180 | if self.mode == Mode.IMAGE or self.mode == Mode.VIDEO: 181 | images += [Image.open(s) for s in src] 182 | return images 183 | 184 | def save(self, folder): 185 | with open(f"{folder}/{self.data['pid']}.pickle", 'wb') as fp: 186 | pickle.dump(self.data, fp, protocol=pickle.HIGHEST_PROTOCOL) 187 | 188 | def __str__(self): 189 | content = { 190 | 'pid': self.data['pid'], # unique id of this pair 191 | 'mode': self.mode.name, 192 | 'global_srcs': self.data['global_srcs'], 193 | 'context_srcs': self.data['context_srcs'], 194 | 'local_srcs': self.data['local_srcs'], 195 | 'audio': self.data['audio'] 196 | } 197 | return json.dumps(content, indent=4) 198 | 199 | 200 | # load a pair by id and return it 201 | def load_pair(pdir, pid=None): 202 | pair = Pair(None, None, None, None, None, Mode.NA, pid) 203 | if pid is not None: 204 | pdir = Path(pdir) / f'{pid}.pickle' 205 | with open(pdir, 'rb') as fp: 206 | pair.data = pickle.load(fp) 207 | pair.mode = Mode(pair.data['mode']) 208 | return pair 209 | 210 | 211 | def collate_pairs(data): 212 | return [d[0] for d in data], [d[1] for d in data] 213 | 214 | 215 | class PairDataset(Dataset): 216 | def __init__(self, pairs_meta_file=None, split='train', pairs_dfs=None, pairs_roots=None): 217 | self.split = split 218 | 219 | if pairs_meta_file is not None: 220 | self.pairs_roots = [Path(pairs_meta_file).parent / split] 221 | df = pd.read_csv(pairs_meta_file) 222 | df = df[df['split'] == split] 223 | self.pairs_dfs = [df] 224 | 225 | elif pairs_dfs is not None and pairs_roots is not None: 226 | self.pairs_dfs = pairs_dfs 227 | self.pairs_roots = pairs_roots 228 | 229 | else: 230 | raise NotImplementedError('Illegal dataset parameters.') 231 | 232 | def __len__(self): 233 | return sum([len(df) for df in self.pairs_dfs]) 234 | 235 | def __getitem__(self, idx): 236 | for df_idx, df in enumerate(self.pairs_dfs): 237 | if idx < len(df): 238 | entry = df.iloc[idx] 239 | return load_pair(self.pairs_roots[df_idx], entry['pid']), entry['category'] 240 | idx -= len(df) 241 | 242 | # merge with another PairDataset, non-destructively 243 | def merge(self, other_set): 244 | return PairDataset(split=self.split, 245 | pairs_dfs=self.pairs_dfs + other_set.pairs_dfs, 246 | pairs_roots=self.pairs_roots + other_set.pairs_roots) 247 | 248 | -------------------------------------------------------------------------------- /ssv2a/data/tpairs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the data structure of tags-audio pairs. This is a many-to-one matching. 3 | """ 4 | import os 5 | import pickle 6 | import random 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import clip 11 | import pandas as pd 12 | import torch 13 | from tqdm.auto import tqdm 14 | 15 | from ssv2a.data.pairs import PairDataset, load_pair 16 | from ssv2a.data.utils import normalize_wav 17 | from ssv2a.model.clap import CLAP 18 | from ssv2a.model.dalle2_prior import Dalle2Prior 19 | 20 | 21 | class TagPair: 22 | def __init__(self, pid, caption, tags, aud_wave): 23 | self.pid = pid 24 | 25 | self.aud_wave = aud_wave 26 | self.aud_clap = None 27 | 28 | self.caption = caption 29 | self.caption_clip = None 30 | self.caption_clip_prior = None 31 | 32 | self.tags = tags 33 | self.tag_clips = None 34 | self.tag_clips_prior = None 35 | self.tag_claps = None 36 | 37 | self.aug_img_id = None 38 | self.aug_clip = None 39 | 40 | self.aug_tag_img_ids = None 41 | self.aug_tag_clips = None 42 | 43 | def save(self, folder): 44 | with open(f"{folder}/{self.pid}.pkl", 'wb') as fp: 45 | pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL) 46 | 47 | 48 | def load_tpair(pdir, pid=None): 49 | if pid is not None: 50 | pdir = Path(pdir) / f'{pid}.pkl' 51 | with open(pdir, 'rb') as fp: 52 | return pickle.load(fp) 53 | 54 | 55 | def tpairs2aclaps(pairs, device='cuda'): 56 | return torch.cat([torch.from_numpy(p.aud_clap).unsqueeze(0) for p in pairs]).to(device) 57 | 58 | 59 | def tpairs2cclips(pairs, device='cuda'): 60 | return torch.cat([torch.from_numpy(p.caption_clip).unsqueeze(0) for p in pairs]).to(device) 61 | 62 | 63 | def tpairs2tclips(pairs, max_length=0, use_prior=False, device='cuda'): 64 | if max_length == 0: 65 | if use_prior: 66 | return torch.cat([torch.from_numpy(p.tag_clips_prior) for p in pairs]).to(device) 67 | else: 68 | return torch.cat([torch.from_numpy(p.tag_clips) for p in pairs]).to(device) 69 | else: 70 | tclips = [] 71 | for p in pairs: 72 | L, E = p.tag_clips.shape 73 | 74 | if use_prior: 75 | seq = torch.from_numpy(p.tag_clips_prior) 76 | else: 77 | seq = torch.from_numpy(p.tag_clips) 78 | 79 | # truncate and random sample (or shuffle if short) 80 | idx = random.sample(range(L), min(L, max_length)) 81 | seq = seq[idx] 82 | 83 | if L < max_length: # pad if short 84 | seq = torch.cat([seq, torch.zeros(max_length - L, E)]) 85 | 86 | tclips.append(seq) 87 | return torch.stack(tclips).to(device) 88 | 89 | 90 | class TagPairDataset(PairDataset): 91 | def __getitem__(self, idx): 92 | for df_idx, df in enumerate(self.pairs_dfs): 93 | if idx < len(df): 94 | entry = df.iloc[idx] 95 | return load_tpair(self.pairs_roots[df_idx], entry['pid']), entry['category'] 96 | idx -= len(df) 97 | 98 | def merge(self, other_set): 99 | return TagPairDataset(split=self.split, 100 | pairs_dfs=self.pairs_dfs + other_set.pairs_dfs, 101 | pairs_roots=self.pairs_roots + other_set.pairs_roots) 102 | 103 | 104 | class MixedPairDataset: 105 | def __init__(self, pairs_meta_fs=None, tpairs_meta_fs=None, split='train', pairs_dfs=None, pairs_roots=None): 106 | if pairs_meta_fs is None: 107 | pairs_meta_fs = [] 108 | self.split = split 109 | 110 | self.pairs_roots = [] 111 | self.pairs_dfs = [] 112 | 113 | for f in pairs_meta_fs: 114 | self.pairs_roots.append(Path(f).parent / split) 115 | df = pd.read_csv(f) 116 | df = df[df['split'] == split] 117 | self.pairs_dfs.append((df, 'pairs')) 118 | 119 | for f in tpairs_meta_fs: 120 | self.pairs_roots.append(Path(f).parent / split) 121 | df = pd.read_csv(f) 122 | df = df[df['split'] == split] 123 | self.pairs_dfs.append((df, 'tpairs')) 124 | 125 | def __len__(self): 126 | return sum([len(df) for df, _ in self.pairs_dfs]) 127 | 128 | def __getitem__(self, idx): 129 | for df_idx, (df, dft) in enumerate(self.pairs_dfs): 130 | if idx < len(df): 131 | entry = df.iloc[idx] 132 | if dft == 'pairs': 133 | return load_pair(self.pairs_roots[df_idx], entry['pid']), entry['category'] 134 | else: 135 | return load_tpair(self.pairs_roots[df_idx], entry['pid']), entry['category'] 136 | idx -= len(df) 137 | 138 | 139 | def clip_embed_tpairs(pids, pdir, bs=64, clip_version='ViT-L/14', device='cuda'): 140 | with (torch.no_grad()): 141 | model, preprocess = clip.load(clip_version, device=device) 142 | 143 | # embed caption 144 | print(f'Embedding {len(pids)} captions into CLIP:') 145 | for s in tqdm(range(0, len(pids), bs)): 146 | e = min(len(pids), s + bs) 147 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 148 | # pairs = [p for p in pairs if p.caption_clip is None] 149 | 150 | if len(pairs) == 0: 151 | continue 152 | 153 | ts = [p.caption for p in pairs] 154 | ts = clip.tokenize(ts, truncate=True).to(device) 155 | ts = model.encode_text(ts).float() 156 | ts = ts / ts.norm(p=2, dim=-1, keepdim=True) 157 | ts = ts.detach().cpu().numpy() 158 | ts = np.nan_to_num(ts) 159 | 160 | for i, p in enumerate(pairs): 161 | p.caption_clip = ts[i] 162 | p.save(pdir) 163 | 164 | # embed tags 165 | print(f'Embedding {len(pids)} bags of tags into CLIP:') 166 | for s in tqdm(range(0, len(pids), bs)): 167 | e = min(len(pids), s + bs) 168 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 169 | # pairs = [p for p in pairs if p.tag_clips is None] 170 | 171 | if len(pairs) == 0: 172 | continue 173 | 174 | ts = [] 175 | for p in pairs: 176 | ts += p.tags 177 | if len(ts) == 0: 178 | continue 179 | ts = clip.tokenize(ts, truncate=True).to(device) 180 | ts = model.encode_text(ts).float() 181 | ts = ts / ts.norm(p=2, dim=-1, keepdim=True) 182 | ts = ts.detach().cpu().numpy() 183 | ts = np.nan_to_num(ts) 184 | 185 | step = 0 186 | jumps = [len(p.tags) for p in pairs] 187 | for i, p in enumerate(pairs): 188 | p.tag_clips = ts[step:step + jumps[i]] 189 | step += jumps[i] 190 | p.save(pdir) 191 | 192 | 193 | def clap_embed_tpairs(pids, pdir, bs=64, 194 | clap_version='audioldm-s-full-v2', device='cuda'): 195 | with torch.no_grad(): 196 | # seg_length = int(duration * 102.4) * 160 197 | clap = CLAP(clap_version=clap_version, embed_mode='audio', device=device) 198 | del_pids = [] 199 | 200 | # embed source audios 201 | print(f'Embedding {len(pids)} source audios into CLAP:') 202 | for s in tqdm(range(0, len(pids), bs)): 203 | e = min(len(pids), s + bs) 204 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 205 | pairs = [p for p in pairs if p.aud_clap is None] 206 | 207 | if len(pairs) == 0: 208 | continue 209 | 210 | for p in pairs: 211 | if p.aud_wave.shape[0] > 100: 212 | waveform = normalize_wav(p.aud_wave) 213 | waveform = waveform[None, ...] 214 | waveform = waveform / np.max(np.abs(waveform)) 215 | waveform = np.nan_to_num(0.5 * waveform) 216 | else: 217 | print(f'Delete Short Audio: {p.pid}') 218 | os.remove(pdir / f'{p.pid}.pkl') 219 | del_pids.append(p.pid) 220 | continue 221 | 222 | embeds = clap.model(torch.from_numpy(waveform).float()).detach().cpu().numpy() 223 | embeds = np.squeeze(np.nan_to_num(embeds)) 224 | 225 | p.aud_clap = embeds 226 | p.save(pdir) 227 | 228 | # embed tags 229 | print(f'Embedding {len(pids)} bags of tags into CLAP:') 230 | for pid in del_pids: 231 | pids.remove(pid) 232 | clap = CLAP(clap_version=clap_version, embed_mode='text', device=device) 233 | for s in tqdm(range(0, len(pids), bs)): 234 | e = min(len(pids), s + bs) 235 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 236 | pairs = [p for p in pairs if p.tag_claps is None] 237 | 238 | if len(pairs) == 0: 239 | continue 240 | 241 | ts = [] 242 | for p in pairs: 243 | ts += p.tags 244 | rts = np.empty((len(ts), 512)) 245 | for s1 in range(0, len(ts), bs): 246 | e1 = min(len(ts), s1 + bs) 247 | rts[s1:e1] = clap.model(ts[s1:e1]).squeeze().float().detach().cpu().numpy() 248 | 249 | step = 0 250 | jumps = [len(p.tags) for p in pairs] 251 | for i, p in enumerate(pairs): 252 | p.tag_claps = rts[step:step + jumps[i]] 253 | step += jumps[i] 254 | p.save(pdir) 255 | 256 | 257 | # translate tpairs from text-audio data to image-audio data 258 | def prior_embed_tpairs(pids, pdir, cfg, ckpt, bs=64, n_samples_per_batch=2, cond_scale=1, device='cuda'): 259 | model = Dalle2Prior(cfg, ckpt, device=device) 260 | 261 | print(f'Translating {len(pids)} captions from CLIP text space to image space:') 262 | for s in tqdm(range(0, len(pids), bs)): 263 | e = min(len(pids), s + bs) 264 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 265 | pairs = [p for p in pairs if (not hasattr(p, 'caption_clip_prior')) or p.caption_clip_prior is None] 266 | 267 | if len(pairs) == 0: 268 | continue 269 | 270 | caps = [p.caption for p in pairs] 271 | cap_clips = model.sample(caps, n_samples_per_batch=n_samples_per_batch, cond_scale=cond_scale) 272 | cap_clips = np.nan_to_num(cap_clips.detach().cpu().float().numpy()) 273 | 274 | for i, p in enumerate(pairs): 275 | p.caption_clip_prior = cap_clips[i] 276 | p.save(pdir) 277 | 278 | print(f'Translating {len(pids)} bags of tags from CLIP text space to image space:') 279 | for s in tqdm(range(0, len(pids), bs)): 280 | e = min(len(pids), s + bs) 281 | pairs = [load_tpair(pdir, pid) for pid in pids[s:e]] 282 | pairs = [p for p in pairs if (not hasattr(p, 'tag_clips_prior')) or p.tag_clips_prior is None] 283 | 284 | if len(pairs) == 0: 285 | continue 286 | 287 | tags = [] 288 | for p in pairs: 289 | tags += p.tags 290 | if len(tags) == 0: 291 | continue 292 | 293 | tag_clips = [] 294 | for s1 in range(0, len(tags), bs): 295 | e1 = min(len(tags), s1 + bs) 296 | tag_clips.append(model.sample(tags[s1:e1], n_samples_per_batch=n_samples_per_batch, cond_scale=cond_scale)) 297 | tag_clips = np.nan_to_num(torch.cat(tag_clips, dim=0).detach().cpu().float().numpy()) 298 | 299 | step = 0 300 | jumps = [len(p.tags) for p in pairs] 301 | for i, p in enumerate(pairs): 302 | p.tag_clips_prior = tag_clips[step:step + jumps[i]] 303 | step += jumps[i] 304 | p.save(pdir) 305 | 306 | -------------------------------------------------------------------------------- /ssv2a/data/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | These helper functions mainly deal with jumping between audio representations. 3 | The wav manipulations are adapted from https://github.com/haoheliu/AudioLDM2 4 | -- danke 5 | """ 6 | import os 7 | import random 8 | import urllib.request 9 | from datetime import datetime 10 | from concurrent.futures import ThreadPoolExecutor 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import clip 15 | import torch 16 | import torchaudio 17 | import skvideo.io 18 | from PIL import Image, ImageDraw, ImageFilter 19 | from audioldm import get_metadata 20 | from audioldm.utils import MyProgressBar 21 | from kneed import KneeLocator 22 | from textblob import TextBlob 23 | import soundfile as sf 24 | from tqdm import tqdm 25 | 26 | from ssv2a.model.dalle2_prior import Dalle2Prior 27 | 28 | 29 | def set_seed(seed: int): 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | 35 | 36 | def get_timestamp(): 37 | return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 38 | 39 | 40 | # find elbow point of a given score list automatically with Kneedle algorithm 41 | # as we always operate in batches, the sensitivity is fixed at 1 in hope of an optimal online Kneedle 42 | def elbow(scores, sensitivity=1, curve='convex', return_idx=True): 43 | x = np.arange(len(scores)) 44 | kneedle = KneeLocator(x, scores, S=sensitivity, curve=curve, direction='decreasing') 45 | if return_idx: 46 | if kneedle.knee is None: 47 | return len(x) 48 | return min(len(x), round(kneedle.knee)) 49 | else: 50 | if kneedle.knee_y is None: 51 | return x[-1] 52 | else: 53 | return kneedle.knee_y 54 | 55 | 56 | def random_mute(a, p=.5): # randomly mute a tensor with probability, used for classifier free guidance 57 | """ 58 | a: [*, d1, d2] 59 | """ 60 | d1 = a.shape[-2] 61 | idx = torch.rand(d1) < p 62 | rt = a.clone() 63 | rt[..., idx, :] = 0 64 | return rt, idx 65 | 66 | 67 | def get_noun_phrases(text): 68 | blob = TextBlob(text) 69 | return list(blob.noun_phrases) 70 | 71 | 72 | def read_classes(classes_file): 73 | with open(classes_file, 'r') as fp: 74 | lines = fp.readlines() 75 | classes = list([c.strip('\n').lower() for c in lines if c.strip('\n').lower() != '']) 76 | return classes 77 | 78 | 79 | def extract_central_frame(video, save_dir, size=None): 80 | videodata = skvideo.io.vread(str(video)) 81 | cf = Image.fromarray(videodata[videodata.shape[0] // 2 + videodata.shape[0] % 2, :, :, :]) 82 | if size is not None: 83 | cf = cf.resize(size, resample=Image.Resampling.BICUBIC) 84 | cf.save(Path(save_dir) / Path(video).name.replace('.mp4', '.png'), 'PNG') 85 | 86 | 87 | # given a list of video files, batch extract the central frames and save to a directory 88 | def batch_extract_central_frame(video_fs, save_dir, size=None, bs=32, num_workers=8): 89 | os.makedirs(save_dir, exist_ok=True) 90 | for s in tqdm(range(0, len(video_fs), bs)): 91 | e = min(len(video_fs), s + bs) 92 | pool = ThreadPoolExecutor(max_workers=num_workers) 93 | for f in video_fs[s:e]: 94 | pool.submit(extract_central_frame, f, save_dir, size=size) 95 | pool.shutdown(wait=True) 96 | 97 | # evenly extract n frames from a given video 98 | def extract_frames(video_fs, save_dir, size=None, frames=64): 99 | for vid in tqdm(video_fs): 100 | videodata = skvideo.io.vread(str(vid)) 101 | L = videodata.shape[0] 102 | for i, f in enumerate(np.round(np.linspace(0, L - 1, frames)).astype(int).tolist()): 103 | cf = Image.fromarray(videodata[f, :, :, :]) 104 | if size is not None: 105 | cf = cf.resize(size, resample=Image.Resampling.BICUBIC) 106 | cf.save(Path(save_dir) / Path(vid).name.replace('.mp4', f'_{i}.png'), 'PNG') 107 | 108 | # given a list of video files, batch extract the central frames and save to a directory 109 | def batch_extract_frames(video_fs, save_dir, size=None, frames=64, num_workers=8): 110 | os.makedirs(save_dir, exist_ok=True) 111 | pool = ThreadPoolExecutor(max_workers=num_workers) 112 | if len(video_fs) >= num_workers: 113 | workload = len(video_fs) // num_workers 114 | for s in range(0, len(video_fs), workload): 115 | e = min(len(video_fs), s + workload) 116 | pool.submit(extract_frames, video_fs[s:e], save_dir, size, frames) 117 | pool.shutdown(wait=True) 118 | else: 119 | extract_frames(video_fs, save_dir, size, frames) 120 | 121 | 122 | def get_fps(s): 123 | if s.isdigit(): 124 | return float(s) 125 | else: 126 | num, denom = s.split('/') 127 | return float(num) / float(denom) 128 | 129 | 130 | def video2images(video, fps=4): # video to image sequence, sampled by fps 131 | try: 132 | video_name = os.path.basename(video).replace('.mp4', '') 133 | videodata = skvideo.io.vread(video) 134 | videometadata = skvideo.io.ffprobe(video) 135 | frame_rate = videometadata['video']['@avg_frame_rate'] 136 | frame_num = videodata.shape[0] 137 | frames_in_sec = get_fps(frame_rate) 138 | length_in_secs = frame_num / frames_in_sec 139 | 140 | return [videodata[::int(round(frames_in_sec)/fps), :, :, :], length_in_secs, frame_num, video_name] 141 | 142 | except Exception as e: 143 | return None 144 | 145 | 146 | def clip_embed_images(images, version='ViT-L/14', batch_size=256, device='cuda'): 147 | with torch.no_grad(): 148 | model, preprocess = clip.load(version, device=device) 149 | embeds = [] 150 | for i in tqdm(range(0, len(images), batch_size)): 151 | e = min(len(images), i + batch_size) 152 | imgs = torch.cat([preprocess(Image.open(img)).unsqueeze(0).to(device) for img in images[i:e]]) 153 | embs = model.encode_image(imgs) 154 | embeds.append(embs) 155 | embeds = torch.cat(embeds).float() 156 | return embeds / embeds.norm(p=2, dim=-1, keepdim=True) 157 | 158 | 159 | def clip_embed_texts(texts, bs=256, version='ViT-L/14', device='cuda'): 160 | with torch.no_grad(): 161 | model, preprocess = clip.load(version, device=device) 162 | embeds = torch.empty(len(texts), 512) 163 | for s in tqdm(range(0, len(texts), bs)): 164 | e = min(len(texts), s + bs) 165 | ts = clip.tokenize(texts[s:e]).to(device) 166 | ts = model.encode_text(ts, normalize=True) 167 | embeds[s:e, :] = ts.detach().cpu() 168 | embeds = embeds.float() 169 | return embeds / embeds.norm(p=2, dim=-1, keepdim=True) 170 | 171 | 172 | def prior_embed_texts(texts, cfg, ckpt, bs=64, n_samples_per_batch=2, cond_scale=1, device='cuda'): 173 | with torch.no_grad(): 174 | model = Dalle2Prior(cfg, ckpt, device=device) 175 | 176 | prior_clips = torch.empty(len(texts), 768) 177 | for s in tqdm(range(0, len(texts), bs)): 178 | e = min(len(texts), s + bs) 179 | prior_clips[s:e, :] = model.sample(texts[s:e], n_samples_per_batch=n_samples_per_batch, cond_scale=cond_scale).detach().cpu() 180 | return prior_clips 181 | 182 | 183 | # given a data length dictionary and its flattened clip embeds, 184 | # unflatten it to batched sequences with sampling and padding, delay specifies the placeholder tokens in first k rows 185 | def emb2seq(jumps, emb, max_length=0, delay=0, device='cuda'): 186 | step = 0 187 | rt_emb = [] 188 | for j in jumps: 189 | rt_emb.append(emb[step:step+j]) 190 | step += j 191 | 192 | if max_length == 0: 193 | return rt_emb 194 | 195 | for i, e in enumerate(rt_emb): 196 | L, E = e.shape 197 | if L > max_length: 198 | idx = random.sample(range(L), max_length - delay) 199 | rt_emb[i] = e[idx] 200 | else: 201 | rt_emb[i] = torch.cat([e, torch.zeros(max_length - L - delay, E).to(device)]) 202 | rt_emb[i] = torch.cat([torch.zeros(delay, E).to(device), rt_emb[i]]) 203 | return torch.stack(rt_emb) 204 | 205 | 206 | # given two batches of embeddings, find the top-k similar ids in emb2 for ech entry of emb1 207 | def topk_sim(emb1, emb2, topk=10, bs=512, normalize=False, device='cuda'): 208 | with torch.no_grad(): 209 | emb1 = emb1.to(device) 210 | emb2 = emb2.to(device) 211 | if normalize: 212 | emb1 = emb1 / emb1.norm(p=2, dim=-1, keepdim=True) 213 | emb2 = emb2 / emb2.norm(p=2, dim=-1, keepdim=True) 214 | 215 | B = emb1.shape[0] 216 | topk_sims = torch.ones(B, topk).to(device) * -2 217 | topk_idx = torch.zeros(B, topk, dtype=torch.int64).to(device) 218 | 219 | for s in range(0, len(emb2), bs): 220 | e = min(len(emb2), s + bs) 221 | sims = torch.einsum('ai,bi->ab', emb1, emb2[s:e]) 222 | sims, idx = torch.topk(sims, k=topk, dim=-1) 223 | idx += s 224 | topk_sims = torch.cat([topk_sims, sims], dim=-1) 225 | topk_idx = torch.cat([topk_idx, idx], dim=-1) 226 | topk_sims, idx = torch.topk(topk_sims, k=topk, dim=-1) 227 | topk_idx = topk_idx.gather(-1, idx) 228 | 229 | return topk_sims, topk_idx 230 | 231 | 232 | def mask2bbox(mask, normalize=False): 233 | rows = np.any(mask, axis=1) 234 | cols = np.any(mask, axis=0) 235 | rmin, rmax = np.where(rows)[0][[0, -1]] 236 | cmin, cmax = np.where(cols)[0][[0, -1]] 237 | 238 | if normalize: 239 | return cmin / mask.shape[0], rmin / mask.shape[1], cmax / mask.shape[0], rmax / mask.shape[1] 240 | return cmin, rmin, cmax, rmax 241 | 242 | 243 | # blur the image content with ellipses filling the bounding boxes 244 | def blur_image_bbox(img, bboxes, blur_radius=15, mask_blur_radius=15): 245 | # ellipse masks 246 | mask = Image.new('L', img.size, color=0) 247 | draw = ImageDraw.Draw(mask) 248 | for bbox in bboxes: 249 | draw.ellipse(bbox, fill=255) 250 | mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur_radius)) 251 | # blur and overlay with masks 252 | overlay = img.filter(ImageFilter.GaussianBlur(radius=blur_radius)) 253 | return Image.composite(overlay, img, mask) 254 | 255 | 256 | # crop a PIL image with a normalized [x1, y1, x2, y2] bbox 257 | def crop_image_bbox(img, bbox, keep_size=False): 258 | width, height = img.size 259 | x1, y1, x2, y2 = bbox 260 | x1, y1, x2, y2 = max(0, x1), max(0, y1), max(0, x2), max(0, y2) 261 | x1, y1, x2, y2 = min(1, x1), min(1, y1), min(1, x2), min(1, y2) 262 | oimg = img.crop([round(min(x1, x2) * width), round(min(y1, y2) * height), 263 | round(max(x1, x2) * width), round(max(y1, y2) * height)]) 264 | if keep_size: # keep original size by upsampling 265 | oimg = oimg.resize((width, height), Image.Resampling.BICUBIC) 266 | return oimg 267 | 268 | 269 | def pad_wav(waveform, segment_length): 270 | waveform_length = waveform.shape[-1] 271 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 272 | if segment_length is None or waveform_length == segment_length: 273 | return waveform 274 | elif waveform_length > segment_length: 275 | return waveform[:, :segment_length] 276 | elif waveform_length < segment_length: 277 | temp_wav = np.zeros((1, segment_length)) 278 | temp_wav[:, :waveform_length] = waveform 279 | return temp_wav 280 | 281 | 282 | def normalize_wav(waveform): 283 | waveform = waveform - np.mean(waveform) 284 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 285 | return waveform * 0.5 286 | 287 | 288 | def read_wav_file(filename, segment_length=0, sampling_rate=16000): 289 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 290 | waveform, sr = torchaudio.load(filename) # Faster!!! 291 | if sr != sampling_rate: 292 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sampling_rate) 293 | waveform = waveform.numpy()[0, ...] 294 | waveform = normalize_wav(waveform) 295 | waveform = waveform[None, ...] 296 | if segment_length != 0: 297 | waveform = pad_wav(waveform, segment_length) 298 | 299 | waveform = waveform / np.max(np.abs(waveform)) 300 | waveform = 0.5 * waveform 301 | 302 | waveform = np.nan_to_num(waveform) 303 | 304 | return waveform 305 | 306 | 307 | def extract_kaldi_fbank_feature(waveform, sampling_rate): 308 | norm_mean = -4.2677393 309 | norm_std = 4.5689974 310 | 311 | if sampling_rate != 16000: 312 | waveform_16k = torchaudio.functional.resample( 313 | waveform, orig_freq=sampling_rate, new_freq=16000 314 | ) 315 | else: 316 | waveform_16k = waveform 317 | 318 | waveform_16k = waveform_16k - waveform_16k.mean() 319 | fbank = torchaudio.compliance.kaldi.fbank( 320 | waveform_16k, 321 | htk_compat=True, 322 | sample_frequency=16000, 323 | use_energy=False, 324 | window_type="hanning", 325 | num_mel_bins=128, 326 | dither=0.0, 327 | frame_shift=10, 328 | ) 329 | 330 | target_len = waveform.size(0) 331 | 332 | # cut and pad 333 | n_frames = fbank.shape[0] 334 | p = target_len - n_frames 335 | if p > 0: 336 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 337 | fbank = m(fbank) 338 | elif p < 0: 339 | fbank = fbank[:target_len, :] 340 | 341 | fbank = (fbank - norm_mean) / (norm_std * 2) 342 | 343 | return {"ta_kaldi_fbank": fbank} # [1024, 128] 344 | 345 | 346 | def save_wave(waveform, savepath, name="outwav", samplerate=16000): 347 | if type(name) is not list: 348 | name = [name] * waveform.shape[0] 349 | 350 | for i in range(waveform.shape[0]): 351 | if waveform.shape[0] > 1 : 352 | fname = "%s_%s.wav" % ( 353 | os.path.basename(name[i]) 354 | if (not ".wav" in name[i]) 355 | else os.path.basename(name[i]).split(".")[0], 356 | i, 357 | ) 358 | else: 359 | fname = "%s.wav" % os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0] 360 | # Avoid the file name too long to be saved 361 | if len(fname) > 255: 362 | fname = f"{hex(hash(fname))}.wav" 363 | 364 | path = os.path.join( 365 | savepath, fname 366 | ) 367 | # print("Save audio to %s" % path) 368 | sf.write(path, waveform[i, 0], samplerate=samplerate) 369 | 370 | 371 | def download_audioldm_checkpoint(checkpoint_name): 372 | meta = get_metadata() 373 | if(checkpoint_name not in meta.keys()): 374 | print("The model name you provided is not supported. Please use one of the following: ", meta.keys()) 375 | 376 | if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9: 377 | os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True) 378 | print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}") 379 | 380 | urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar()) 381 | print( 382 | "Weights downloaded in: {} Size: {}".format( 383 | meta[checkpoint_name]["path"], 384 | os.path.getsize(meta[checkpoint_name]["path"]), 385 | ) 386 | ) 387 | 388 | return meta[checkpoint_name]["path"] 389 | 390 | def image2video(img_fs, out_dir, duration=10, fps=24): # pad an image to a video 391 | for img_f in tqdm(img_fs): 392 | img = np.asarray(Image.open(img_f))[None, ...] 393 | img = np.repeat(img, duration * fps, axis=0) 394 | writer = skvideo.io.FFmpegWriter(os.path.join(out_dir, os.path.basename(img_f).replace('.png', '.mp4'))) 395 | for i in range(img.shape[0]): 396 | writer.writeFrame(img[i, ...]) 397 | writer.close() 398 | 399 | def batch_image2video(img_fs, out_dir, duration=10, fps=24, num_workers=16): 400 | os.makedirs(out_dir, exist_ok=True) 401 | 402 | worker_fs = [[] for _ in range(num_workers)] 403 | for i, j in enumerate(img_fs): 404 | worker_fs[i % num_workers].append(j) 405 | 406 | pool = ThreadPoolExecutor(max_workers=num_workers) 407 | for i in range(num_workers): 408 | pool.submit(image2video, worker_fs[i], out_dir, duration, fps) 409 | pool.shutdown(wait=True) 410 | 411 | -------------------------------------------------------------------------------- /ssv2a/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wguo86/SSV2A/7f01d5f7a45e6f3bf68b75ade3db30f08978c4b4/ssv2a/evals/__init__.py -------------------------------------------------------------------------------- /ssv2a/evals/cluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | Measures the F-Ratio and Partition Coefficient of the learned manifold. 3 | """ 4 | import random 5 | 6 | import pandas as pd 7 | import torch 8 | import numpy as np 9 | from sklearn.decomposition import PCA 10 | 11 | from ssv2a.data.pairs import PairDataset, pairs2clips, pairs2claps 12 | 13 | 14 | def get_cluster(pipeline, pairs_meta, cats, samples_per_cat=20, var_samples=1, device='cuda'): 15 | pairs = PairDataset(pairs_meta, split='test') 16 | 17 | # collect samples 18 | samples = {} 19 | for i in random.sample(range(len(pairs)), len(pairs)): 20 | pair, cat = pairs[i] 21 | if cat in cats: 22 | if cat not in samples: 23 | samples[cat] = [pair] 24 | elif cat in samples and len(samples[cat]) < samples_per_cat: 25 | samples[cat].append(pair) 26 | 27 | # get manifold embeddings 28 | raw_clips = {} 29 | raw_claps = {} 30 | fold_clips = {} 31 | fold_claps = {} 32 | gen_clips = {} 33 | gen_claps = {} 34 | pipeline.eval() 35 | with torch.no_grad(): 36 | for cat in cats: 37 | clips = pairs2clips(samples[cat], 'local_clips', device=device) 38 | claps = pairs2claps(samples[cat], align_clips='local_clips', device=device) 39 | raw_clips[cat] = clips.detach().cpu().numpy() 40 | raw_claps[cat] = claps.detach().cpu().numpy() 41 | clip = pipeline.manifold.fold_clips(clips, var_samples=var_samples, normalize=False) 42 | clap = pipeline.manifold.fold_claps(claps, var_samples=var_samples, normalize=False) 43 | fold_clips[cat] = clip.detach().cpu().numpy() 44 | fold_claps[cat] = clap.detach().cpu().numpy() 45 | gen_clips[cat] = pipeline.generator.fold2claps(clip, var_samples=var_samples).detach().cpu().numpy() 46 | gen_claps[cat] = pipeline.generator.fold2claps(clap, var_samples=var_samples).detach().cpu().numpy() 47 | return raw_clips, raw_claps, fold_clips, fold_claps, gen_clips, gen_claps 48 | 49 | 50 | def clusters2arr(clusters, cats): 51 | rt = [] 52 | for cat in cats: 53 | rt.append(clusters[cat]) 54 | return np.stack(rt) 55 | 56 | 57 | def clusters2csv(clusters, save_dir): # save the clustering to csv 58 | df = [] 59 | for k in clusters: 60 | embs = clusters[k] 61 | for i in range(len(embs)): 62 | entry = {'class': k} 63 | if i < len(embs) // 2: 64 | entry['mode'] = 'visual' 65 | else: 66 | entry['mode'] = 'audio' 67 | emb = embs[i] 68 | for j in range(len(emb)): 69 | entry[f'dim_{j}'] = emb[j] 70 | df.append(entry) 71 | df = pd.DataFrame(df) 72 | df.to_csv(save_dir, index=False) 73 | return df 74 | 75 | 76 | def pca_fit(pairs_meta, modality='local_clips', pipeline=None, n_components=512, device='cuda'): # use https://github.com/valentingol/torch_pca, faster than sklearn 77 | pairs = PairDataset(pairs_meta, split='test') 78 | embs = [] 79 | clap_embs = [] 80 | for i in random.sample(range(len(pairs)), len(pairs)): 81 | pair, _ = pairs[i] 82 | if pipeline is None: 83 | embs.append(np.squeeze(pair.data[modality])) 84 | else: 85 | embs.append(np.squeeze(pair.data['local_clips'])) 86 | clap_embs.append(np.squeeze(pair.data['clap'])) 87 | 88 | embs = np.stack(embs) 89 | if pipeline is not None: 90 | clap_embs = np.stack(clap_embs) 91 | pipeline.manifold.eval() 92 | with torch.no_grad(): 93 | fold_embs = [] 94 | for s in range(0, len(embs), 64): 95 | e = min(len(embs), s + 64) 96 | fold_embs.append( 97 | pipeline.manifold.fold_clips(torch.from_numpy(embs[s:e]).float().to(device), 98 | var_samples=1, normalize=False).detach().cpu().numpy()) 99 | fold_embs = np.concatenate(fold_embs, axis=0) 100 | 101 | fold_clap_embs = [] 102 | for s in range(0, len(clap_embs), 64): 103 | e = min(len(clap_embs), s + 64) 104 | fold_clap_embs.append( 105 | pipeline.manifold.fold_claps(torch.from_numpy(clap_embs).float().to(device), 106 | var_samples=1, normalize=False).detach().cpu().numpy()) 107 | fold_clap_embs = np.concatenate(fold_clap_embs, axis=0) 108 | embs = np.concatenate([fold_embs, fold_clap_embs], axis=0) 109 | 110 | pca = PCA(n_components=n_components, svd_solver='auto') 111 | pca.fit(embs) 112 | return pca 113 | 114 | 115 | def pca_reduce(clusters, pca): 116 | cats = list(clusters.keys()) 117 | cluster_arr = clusters2arr(clusters, cats) 118 | _, S, _ = cluster_arr.shape 119 | cluster_arr = np.concatenate(cluster_arr, axis=0) 120 | 121 | cluster_arr = pca.transform(cluster_arr) 122 | 123 | new_clusters = {} 124 | for i, cat in enumerate(cats): 125 | new_clusters[cat] = cluster_arr[i*S:i*S+S] 126 | return new_clusters 127 | 128 | 129 | def get_pc(clusters): 130 | """ 131 | :param clusters (K, S, E): a numpy array containing a cluster of samples 132 | :return: the Partition Coefficient 133 | """ 134 | # find centroids 135 | centroids = np.mean(clusters, axis=1) # (K, E) 136 | 137 | # compute membership (cosine similarity) 138 | member = np.einsum('kse,ke->ks', clusters, centroids) 139 | norms = np.linalg.norm(clusters, axis=-1) * np.expand_dims(np.linalg.norm(centroids, axis=-1), axis=-1) 140 | member /= norms # (K, S) 141 | member = member.T / 2 + 0.5 # (S, K), rescale, don't allow negative member resulted from cosine similarity 142 | 143 | # compute partition coefficient 144 | pc = np.sum(member ** 2, axis=-1) 145 | pc = np.mean(pc) 146 | return pc 147 | 148 | -------------------------------------------------------------------------------- /ssv2a/evals/cs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torchaudio 7 | from PIL import Image 8 | from torch.utils.data import DataLoader, Dataset 9 | from tqdm.auto import tqdm 10 | import wav2clip 11 | import clip 12 | 13 | from ssv2a.data.utils import pad_wav 14 | 15 | 16 | class WaveImageDataset(Dataset): 17 | def __init__(self, aud_imgs, sr=16000, duration=10, limit_num=None): 18 | self.data = aud_imgs 19 | self.data_idx = sorted(list(self.data.keys())) 20 | if limit_num is not None: 21 | self.data_idx = self.data_idx[:limit_num] 22 | self.sr = sr 23 | self.seg_len = int(duration * 102.4) * (sr // 100) 24 | 25 | def __getitem__(self, index): 26 | while True: 27 | try: 28 | filename = self.data_idx[index] 29 | waveform = self.read_from_file(filename) 30 | if waveform.shape[-1] < 1: 31 | raise ValueError("empty file %s" % filename) 32 | break 33 | except Exception as e: 34 | print(index, e) 35 | index = (index + 1) % len(self.data_idx) 36 | return waveform, self.data[filename], os.path.basename(filename) 37 | 38 | def __len__(self): 39 | return len(self.data_idx) 40 | 41 | def read_from_file(self, audio_file): 42 | audio, file_sr = torchaudio.load(audio_file) 43 | # Only use the first channel 44 | audio = audio[0:1, ...] 45 | audio = audio - audio.mean() 46 | 47 | if file_sr != self.sr: 48 | audio = torchaudio.functional.resample( 49 | audio, orig_freq=file_sr, new_freq=self.sr, # rolloff=0.95, lowpass_filter_width=16 50 | ) 51 | # audio = torch.FloatTensor(librosa.resample(audio.numpy(), file_sr, self.sr)) 52 | 53 | audio = pad_wav(audio.numpy(), self.seg_len) 54 | return audio 55 | 56 | 57 | def collate_waveimage(data): 58 | return np.concatenate([d[0] for d in data], dtype=np.float32), [d[1] for d in data], [d[2] for d in data] 59 | 60 | 61 | # clip score between image and audio, image can be multiple patches 62 | def get_cs(img_d, aud_d, sr=16000, duration=10, batch_size=64, device='cuda'): 63 | with torch.no_grad(): 64 | model = wav2clip.get_model().to(device) 65 | clip_model, clip_preprocess = clip.load('ViT-B/32', device=device) 66 | auds = [str(i) for i in Path(aud_d).rglob('*.wav')] 67 | aud_imgs = {} 68 | preimg = [str(i) for i in Path(img_d).rglob('*.png')] 69 | for p in auds: 70 | pat = Path(p).name.replace('.wav', '') 71 | aud_imgs[p] = [i for i in preimg if pat in Path(i).name] 72 | 73 | loader = DataLoader( 74 | WaveImageDataset(aud_imgs, sr=sr, duration=duration), 75 | batch_size=batch_size, 76 | sampler=None, 77 | num_workers=0, 78 | collate_fn=collate_waveimage 79 | ) 80 | 81 | ret_score = None 82 | total_score = 0 83 | n = 0 84 | for audio, imgs, audio_path in tqdm(loader): 85 | embedding = torch.from_numpy(wav2clip.embed_audio(audio, model)) 86 | embedding = embedding / embedding.norm(p=2, dim=-1, keepdim=True) 87 | 88 | # embed images 89 | jumps = [len(img) for img in imgs] 90 | pimgs = [] 91 | for img in imgs: 92 | pimgs += img 93 | pimgs = torch.cat([clip_preprocess(Image.open(img)).unsqueeze(0).to(device) for img in pimgs]) 94 | img_emb = clip_model.encode_image(pimgs) 95 | img_emb = img_emb / img_emb.norm(p=2, dim=-1, keepdim=True) 96 | 97 | # scoring (cosine similarity) 98 | idx = 0 99 | for i, j in enumerate(jumps): 100 | ae = embedding[i].detach().cpu().float() 101 | ie = img_emb[idx:idx+j].detach().cpu().float() 102 | sims = torch.einsum('i,bi->b', ae, ie) 103 | total_score += torch.mean(sims).numpy() 104 | idx += j 105 | n += 1 106 | 107 | return total_score / n 108 | 109 | -------------------------------------------------------------------------------- /ssv2a/evals/fad.py: -------------------------------------------------------------------------------- 1 | from frechet_audio_distance import FrechetAudioDistance 2 | 3 | 4 | # use https://github.com/gudgud96/frechet-audio-distance 5 | def get_fad(pred, target, sr=16000, model='pann', device='cuda'): 6 | if model == 'pann' or model == 'vggish': 7 | frechet = FrechetAudioDistance(model_name=model, sample_rate=sr, 8 | use_pca=False, use_activation=False, verbose=False) 9 | elif model == 'clap': 10 | frechet = FrechetAudioDistance( 11 | model_name="clap", 12 | sample_rate=48000, 13 | submodel_name="630k-audioset", # for CLAP only 14 | verbose=False, 15 | enable_fusion=False # for CLAP only 16 | ) 17 | elif model == 'encodec': 18 | frechet = FrechetAudioDistance( 19 | model_name="encodec", 20 | sample_rate=48000, 21 | channels=2, 22 | verbose=False, 23 | ) 24 | else: 25 | raise NotImplementedError('Model is not supported.') 26 | score = frechet.score(background_dir=target, eval_dir=pred) 27 | return score 28 | 29 | -------------------------------------------------------------------------------- /ssv2a/evals/ms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Matching Score (MS) measures generation relevance from multiple audio source conditions. 3 | Suppose we obtain top M detected ground-truth labels, and top N detected labels from a classifier on a generated audio, 4 | Then we would have 3 kinds of matchings: 5 | 1. True Positive: label is present in both ground truth and generated sets. 6 | 2. False Positive: label is not present in ground truth, but present in generation. 7 | 3. False Negative: label is present in ground truth, but not in generation. 8 | 9 | Notice we don't consider True Negatives here because we are not interested in them for a generation task. 10 | 11 | We then compute the following sub-metrics in MS: 12 | 1. Precision 13 | 2. Recall 14 | 3. F1 Score 15 | """ 16 | 17 | import numpy as np 18 | 19 | from ssv2a.data.utils import read_wav_file 20 | from pathlib import Path 21 | 22 | import torch 23 | from torch.utils.data import Dataset, DataLoader 24 | import pandas as pd 25 | from tqdm.auto import tqdm 26 | from BEATs import BEATs, BEATsConfig 27 | 28 | 29 | class AudioLabelDataset(Dataset): 30 | def __init__(self, meta_csv, aud_folder): 31 | super().__init__() 32 | self.folder = Path(aud_folder) 33 | self.df = pd.read_csv(meta_csv) 34 | 35 | def __len__(self): 36 | return len(self.df) 37 | 38 | def __getitem__(self, idx): 39 | row = self.df.iloc[idx] 40 | aud_f = self.folder / f"{row['id']}.wav" 41 | waveform = read_wav_file(aud_f, segment_length=163840, sampling_rate=16000) 42 | return waveform, row['labels'], row['id'] 43 | 44 | 45 | def collate_audiolabels(data): 46 | return torch.stack([torch.from_numpy(d[0]) for d in data]).squeeze().float(), [d[1] for d in data], [d[2] for d in data] 47 | 48 | 49 | def get_ms(gt_aud_dir, gen_aud_dir, beats_ckpt, M=10, N=50, bs=64, device='cuda'): 50 | ckpt = torch.load(beats_ckpt) 51 | 52 | cfg = BEATsConfig(ckpt['cfg']) 53 | label_dict = ckpt['label_dict'] 54 | 55 | beats = BEATs(cfg) 56 | beats.load_state_dict(ckpt['model']) 57 | beats.to(device) 58 | beats.eval() 59 | 60 | with torch.no_grad(): 61 | tps, fps, fns = [], [], [] 62 | 63 | gt_aud_fs = [str(p) for p in Path(gt_aud_dir).glob('*.wav')] 64 | ids = [p.name.replace('.wav', '') for p in Path(gt_aud_dir).glob('*.wav')] 65 | 66 | for s in tqdm(range(0, len(gt_aud_fs), bs)): 67 | e = min(len(gt_aud_fs), s + bs) 68 | wave = [] 69 | for aud_f in gt_aud_fs[s:e]: 70 | wave.append(torch.from_numpy(read_wav_file(str(aud_f), segment_length=163840, sampling_rate=16000))) 71 | wave = torch.stack(wave).squeeze().float().to(device) 72 | 73 | B = wave.shape[0] 74 | 75 | # prepare ground truth labels 76 | labels = [] 77 | padding_mask = torch.zeros(wave.shape).bool().to(device) 78 | gt_pred = beats.extract_features(wave.to(device), padding_mask=padding_mask)[0] 79 | for i, (label_prob, label_idx) in enumerate(zip(*gt_pred.topk(k=M))): 80 | lbs = [label_dict[idx.item()] for idx in label_idx] 81 | labels.append(lbs) 82 | 83 | # predict labels for generated audios 84 | gen_wave = [] 85 | gen_aud_dir = Path(gen_aud_dir) 86 | for i, vid in enumerate(ids[s:e]): 87 | try: 88 | aud_f = next(gen_aud_dir.glob(f'*{vid}*')) 89 | gen_wave.append(torch.from_numpy(read_wav_file(str(aud_f), segment_length=163840, sampling_rate=16000))) 90 | except Exception as e: 91 | del labels[i] 92 | print(f'Ignored non-present audio {vid}.wav') 93 | gen_wave = torch.stack(gen_wave).squeeze().float().to(device) 94 | padding_mask = torch.zeros(gen_wave.shape).bool().to(device) 95 | gen_pred = beats.extract_features(gen_wave.to(device), padding_mask=padding_mask)[0] 96 | gen_labels = [] 97 | for i, (label_prob, label_idx) in enumerate(zip(*gen_pred.topk(k=N))): 98 | lbs = [label_dict[idx.item()] for idx in label_idx] 99 | gen_labels.append(lbs) 100 | 101 | # matching 102 | tp, fp, fn = [], [], [] 103 | for i in range(len(labels)): 104 | gt_set = set(labels[i]) 105 | gen_set = set(gen_labels[i]) 106 | true_pos = gt_set.intersection(gen_set) 107 | tp.append(len(true_pos)) 108 | fp.append(len(gen_set.difference(true_pos))) 109 | fn.append(len(gt_set.difference(true_pos))) 110 | tps += tp 111 | fps += fp 112 | fns += fn 113 | 114 | # compute metrics 115 | tps = np.array(tps, dtype=np.float64) 116 | fps = np.array(fps, dtype=np.float64) 117 | fns = np.array(fns, dtype=np.float64) 118 | 119 | precision = np.mean(tps / (tps + fps)).item() 120 | recall = np.mean(tps / (tps + fns)).item() 121 | f1_score = (2 * precision * recall / (precision + recall)) 122 | 123 | return precision, recall, f1_score 124 | 125 | -------------------------------------------------------------------------------- /ssv2a/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wguo86/SSV2A/7f01d5f7a45e6f3bf68b75ade3db30f08978c4b4/ssv2a/model/__init__.py -------------------------------------------------------------------------------- /ssv2a/model/aggregator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import wandb 8 | from tqdm import tqdm 9 | from torch.utils.data import Dataset, DataLoader 10 | import torch.nn.functional as F 11 | 12 | from ssv2a.model.modules import PositionalEmbedding, TransEncoder, MLP, sample_normal 13 | # from ssv2a.train.loss import kld 14 | 15 | 16 | class Aggregator(nn.Module): 17 | def __init__(self, emb_dim=512, device='cuda'): 18 | super().__init__() 19 | self.device = device 20 | 21 | self.pe = PositionalEmbedding(emb_dim, resolution=1024, inject_method='add', device=device) 22 | self.encoder = TransEncoder(num_layers=1, embed_dim=emb_dim, nhead=8, dropout=.2, exp_rate=2) 23 | self.pred_token = nn.Parameter(torch.zeros(2, emb_dim)) 24 | nn.init.normal_(self.pred_token, mean=0, std=1) 25 | self.head_mu = MLP(layers=[emb_dim] * 2, dropout=.2) 26 | self.head_sigma = MLP(layers=[emb_dim] * 2, dropout=.2) 27 | self.out_mu = nn.Linear(emb_dim, emb_dim) 28 | self.out_sigma = nn.Linear(emb_dim, emb_dim) 29 | 30 | self.to(device) 31 | self.float() 32 | 33 | def forward(self, x): 34 | """ 35 | :param x: [B, L, E] 36 | """ 37 | x = x.to(self.device) 38 | x = torch.cat([torch.tile(self.pred_token, (x.shape[0], 1, 1)), x], dim=1) 39 | x = self.pe(x) 40 | x = self.encoder(x) 41 | return self.out_mu(self.head_mu(x[:, 0, :])), self.out_sigma(self.head_sigma(x[:, 1, :])) 42 | 43 | @torch.no_grad() 44 | def sample(self, x, var_samples=1): 45 | mu, sigma = self.forward(x) 46 | clap = torch.zeros(mu.shape).to(self.device) 47 | for i in range(var_samples): 48 | clap += sample_normal(mu, sigma) / var_samples 49 | return clap / clap.norm(p=2, dim=-1, keepdim=True) 50 | 51 | 52 | class VideoCLAPDataset(Dataset): 53 | def __init__(self, claps_dir): 54 | self.clap_fs = [str(p) for p in Path(claps_dir).glob('*.npy')] 55 | 56 | def __len__(self): 57 | return len(self.clap_fs) 58 | 59 | def __getitem__(self, idx): 60 | claps = np.load(self.clap_fs[idx]) 61 | return torch.from_numpy(claps[1:]), torch.from_numpy(claps[0]) 62 | 63 | 64 | def collate_claps(data): 65 | frame_claps = torch.stack([d[0] for d in data]) 66 | gt_claps = torch.stack([d[1] for d in data]) 67 | return frame_claps, gt_claps 68 | 69 | 70 | class AggTrainer: 71 | def __init__(self, model:Aggregator, claps_dir, ckpt_dir, batch_size=64, var_samples=1): 72 | claps_dir = Path(claps_dir) 73 | self.name = 'N.A.' 74 | self.ckpt_dir = Path(ckpt_dir) / self.name 75 | 76 | self.model = model 77 | 78 | self.train_loader = DataLoader(VideoCLAPDataset(claps_dir / 'train'), 79 | batch_size=batch_size, collate_fn=collate_claps) 80 | self.val_loader = DataLoader(VideoCLAPDataset(claps_dir / 'val'), 81 | batch_size=batch_size, collate_fn=collate_claps) 82 | self.test_loader = DataLoader(VideoCLAPDataset(claps_dir / 'test'), 83 | batch_size=batch_size, collate_fn=collate_claps) 84 | self.var_samples = var_samples 85 | 86 | def compute_loss(self, frame_claps, gt_claps): 87 | mu, sigma = self.model(frame_claps) 88 | # kl = torch.mean(kld(mu, sigma)) 89 | kl = 0 90 | 91 | mu = mu.tile(self.var_samples, 1) 92 | sigma = sigma.tile(self.var_samples, 1) 93 | gt_claps = gt_claps.tile(self.var_samples, 1).to(self.model.device) 94 | 95 | gen_claps = sample_normal(mu, sigma) 96 | gen_claps = gen_claps / gen_claps.norm(p=2, dim=-1, keepdim=True) 97 | gen_loss = torch.mean((1 - F.cosine_similarity(gt_claps, gen_claps)) ** 2) 98 | 99 | return gen_loss + .001 * kl 100 | 101 | 102 | def train(self, epochs=64, report_interval=1): 103 | best_val_loss = 1e6 104 | 105 | # wandb 106 | run = wandb.init(project='SDV2A-Agg') 107 | self.name = run.name 108 | self.ckpt_dir = self.ckpt_dir.parent / self.name 109 | os.makedirs(self.ckpt_dir, exist_ok=True) 110 | 111 | # optimizer 112 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4) 113 | 114 | # epoch 115 | for epoch in tqdm(range(epochs)): 116 | wandb_log = {} 117 | 118 | # step 119 | train_loss = [] 120 | for batch, (frame_claps, gt_claps) in enumerate(self.train_loader): 121 | loss = self.compute_loss(frame_claps, gt_claps) 122 | # back propagation 123 | loss.backward() 124 | # optimize 125 | optimizer.step() 126 | optimizer.zero_grad() 127 | train_loss.append(loss.detach().cpu().item()) 128 | train_loss = np.mean(train_loss) 129 | torch.save(self.model.state_dict(), self.ckpt_dir / 'latest.pth') 130 | 131 | # evaluate 132 | val_loss = [] 133 | for batch, (frame_claps, gt_claps) in enumerate(self.val_loader): 134 | loss = self.compute_loss(frame_claps, gt_claps).detach().cpu().item() 135 | val_loss.append(loss) 136 | val_loss = np.mean(val_loss) 137 | if val_loss < best_val_loss: 138 | best_val_loss = val_loss 139 | torch.save(self.model.state_dict(), self.ckpt_dir / 'best_val.pth') 140 | 141 | test_loss = [] 142 | for batch, (frame_claps, gt_claps) in enumerate(self.test_loader): 143 | loss = self.compute_loss(frame_claps, gt_claps).detach().cpu().item() 144 | test_loss.append(loss) 145 | test_loss = np.mean(test_loss) 146 | 147 | # report 148 | if epoch % report_interval == 0: 149 | print(f"Epoch {epoch + 1} - train loss: {train_loss:.5f} " 150 | f"validation loss: {val_loss:.5f} test loss: {test_loss:.5f}") 151 | 152 | wandb_log['train_loss'] = train_loss 153 | wandb_log['val_loss'] = val_loss 154 | wandb_log['test_loss'] = test_loss 155 | wandb.log(wandb_log) 156 | 157 | wandb.finish() 158 | 159 | -------------------------------------------------------------------------------- /ssv2a/model/aldm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reimplementation of AudioLDM and AudioLDM2 to accept CLAP embeddings instead of texts as condition. 3 | Adapted from https://github.com/haoheliu/AudioLDM2 4 | -- danke 5 | """ 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import yaml 11 | from transformers import logging 12 | 13 | from audioldm import get_metadata, download_checkpoint, default_audioldm_config, LatentDiffusion, seed_everything, \ 14 | read_wav_file, duration_to_latent_t_size, set_cond_audio 15 | from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution 16 | 17 | 18 | def ddpm_get_input(batch, k): 19 | fbank, log_magnitudes_stft, label_indices, fname, waveform, text, image, emb = batch 20 | ret = {} 21 | 22 | ret["fbank"] = ( 23 | fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() 24 | ) 25 | ret["stft"] = log_magnitudes_stft.to( 26 | memory_format=torch.contiguous_format 27 | ).float() 28 | # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() 29 | ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() 30 | ret["text"] = list(text) 31 | ret['image'] = list(image) 32 | ret["fname"] = fname 33 | ret['emb'] = emb 34 | return ret[k] 35 | 36 | 37 | class EmbAudioLDM(LatentDiffusion): 38 | def get_learned_conditioning(self, c): 39 | if self.cond_stage_forward is None: # true 40 | if hasattr(self.cond_stage_model, "encode") and callable( 41 | self.cond_stage_model.encode 42 | ): 43 | c = self.cond_stage_model.encode(c) 44 | if isinstance(c, DiagonalGaussianDistribution): 45 | c = c.mode() 46 | else: # true 47 | if self.cond_stage_key == 'emb': 48 | if len(c.shape) == 2: 49 | c = c[:, None, :] 50 | else: 51 | # Text input is list 52 | if type(c) == list and len(c) == 1: # true 53 | c = self.cond_stage_model([c[0], c[0]]) # clap/encoders.py 54 | c = c[0:1] # [1, 1, 512]) 55 | else: 56 | c = self.cond_stage_model(c) # torch.Size([1, 1, 512]) torch.cuda.FloatTensor float32 57 | else: 58 | assert hasattr(self.cond_stage_model, self.cond_stage_forward) 59 | c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) 60 | return c 61 | 62 | @torch.no_grad() 63 | def get_input( 64 | self, 65 | batch, 66 | k, 67 | return_first_stage_encode=True, 68 | return_first_stage_outputs=False, 69 | force_c_encode=False, 70 | cond_key=None, 71 | return_original_cond=False, 72 | bs=None, 73 | ): 74 | 75 | x = ddpm_get_input(batch, k) 76 | 77 | if bs is not None: # false 78 | x = x[:bs] 79 | 80 | x = x.to(self.device) 81 | 82 | if return_first_stage_encode: # true 83 | encoder_posterior = self.encode_first_stage(x) 84 | z = self.get_first_stage_encoding(encoder_posterior).detach() # torch.Size([10, 8, 256, 16]) 85 | else: 86 | z = None 87 | 88 | if self.model.conditioning_key is not None: # film 89 | if cond_key is None: 90 | cond_key = self.cond_stage_key 91 | if cond_key != self.first_stage_key: # true 92 | if cond_key in ["caption", "coordinates_bbox"]: 93 | xc = batch[cond_key] 94 | elif cond_key == "class_label": 95 | xc = batch 96 | else: 97 | # [bs, 1, 527] 98 | xc = ddpm_get_input(batch, cond_key) # 10,512 99 | if type(xc) == torch.Tensor: # false 100 | xc = xc.to(self.device) 101 | else: 102 | xc = x 103 | if not self.cond_stage_trainable or force_c_encode: # true, true 104 | if isinstance(xc, dict) or isinstance(xc, list): 105 | c = self.get_learned_conditioning(xc) 106 | else: 107 | c = self.get_learned_conditioning(xc.to(self.device)) # 10,1,512 108 | else: 109 | c = xc 110 | 111 | if bs is not None: # false 112 | c = c[:bs] 113 | 114 | else: 115 | c = None 116 | xc = None 117 | if self.use_positional_encodings: 118 | pos_x, pos_y = self.compute_latent_shifts(batch) 119 | c = {"pos_x": pos_x, "pos_y": pos_y} 120 | out = [z, c] 121 | if return_first_stage_outputs: 122 | xrec = self.decode_first_stage(z) 123 | out.extend([x, xrec]) 124 | if return_original_cond: 125 | out.append(xc) 126 | return out 127 | 128 | @torch.no_grad() 129 | def generate_sample( 130 | self, 131 | batchs, 132 | ddim_steps=200, 133 | ddim_eta=1.0, 134 | x_T=None, 135 | n_candidate_gen_per_text=1, 136 | unconditional_guidance_scale=1.0, 137 | unconditional_conditioning=None, 138 | name="waveform", 139 | use_plms=False, 140 | save=False, 141 | **kwargs, 142 | ): 143 | # Generate n_candidate_gen_per_text times and select the best 144 | # Batch: audio, text, fnames 145 | assert x_T is None 146 | try: 147 | batchs = iter(batchs) 148 | except TypeError: 149 | raise ValueError("The first input argument should be an iterable object") 150 | 151 | if use_plms: 152 | assert ddim_steps is not None 153 | use_ddim = ddim_steps is not None 154 | # waveform_save_path = os.path.join(self.get_log_dir(), name) 155 | # os.makedirs(waveform_save_path, exist_ok=True) 156 | # print("Waveform save path: ", waveform_save_path) 157 | 158 | with self.ema_scope("Generate"): 159 | waves = [] 160 | for batch in batchs: 161 | z, c = self.get_input( 162 | batch, 163 | self.first_stage_key, 164 | cond_key=self.cond_stage_key, 165 | return_first_stage_outputs=False, 166 | force_c_encode=True, 167 | return_original_cond=False, 168 | bs=None, 169 | ) 170 | text = ddpm_get_input(batch, "text") 171 | 172 | # Generate multiple samples 173 | batch_size = z.shape[0] * n_candidate_gen_per_text 174 | c = torch.cat([c] * n_candidate_gen_per_text, dim=0) 175 | text = text * n_candidate_gen_per_text 176 | 177 | if unconditional_guidance_scale != 1.0: 178 | unconditional_conditioning = ( 179 | self.cond_stage_model.get_unconditional_condition(batch_size) 180 | ) 181 | 182 | samples, _ = self.sample_log( 183 | cond=c, 184 | batch_size=batch_size, 185 | x_T=x_T, 186 | ddim=use_ddim, 187 | ddim_steps=ddim_steps, 188 | eta=ddim_eta, 189 | unconditional_guidance_scale=unconditional_guidance_scale, 190 | unconditional_conditioning=unconditional_conditioning, 191 | use_plms=use_plms, 192 | ) 193 | 194 | if (torch.max(torch.abs(samples)) > 1e2): 195 | samples = torch.clip(samples, min=-10, max=10) 196 | 197 | mel = self.decode_first_stage(samples) 198 | 199 | waveform = self.mel_spectrogram_to_waveform(mel) 200 | 201 | if waveform.shape[0] > 1: 202 | similarity = self.cond_stage_model.cos_similarity( 203 | torch.FloatTensor(waveform).squeeze(1), text 204 | ) 205 | 206 | best_index = [] 207 | for i in range(z.shape[0]): 208 | candidates = similarity[i:: z.shape[0]] 209 | max_index = torch.argmax(candidates).item() 210 | best_index.append(i + max_index * z.shape[0]) 211 | 212 | waveform = waveform[best_index] 213 | # print("Similarity between generated audio and text", similarity) 214 | # print("Choose the following indexes:", best_index) 215 | 216 | waves.append(waveform) 217 | 218 | return np.concatenate(waves) 219 | 220 | 221 | def make_batch_for_emb_to_audio(emb, waveform=None, fbank=None, batchsize=1): 222 | batches = [] 223 | B = emb.shape[0] 224 | 225 | for s in range(0, B, batchsize): 226 | e = min(B, s + batchsize) 227 | bs = e - s 228 | 229 | text = [''] * bs 230 | image = [''] * bs 231 | 232 | if bs < 1: 233 | print("Warning: Batchsize must be at least 1. Batchsize is set to .") 234 | 235 | if fbank is None: # true 236 | fb = torch.zeros((bs, 1024, 64)) # Not used, here to keep the code format 237 | else: 238 | fb = torch.FloatTensor(fbank[s:e]) 239 | fb = fb.expand(bs, 1024, 64) 240 | assert fb.size(0) == bs 241 | 242 | stft = torch.zeros((bs, 1024, 512)) # Not used 243 | 244 | if waveform is None: 245 | wave = torch.zeros((bs, 160000)) # Not used 16kHz*10s 246 | else: 247 | wave = torch.FloatTensor(waveform[s:e]) 248 | wave = wave.expand(bs, -1) 249 | assert wave.size(0) == bs 250 | 251 | fname = [""] * bs # Not used 252 | 253 | batch = ( 254 | fb, 255 | stft, 256 | None, 257 | fname, 258 | wave, 259 | text, 260 | image, 261 | emb[s:e] 262 | ) 263 | batches.append(batch) 264 | 265 | return batches 266 | 267 | 268 | def build_audioldm( # only model_name 269 | ckpt_path=None, 270 | config=None, 271 | model_name="audioldm-s-full", 272 | device=None 273 | ): 274 | # print(f"Load AudioLDM: {model_name}") 275 | 276 | if (ckpt_path is None): 277 | ckpt_path = get_metadata()[model_name]["path"] 278 | 279 | if (not os.path.exists(ckpt_path)): 280 | download_checkpoint(model_name) 281 | 282 | if device is None: 283 | if torch.cuda.is_available(): 284 | device = torch.device("cuda") 285 | else: 286 | device = torch.device("cpu") 287 | 288 | if config is not None: 289 | assert type(config) is str 290 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 291 | else: 292 | config = default_audioldm_config(model_name) 293 | 294 | # Use text as condition instead of using waveform during training 295 | config["model"]["params"]["device"] = device 296 | config["model"]["params"]["cond_stage_key"] = "text" 297 | 298 | # No normalization here 299 | latent_diffusion = EmbAudioLDM(**config["model"]["params"]) 300 | 301 | resume_from_checkpoint = ckpt_path 302 | checkpoint = torch.load(resume_from_checkpoint, map_location=device) 303 | latent_diffusion.load_state_dict(checkpoint["state_dict"], strict=False) 304 | 305 | latent_diffusion.eval() 306 | latent_diffusion = latent_diffusion.to(device) 307 | 308 | latent_diffusion.cond_stage_model.embed_mode = "text" 309 | return latent_diffusion 310 | 311 | 312 | def set_cond_emb(latent_diffusion): 313 | latent_diffusion.cond_stage_key = "emb" 314 | latent_diffusion.cond_stage_model.embed_mode = None 315 | return latent_diffusion 316 | 317 | 318 | def emb_to_audio( 319 | latent_diffusion, 320 | emb, 321 | original_audio_file_path=None, 322 | seed=42, 323 | ddim_steps=200, 324 | duration=10, 325 | batchsize=1, 326 | guidance_scale=2.5, 327 | n_candidate_gen_per_text=3 328 | ): 329 | seed_everything(int(seed)) 330 | logging.set_verbosity_error() 331 | waveform = None 332 | if (original_audio_file_path is not None): 333 | waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) 334 | 335 | batchs = make_batch_for_emb_to_audio(emb, waveform=waveform, batchsize=batchsize) 336 | 337 | latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) 338 | 339 | if (waveform is not None): 340 | print("Generate audio that has similar content as %s" % original_audio_file_path) 341 | latent_diffusion = set_cond_audio(latent_diffusion) 342 | else: 343 | # print("Generate audio using embedding", emb.shape) 344 | latent_diffusion = set_cond_emb(latent_diffusion) 345 | 346 | with torch.no_grad(): 347 | waveform = latent_diffusion.generate_sample( 348 | batchs, 349 | unconditional_guidance_scale=guidance_scale, 350 | ddim_steps=ddim_steps, 351 | n_candidate_gen_per_text=n_candidate_gen_per_text, 352 | duration=duration, 353 | ) 354 | return waveform 355 | 356 | -------------------------------------------------------------------------------- /ssv2a/model/clap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchaudio 4 | from audioldm.clap.encoders import CLAPAudioEmbeddingClassifierFreev2 as CLAP_ALD 5 | from tqdm.auto import tqdm 6 | 7 | from ssv2a.data.utils import download_audioldm_checkpoint, normalize_wav 8 | 9 | 10 | # a wrapper to build the clap model (AudioLDM2 flavor) 11 | class CLAP: 12 | def __init__(self, clap_version='audioldm2-full', embed_mode='audio', sample_rate=16000, device='cuda'): 13 | self.model = None 14 | ckpt_path = download_audioldm_checkpoint(clap_version) 15 | ckpt = torch.load(ckpt_path, map_location=device) 16 | self.model = CLAP_ALD( 17 | key="waveform", 18 | sampling_rate=sample_rate, 19 | embed_mode=embed_mode, 20 | unconditional_prob=0 21 | ) 22 | clap_ckpt = {} 23 | for k, v in ckpt["state_dict"].items(): 24 | if k.split('.')[0] == 'cond_stage_model': 25 | clap_ckpt[k.split('cond_stage_model.')[-1]] = v 26 | self.model.load_state_dict(clap_ckpt) 27 | self.model.eval() 28 | self.model.to(device) 29 | 30 | 31 | def clap_embed_texts(texts, version='audioldm-s-full-v2', bs=256, device='cuda'): 32 | clap = CLAP(clap_version=version, embed_mode='text', device=device) 33 | embeds = torch.zeros((len(texts) + 1, 512)) 34 | for s in tqdm(range(0, len(texts), bs)): 35 | e = min(len(texts), s + bs) 36 | emb = clap.model(texts[s:e]).squeeze().float().detach().cpu() 37 | embeds[s+1:e+1] = emb 38 | embeds = torch.nan_to_num(embeds) 39 | return embeds 40 | 41 | 42 | def clap_embed_auds(auds, clap_version='audioldm-s-full-v2', device='cuda'): 43 | with torch.no_grad(): 44 | clap = CLAP(clap_version=clap_version, embed_mode='audio', device=device) 45 | 46 | embeds = torch.empty(len(auds), 512) 47 | for i, aud in enumerate(auds): 48 | waveform, sr = torchaudio.load(aud) 49 | if sr != 16000: 50 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 51 | waveform = waveform.numpy()[0, ...] 52 | waveform = normalize_wav(waveform) 53 | waveform = waveform[None, ...] 54 | waveform = waveform / np.max(np.abs(waveform)) 55 | waveform = np.nan_to_num(0.5 * waveform) 56 | 57 | embeds[i, :] = clap.model(torch.from_numpy(waveform).float().to(device)).detach().cpu().squeeze() 58 | 59 | return embeds 60 | 61 | -------------------------------------------------------------------------------- /ssv2a/model/dalle2_prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import clip 4 | from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig 5 | 6 | 7 | class Dalle2Prior: 8 | def __init__(self, config_path, ckpt_path, device='cuda'): 9 | prior_config = TrainDiffusionPriorConfig.from_json_path(config_path).prior 10 | self.prior = prior_config.create().to(device) 11 | self.device = device 12 | 13 | states = torch.load(ckpt_path) 14 | if 'model' in states: 15 | states = states['model'] 16 | 17 | self.prior.load_state_dict(states, strict=True) 18 | 19 | def sample(self, texts, n_samples_per_batch=2, cond_scale=1): 20 | texts = clip.tokenize(texts, truncate=True).to(self.device) 21 | clips = self.prior.sample(texts, num_samples_per_batch=n_samples_per_batch, cond_scale=cond_scale) 22 | return clips / clips.norm(p=2, dim=-1, keepdim=True) 23 | 24 | -------------------------------------------------------------------------------- /ssv2a/model/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ssv2a.model.modules import MLP, TransEncoder, PositionalEmbedding, sample_normal, MoE 5 | # from ssv2a.train.loss import rrl 6 | 7 | 8 | class Generator(nn.Module): 9 | def __init__(self, clap_dim=512, manifold_dim=128, generator=None, device='cuda', **_): 10 | super().__init__() 11 | self.device = device 12 | generator['device'] = device 13 | self.model = None 14 | self.model_id = 'generator' 15 | self.variational = generator['variational'] 16 | 17 | hidden_dim = manifold_dim 18 | self.kl_weight = generator['kl_weight'] 19 | 20 | self.arch = generator['arch'] 21 | if self.arch == 'mlp': 22 | self.model = nn.Sequential() 23 | self.model.append(nn.Linear(manifold_dim, generator['layers'][0])) 24 | self.model.append(MLP(**generator)) 25 | hidden_dim = generator['layers'][-1] 26 | 27 | elif self.arch == 'transformer': 28 | self.embed_dim = generator['embed_dim'] 29 | self.out_dim = generator['out_dim'] 30 | if generator['pe_inject'] == 'cat': 31 | self.embed_dim = 2 * self.embed_dim 32 | self.patches = generator['patches'] 33 | 34 | # learnable generator embedding, attached to start of the sequence 35 | self.gen_embedding = nn.Parameter(torch.zeros(self.out_dim // self.embed_dim, self.embed_dim)) 36 | nn.init.normal_(self.gen_embedding, mean=0.0, std=1.0) 37 | 38 | self.pos_embed = PositionalEmbedding(self.embed_dim, 39 | resolution=generator['pe_res'], 40 | inject_method=generator['pe_inject'], 41 | device=device) 42 | 43 | self.in_proj = nn.Linear(manifold_dim // self.patches, self.embed_dim) # in projection 44 | self.transformer = TransEncoder(**generator) 45 | 46 | self.model = MLP(**generator) # model is the prediction head 47 | hidden_dim = generator['layers'][-1] 48 | 49 | elif self.arch == 'moe': 50 | self.experts = generator['experts'] 51 | self.moe = MoE(generator, experts=self.experts, 52 | diverse_experts=generator['diverse_experts'], device=device) 53 | 54 | self.rrl_weight = generator['rrl_weight'] 55 | self.router = MLP([manifold_dim, (manifold_dim + self.experts) // 2, self.experts]) 56 | 57 | self.model = nn.Linear(generator['layers'][-1], generator['layers'][-1], bias=False) # head is out projection 58 | 59 | hidden_dim = generator['layers'][-1] 60 | 61 | if self.model is None: 62 | raise Exception('Illegal config for generator, abort.') 63 | 64 | # out projection 65 | if generator['variational']: 66 | self.out_mu = nn.Linear(hidden_dim, clap_dim) 67 | self.out_sigma = nn.Linear(hidden_dim, clap_dim) 68 | else: 69 | self.out = nn.Linear(hidden_dim, clap_dim) 70 | 71 | self.to(device) 72 | self.float() 73 | 74 | def forward(self, x): 75 | # tile embedding into [B, L, E] if arch is transformer 76 | router_reg_loss = 0 77 | 78 | if self.arch == 'transformer': 79 | x = self.in_proj(x.reshape(x.shape[0], self.patches, -1)) 80 | x = torch.cat([self.gen_embedding.tile((x.shape[0], 1, 1)), x], dim=1) 81 | x = self.pos_embed(x) 82 | x = self.transformer(x)[:, :self.gen_embedding.shape[0], :] # extract prediction token 83 | x = torch.flatten(x, start_dim=1) 84 | 85 | elif self.arch == 'moe': 86 | ws = self.router(x) 87 | ws = torch.softmax(ws, dim=-1) 88 | 89 | if self.training: # router regularization loss 90 | # router_reg_loss = rrl(ws) * self.rrl_weight 91 | router_reg_loss = 0 92 | 93 | x = self.moe(x, ws) 94 | 95 | if not self.variational: 96 | return self.out(self.model(x)), router_reg_loss 97 | x = self.model(x) 98 | mu = self.out_mu(x) 99 | log_sigma = self.out_sigma(x) 100 | 101 | return mu, log_sigma, router_reg_loss 102 | 103 | def fold2claps(self, folds, var_samples=64): 104 | if self.variational: 105 | mu, sigma, _ = self.forward(folds) 106 | gen_claps = torch.zeros(mu.shape).to(self.device) 107 | for i in range(var_samples): 108 | gen_claps += sample_normal(mu, sigma) / var_samples 109 | else: 110 | gen_claps, _ = self.forward(folds) 111 | 112 | gen_claps = gen_claps / gen_claps.norm(p=2, dim=-1, keepdim=True) 113 | return gen_claps 114 | 115 | -------------------------------------------------------------------------------- /ssv2a/model/manifold.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ssv2a.data.utils import read_classes 7 | from ssv2a.model.modules import MLP, LinearProjection, sample_normal, TransEncoder, PositionalEmbedding, MoE 8 | # from ssv2a.train.loss import contrastive_loss, kld, rrl 9 | 10 | 11 | class ManifoldEncoder(nn.Module): 12 | def __init__(self, in_dim, manifold_dim, manifold=None, device='cuda'): 13 | super().__init__() 14 | manifold['device'] = device 15 | self.variational = manifold['variational'] 16 | self.model = None 17 | hidden_dim = in_dim 18 | 19 | self.arch = manifold['arch'] 20 | if self.arch == 'mlp': 21 | self.model = nn.Sequential() 22 | self.model.append(nn.Linear(in_dim, manifold['layers'][0], bias=False)) # in projection 23 | self.model.append(MLP(**manifold)) 24 | hidden_dim = manifold['layers'][-1] 25 | 26 | elif self.arch == 'linear': 27 | self.model = LinearProjection(in_dim, manifold_dim) 28 | hidden_dim = manifold_dim 29 | 30 | elif self.arch == 'transformer': 31 | self.embed_dim = manifold['embed_dim'] 32 | self.out_dim = manifold['out_dim'] 33 | if manifold['pe_inject'] == 'cat': 34 | self.embed_dim = 2 * self.embed_dim 35 | self.patches = manifold['patches'] 36 | 37 | # learnable manifold embedding, attached to start of the sequence 38 | self.fold_embedding = nn.Parameter(torch.zeros(self.out_dim // self.embed_dim, self.embed_dim)) 39 | nn.init.normal_(self.fold_embedding, mean=0.0, std=1.0) 40 | 41 | self.pos_embed = PositionalEmbedding(self.embed_dim, 42 | resolution=manifold['pe_res'], 43 | inject_method=manifold['pe_inject'], 44 | device=device) 45 | 46 | self.in_proj = nn.Linear(in_dim // self.patches, self.embed_dim) # in projection 47 | self.transformer = TransEncoder(**manifold) 48 | 49 | self.model = MLP(**manifold) # model is the prediction head 50 | hidden_dim = manifold['layers'][-1] 51 | 52 | elif self.arch == 'moe': 53 | self.experts = manifold['experts'] 54 | self.moe = MoE(manifold, experts=self.experts, 55 | diverse_experts=manifold['diverse_experts'], device=device) 56 | 57 | self.rrl_weight = manifold['rrl_weight'] 58 | self.router = MLP([in_dim, (in_dim + self.experts) // 2, self.experts]) 59 | 60 | self.model = nn.Linear(manifold['layers'][-1], manifold['layers'][-1], bias=False) # head is out projection 61 | 62 | hidden_dim = manifold['layers'][-1] 63 | 64 | if self.model is None: 65 | raise Exception('Illegal config for manifold encoder, abort.') 66 | 67 | # out projection 68 | if self.variational: 69 | self.out_mu = nn.Linear(hidden_dim, manifold_dim) 70 | self.out_sigma = nn.Linear(hidden_dim, manifold_dim) 71 | elif manifold['arch'] != 'linear': 72 | self.out = nn.Linear(hidden_dim, manifold_dim) 73 | 74 | def forward(self, x): 75 | router_reg_loss = 0 76 | 77 | if self.arch == 'transformer': 78 | x = self.in_proj(x.reshape(x.shape[0], self.patches, -1)) # patching 79 | x = torch.cat([self.fold_embedding.tile((x.shape[0], 1, 1)), x], dim=1) 80 | x = self.pos_embed(x) 81 | x = self.transformer(x)[:, :self.fold_embedding.shape[0], :] # extract prediction token 82 | x = torch.flatten(x, start_dim=1) 83 | 84 | elif self.arch == 'moe': 85 | ws = self.router(x) 86 | ws = torch.softmax(ws, dim=-1) 87 | 88 | if self.training: # router regularization loss 89 | # router_reg_loss = rrl(ws) * self.rrl_weight 90 | router_reg_loss = 0 91 | 92 | x = self.moe(x, ws) 93 | 94 | elif self.arch == 'linear': 95 | return self.model(x), router_reg_loss 96 | 97 | x = self.model(x) 98 | if not self.variational: 99 | return self.out(x), router_reg_loss 100 | mu = self.out_mu(x) 101 | log_sigma = self.out_sigma(x) 102 | 103 | return mu, log_sigma, router_reg_loss 104 | 105 | 106 | class Manifold(nn.Module): 107 | def __init__(self, clip_dim=512, clap_dim=512, manifold_dim=128, classes='', manifold=None, device='cuda', **_): 108 | super().__init__() 109 | self.device = device 110 | self.model_id = 'manifold' 111 | self.variational = manifold['variational'] 112 | self.kl_weight = manifold['kl_weight'] 113 | self.clap_dim = clap_dim 114 | self.clip_encoder = ManifoldEncoder(clip_dim, manifold_dim, manifold, device=device) 115 | self.clap_encoder = ManifoldEncoder(clap_dim, manifold_dim, manifold, device=device) 116 | 117 | # randomly initialize learnable contrastive temperatures 118 | self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) 119 | self.self_logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) 120 | self.logit_scale_min = math.log(1) 121 | self.logit_scale_max = math.log(100) 122 | 123 | self.cr_weight = manifold['cr_weight'] 124 | 125 | self.to(device) 126 | self.float() 127 | 128 | def forward(self, clips, claps, contrast_mask=None, kl_weight=.001, return_loss=False, var_samples=1): 129 | kl = 0 130 | loss = None 131 | self.logit_scale.data.clamp_(self.logit_scale_min, self.logit_scale_max) 132 | self.self_logit_scale.data.clamp_(self.logit_scale_min, self.logit_scale_max) 133 | 134 | if self.variational: 135 | clip_mu, clip_log_sigma, clip_rrl = self.clip_encoder(clips) 136 | clap_mu, clap_log_sigma, clap_rrl = self.clap_encoder(claps) 137 | 138 | clip_embeds = torch.empty(var_samples, clip_mu.shape[0], clip_mu.shape[1]).to(self.device) 139 | clap_embeds = torch.empty(var_samples, clap_mu.shape[0], clap_mu.shape[1]).to(self.device) 140 | for i in range(var_samples): 141 | emb1 = sample_normal(clip_mu, clip_log_sigma) 142 | clip_embeds[i, :, :] = emb1 143 | emb2 = sample_normal(clap_mu, clap_log_sigma) 144 | clap_embeds[i, :, :] = emb2 145 | 146 | # calculate kl distance together since clip and clap will be blended on the manifold 147 | if return_loss: 148 | mu = torch.cat([clip_mu, clap_mu]) 149 | log_sigma = torch.cat([clip_log_sigma, clap_log_sigma]) 150 | # kl = torch.mean(kld(mu, log_sigma)) 151 | kl = 0 152 | 153 | n_clip_embeds = clip_embeds / clip_embeds.norm(dim=-1, keepdim=True) 154 | n_clap_embeds = clap_embeds / clap_embeds.norm(dim=-1, keepdim=True) 155 | n_clip_embeds_t = torch.einsum('bij->bji', n_clip_embeds) 156 | n_clap_embeds_t = torch.einsum('bij->bji', n_clap_embeds) 157 | 158 | # random permutation for monte carlo contrastive loss 159 | clip_loss = 0 160 | clap_loss = 0 161 | B = n_clip_embeds.shape[1] 162 | for i in range(var_samples): # logits: [var_samples, B, B] 163 | p1 = torch.randperm(var_samples).to(self.device) 164 | p2 = torch.randperm(var_samples).to(self.device) 165 | p3 = torch.randperm(var_samples).to(self.device) 166 | p4 = torch.randperm(var_samples).to(self.device) 167 | 168 | logits_per_clap = torch.einsum('bij,bjk->bik', n_clip_embeds[p1], n_clap_embeds_t[p2]) 169 | logits_per_clip = torch.einsum('bij,bjk->bik', n_clap_embeds[p3], n_clip_embeds_t[p4]) 170 | 171 | if contrast_mask is not None: 172 | logits_per_clap *= contrast_mask 173 | logits_per_clip *= contrast_mask 174 | 175 | logits_per_clap *= self.logit_scale.exp() 176 | logits_per_clip *= self.logit_scale.exp() 177 | 178 | clap_loss += ( 179 | nn.functional.cross_entropy( 180 | logits_per_clap.reshape(-1, B), 181 | torch.arange(B).tile(var_samples).to(self.device))) 182 | clip_loss += ( 183 | nn.functional.cross_entropy( 184 | logits_per_clip.reshape(-1, B), 185 | torch.arange(B).tile(var_samples).to(self.device))) 186 | 187 | loss = (clap_loss + clip_loss) / var_samples * .5 + kl_weight * kl + clip_rrl + clap_rrl 188 | 189 | return (clip_embeds.reshape(-1, clip_embeds.size(-1)), 190 | clap_embeds.reshape(-1, clap_embeds.size(-1)), loss) 191 | 192 | else: 193 | clip_embeds, clip_rrl = self.clip_encoder(clips) 194 | clap_embeds, clap_rrl = self.clap_encoder(claps) 195 | 196 | # normalized features 197 | n_clip_embeds = clip_embeds / clip_embeds.norm(dim=-1, keepdim=True) 198 | n_clap_embeds = clap_embeds / clap_embeds.norm(dim=-1, keepdim=True) 199 | 200 | if return_loss: 201 | # cosine similarity as logits 202 | logits_per_clap = n_clap_embeds @ n_clip_embeds.t() * self.logit_scale.exp() 203 | logits_per_clip = n_clip_embeds @ n_clap_embeds.t() * self.logit_scale.exp() 204 | 205 | if contrast_mask is not None: 206 | logits_per_clap *= contrast_mask 207 | logits_per_clip *= contrast_mask 208 | 209 | # clap_loss = contrastive_loss(logits_per_clap) 210 | # clip_loss = contrastive_loss(logits_per_clip.t()) 211 | clap_loss = 0 212 | clip_loss = 0 213 | 214 | loss = (clap_loss + clip_loss) * .5 + kl_weight * kl + clip_rrl + clap_rrl 215 | 216 | return clip_embeds, clap_embeds, loss 217 | 218 | def fold_clips(self, clips, var_samples=1, normalize=False): 219 | if self.variational: 220 | mu, sigma, _ = self.clip_encoder(clips) 221 | fold_clips = torch.zeros(mu.shape).to(self.device) 222 | for i in range(var_samples): # repetitive sampling 223 | fold_clips += sample_normal(mu, sigma) / var_samples 224 | else: 225 | fold_clips, _ = self.clip_encoder(clips) 226 | 227 | if normalize: 228 | fold_clips = fold_clips / fold_clips.norm(p=2, dim=-1, keepdim=True) 229 | 230 | return fold_clips 231 | 232 | def fold_claps(self, claps, var_samples=1, normalize=False): 233 | if self.variational: 234 | mu, sigma, _ = self.clap_encoder(claps) 235 | fold_claps = torch.zeros(mu.shape).to(self.device) 236 | for i in range(var_samples): # repetitive sampling 237 | fold_claps += sample_normal(mu, sigma) / var_samples 238 | else: 239 | fold_claps, _ = self.clap_encoder(claps) 240 | 241 | if normalize: 242 | fold_claps = fold_claps / fold_claps.norm(p=2, dim=-1, keepdim=True) 243 | 244 | return fold_claps 245 | 246 | -------------------------------------------------------------------------------- /ssv2a/model/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | from typing import Callable, List, Optional 9 | 10 | 11 | # generate a Gaussian sample given a batch of means and variances 12 | def sample_normal(mu, log_sigma): 13 | std = torch.exp(0.5 * log_sigma) 14 | eps = std.data.new(std.size()).normal_() 15 | return eps.mul(std) + mu 16 | 17 | 18 | # good old positional embedding 19 | # adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html 20 | class PositionalEmbedding(nn.Module): 21 | def __init__(self, embed_dims, resolution=1024, inject_method='cat', device='cuda'): 22 | super().__init__() 23 | self.inject_method = inject_method 24 | self.resolution = resolution 25 | self.device = device 26 | pe = torch.zeros(resolution, embed_dims) 27 | position = torch.arange(resolution, dtype=torch.float).unsqueeze(1) 28 | div_term = torch.exp(torch.arange(0, embed_dims, 2).float() * (-math.log(2 * resolution) / embed_dims)) 29 | pe[:, 0::2] = torch.sin(position * div_term) 30 | pe[:, 1::2] = torch.cos(position * div_term) 31 | pe = pe.unsqueeze(0).transpose(0, 1) 32 | self.register_buffer('pe', pe) 33 | 34 | def forward(self, x, *args): 35 | # x: [B, L, E] steps: [B, L] 36 | ge = self.pe[:x.shape[1]].squeeze().tile(x.shape[0], 1, 1) 37 | 38 | if self.inject_method == 'add': 39 | return x + ge 40 | return torch.cat([x, ge], dim=-1) 41 | 42 | 43 | class LocalityEmbedding(PositionalEmbedding): 44 | def forward(self, x, *args): 45 | # x: [B, L, E] localities: [B, L] 46 | localities = torch.round(args[0] * (self.resolution - 1)).int() 47 | ge = torch.empty(x.shape).to(self.device) 48 | for i in range(localities.shape[0]): 49 | ge[i, :, :] = self.pe[localities[i]].squeeze() 50 | 51 | if self.inject_method == 'add': 52 | return x + ge 53 | return torch.cat([x, ge], dim=-1) 54 | 55 | 56 | ''' 57 | Transformer 58 | ''' 59 | 60 | 61 | # efficient attention, adapted from 62 | # https://github.com/mingyuan-zhang/MotionDiffuse/blob/main/text2motion/models/transformer.py 63 | class EfficientSelfAttention(nn.Module): 64 | def __init__(self, embed_dim, nhead, dropout): 65 | super().__init__() 66 | self.nhead = nhead 67 | self.query = nn.Linear(embed_dim, embed_dim) 68 | self.key = nn.Linear(embed_dim, embed_dim) 69 | self.value = nn.Linear(embed_dim, embed_dim) 70 | self.dropout = nn.Dropout(p=dropout) 71 | self.norm = nn.LayerNorm(embed_dim) 72 | self.highway = nn.Linear(embed_dim, embed_dim) 73 | 74 | def forward(self, x): 75 | # x: B, T, D 76 | B, T, D = x.shape 77 | H = self.nhead 78 | 79 | # linear projections and split into multiheads (B, T, H, D//H) 80 | nx = self.norm(x) 81 | query = self.query(nx).view(B, T, H, -1) 82 | key = self.key(nx).view(B, T, H, -1) 83 | value = self.value(nx).view(B, T, H, -1) 84 | 85 | # attention (B, T, H, D//H) -> (B, T, D) 86 | query = F.softmax(query, dim=-1) 87 | key = F.softmax(key, dim=-1) 88 | attention = self.dropout(torch.einsum('bnhd,bnhl->bhdl', key, value)) 89 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 90 | 91 | # residual 92 | y = self.highway(x) + y 93 | return y 94 | 95 | 96 | class EfficientCrossAttention(nn.Module): 97 | def __init__(self, embed_dim, cond_dim, nhead, dropout): 98 | super().__init__() 99 | self.nhead = nhead 100 | self.query = nn.Linear(embed_dim, embed_dim) 101 | self.key = nn.Linear(cond_dim, embed_dim) 102 | self.value = nn.Linear(cond_dim, embed_dim) 103 | self.dropout = nn.Dropout(p=dropout) 104 | self.norm1 = nn.LayerNorm(embed_dim) 105 | self.norm2 = nn.LayerNorm(cond_dim) 106 | self.highway = nn.Linear(embed_dim, embed_dim) 107 | 108 | def forward(self, x1, x2): 109 | """ 110 | x1: B, T, D 111 | x2: B, N, L 112 | """ 113 | B, T, D = x1.shape 114 | N = x2.shape[1] 115 | H = self.nhead 116 | 117 | # linear projections and split into multiheads (B, T, H, D//H) 118 | nx1 = self.norm1(x1) 119 | nx2 = self.norm2(x2) 120 | query = self.query(nx1).view(B, T, H, -1) 121 | key = self.key(nx2).view(B, N, H, -1) 122 | value = self.value(nx2).view(B, N, H, -1) 123 | 124 | # attention (B, T, H, D//H), (B, N, H, D//H) -> (B, T, D) 125 | query = F.softmax(query, dim=-1) 126 | key = F.softmax(key, dim=-1) 127 | attention = self.dropout(torch.einsum('bnhd,bnhl->bhdl', key, value)) 128 | y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) 129 | 130 | # residual 131 | y = self.highway(x1) + y 132 | return y 133 | 134 | 135 | def zero_module(module): 136 | """ 137 | Zero out the parameters of a module and return it. 138 | """ 139 | for p in module.parameters(): 140 | p.detach().zero_() 141 | return module 142 | 143 | 144 | class FFN(nn.Module): 145 | def __init__(self, embed_dim, hidden_dim, dropout): 146 | super().__init__() 147 | self.linear1 = nn.Linear(embed_dim, hidden_dim) 148 | self.linear2 = zero_module(nn.Linear(hidden_dim, embed_dim)) 149 | self.activation = nn.GELU() 150 | self.dropout = nn.Dropout(p=dropout) 151 | self.norm = nn.LayerNorm(embed_dim) 152 | 153 | def forward(self, x): 154 | y = self.linear2(self.dropout(self.activation(self.linear1(self.norm(x))))) 155 | y = x + y 156 | return y 157 | 158 | 159 | class ESALayer(nn.Module): 160 | def __init__(self, embed_dim, ffn_dim, nhead, dropout): 161 | super().__init__() 162 | self.esa = EfficientSelfAttention(embed_dim, nhead, dropout) 163 | self.ffn = FFN(embed_dim, ffn_dim, dropout) 164 | 165 | def forward(self, x): 166 | return self.ffn(self.esa(x)) 167 | 168 | 169 | class ECALayer(nn.Module): 170 | def __init__(self, embed_dim, cond_dim, ffn_dim, nhead, dropout): 171 | super().__init__() 172 | self.eca = EfficientCrossAttention(embed_dim, cond_dim, nhead, dropout) 173 | self.ffn = FFN(embed_dim, ffn_dim, dropout) 174 | 175 | def forward(self, x1, x2): 176 | return self.ffn(self.eca(x1, x2)) 177 | 178 | 179 | class TransEncoder(nn.Module): 180 | def __init__(self, num_layers=3, embed_dim=512, nhead=8, dropout=.2, exp_rate=1, **_): 181 | super().__init__() 182 | self.encoder = nn.Sequential() 183 | for i in range(num_layers): 184 | self.encoder.append(ESALayer(embed_dim, embed_dim * exp_rate, nhead, dropout)) 185 | 186 | def forward(self, src): 187 | return self.encoder(src) 188 | 189 | 190 | class TransDecoder(nn.Module): 191 | def __init__(self, num_layers=6, embed_dim=512, cond_dim=512, nhead=8, dropout=.2, **_): 192 | super().__init__() 193 | self.decoder = nn.ModuleList() 194 | for i in range(num_layers): 195 | self.decoder.append(ECALayer(embed_dim, cond_dim, embed_dim * 2, nhead, dropout)) 196 | 197 | def forward(self, src, cond): 198 | for mod in self.decoder: 199 | src = mod(src, cond) 200 | return src 201 | 202 | 203 | ''' 204 | Linear Projection 205 | ''' 206 | 207 | 208 | class LinearProjection(nn.Module): 209 | def __init__(self, in_dim, out_dim, bias=True): 210 | super().__init__() 211 | self.A = nn.Parameter(torch.randn(out_dim, in_dim)) 212 | # SVD 213 | self.U = None 214 | self.S = None 215 | self.Vh = None 216 | if bias: 217 | self.b = nn.Parameter(torch.zeros(1)) 218 | else: 219 | self.b = 0 220 | 221 | def forward(self, x): 222 | return torch.einsum('ij,bj->bi', self.A, x) + self.b 223 | 224 | def svd(self): 225 | if self.U is None: # solve once 226 | with torch.no_grad(): 227 | self.U, self.S, self.Vh = torch.linalg.svd(self.A, full_matrices=False) 228 | return self.U, self.S, self.Vh 229 | 230 | def solve(self, x): 231 | # x : [B, E] 232 | self.svd() 233 | # broadcast inversion 234 | return torch.einsum('ij,bj->bi', 235 | self.Vh.t() @ torch.diag(1.0 / self.S) @ self.U.t(), 236 | x - self.b) 237 | 238 | 239 | ''' 240 | MLPs 241 | ''' 242 | 243 | 244 | # wrap a given module with residual shortcut 245 | class ResWrapper(nn.Module): 246 | def __init__(self, mod, proj): 247 | super().__init__() 248 | self.mod = mod 249 | self.proj = proj 250 | 251 | def forward(self, x): 252 | return self.proj(x) + self.mod(x) 253 | 254 | 255 | class MLP(nn.Module): 256 | def __init__(self, layers, resnet=False, dropout=.1, activation=nn.GELU, **_): # pass in layers as a list of ints 257 | super().__init__() 258 | resnet = resnet and len(layers) > 2 259 | self.model = nn.Sequential() 260 | idx = 0 261 | if resnet: 262 | for i in range(0, (len(layers) - 1) // 2): 263 | a, b, c = i * 2, i * 2 + 1, i * 2 + 2 264 | idx = c 265 | mod = nn.Sequential() 266 | mod.append(nn.Linear(layers[a], layers[b])) 267 | mod.append(nn.LayerNorm(layers[b])) 268 | if dropout: 269 | mod.append(nn.Dropout(p=dropout)) 270 | mod.append(activation()) 271 | mod.append(nn.Linear(layers[b], layers[c])) 272 | mod.append(nn.LayerNorm(layers[c])) 273 | if dropout: 274 | mod.append(nn.Dropout(p=dropout)) 275 | self.model.append(ResWrapper(mod, nn.Linear(layers[a], layers[c]))) 276 | self.model.append(activation()) 277 | 278 | for i in range(idx, len(layers) - 1): 279 | self.model.append(nn.Linear(layers[i], layers[i + 1])) 280 | self.model.append(nn.LayerNorm(layers[i + 1])) 281 | if dropout: 282 | self.model.append(nn.Dropout(p=dropout)) 283 | self.model.append(activation()) 284 | 285 | def forward(self, x): 286 | return self.model(x) 287 | 288 | 289 | class MoE(nn.Module): 290 | def __init__(self, config, experts=8, diverse_experts=False, device='cuda'): 291 | super().__init__() 292 | cfg = copy.deepcopy(config) 293 | self.out_dim = cfg['layers'][-1] 294 | 295 | if diverse_experts: 296 | exps = [] 297 | in_dim = cfg['layers'][0] 298 | hidden_dims = copy.deepcopy(cfg['layers'][1:-1]) 299 | for i in range(len(hidden_dims)): 300 | cfg['layers'] = [in_dim] + hidden_dims[:i] + [self.out_dim] 301 | expn = experts // len(hidden_dims) 302 | if i == len(hidden_dims) - 1: 303 | expn = experts - i * expn 304 | exps += [MLP(**cfg) for _ in range(expn)] 305 | 306 | else: 307 | exps = [MLP(**cfg) for _ in range(experts)] 308 | 309 | self.experts = nn.ModuleList(exps) 310 | 311 | self.device = device 312 | 313 | def forward(self, x, weights, sequential=False): 314 | # x: [B, E] is not sequential, else [B, L, E]; weights: [B, experts], else [B, L, experts] 315 | if not sequential: 316 | y = torch.empty(x.shape[0], len(self.experts), self.out_dim).to(self.device) # y: [B, experts, E] 317 | for i, e in enumerate(self.experts): 318 | ep = e(x) 319 | y[:, i, :] = ep 320 | y = torch.einsum('bij,bi->bj', y, weights) # y: [B, E] 321 | else: 322 | # y: [B, L, experts, E] 323 | y = torch.empty(x.shape[0], x.shape[1], len(self.experts), self.out_dim).to(self.device) 324 | for i, e in enumerate(self.experts): 325 | ep = e(x) # [B, L, E] 326 | y[:, :, i, :] = ep 327 | y = torch.einsum('bijk,bij->bk', y, weights) # y: [B, E] 328 | return y 329 | 330 | 331 | ''' 332 | The snippet below defines AudioResNet 333 | Adapted from https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html 334 | Pretty much the same old resnet, but modified to fit clap audio representation (1d conv). 335 | ''' 336 | 337 | 338 | def conv3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d: 339 | """3 convolution with padding""" 340 | return nn.Conv1d( 341 | in_planes, 342 | out_planes, 343 | kernel_size=3, 344 | stride=stride, 345 | padding=dilation, 346 | groups=groups, 347 | bias=False, 348 | dilation=dilation 349 | ) 350 | 351 | 352 | def conv1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d: 353 | """1 convolution""" 354 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 355 | 356 | 357 | class Bottleneck(nn.Module): 358 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 359 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 360 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 361 | # This variant is also known as ResNet V1.5 and improves accuracy according to 362 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 363 | 364 | expansion: int = 4 365 | 366 | def __init__( 367 | self, 368 | inplanes: int, 369 | planes: int, 370 | stride: int = 1, 371 | downsample: Optional[nn.Module] = None, 372 | groups: int = 1, 373 | base_width: int = 64, 374 | dilation: int = 1, 375 | norm_layer: Optional[Callable[..., nn.Module]] = None, 376 | ) -> None: 377 | super().__init__() 378 | if norm_layer is None: 379 | norm_layer = nn.BatchNorm1d 380 | width = int(planes * (base_width / 64.0)) * groups 381 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 382 | self.conv1 = conv1(inplanes, width) 383 | self.bn1 = norm_layer(width) 384 | self.conv2 = conv3(width, width, stride, groups, dilation) 385 | self.bn2 = norm_layer(width) 386 | self.conv3 = conv1(width, planes * self.expansion) 387 | self.bn3 = norm_layer(planes * self.expansion) 388 | self.relu = nn.ReLU(inplace=True) 389 | self.downsample = downsample 390 | self.stride = stride 391 | 392 | def forward(self, x: Tensor) -> Tensor: 393 | identity = x 394 | 395 | out = self.conv1(x) 396 | out = self.bn1(out) 397 | out = self.relu(out) 398 | 399 | out = self.conv2(out) 400 | out = self.bn2(out) 401 | out = self.relu(out) 402 | 403 | out = self.conv3(out) 404 | out = self.bn3(out) 405 | 406 | if self.downsample is not None: 407 | identity = self.downsample(x) 408 | 409 | out += identity 410 | out = self.relu(out) 411 | 412 | return out 413 | 414 | 415 | class AudioResNet(nn.Module): 416 | def __init__( 417 | self, 418 | layers: List[int], 419 | planes: List[int], 420 | in_ch: int = 1, 421 | out_dim: int = 1024, 422 | zero_init_residual: bool = False, 423 | groups: int = 1, 424 | width_per_group: int = 64, 425 | norm_layer: Optional[Callable[..., nn.Module]] = None 426 | ) -> None: 427 | super().__init__() 428 | if norm_layer is None: 429 | norm_layer = nn.BatchNorm1d 430 | self._norm_layer = norm_layer 431 | 432 | self.inplanes = planes[0] 433 | self.dilation = 1 434 | 435 | self.groups = groups 436 | self.base_width = width_per_group 437 | self.conv1 = nn.Conv1d(in_ch, self.inplanes, kernel_size=9, stride=4, padding=4, bias=False) 438 | self.bn1 = norm_layer(self.inplanes) 439 | self.relu = nn.ReLU(inplace=True) 440 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 441 | 442 | self.layer1 = self._make_layer(planes[0], layers[0]) 443 | self.layers = nn.Sequential() 444 | for i in range(1, len(layers)): 445 | self.layers.append(self._make_layer(planes[i], layers[1], stride=2, dilate=False)) 446 | self.avgpool = nn.AdaptiveAvgPool1d(1) 447 | self.fc = nn.Linear(planes[-1] * Bottleneck.expansion, out_dim) 448 | 449 | for m in self.modules(): 450 | if isinstance(m, nn.Conv1d): 451 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 452 | elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): 453 | nn.init.constant_(m.weight, 1) 454 | nn.init.constant_(m.bias, 0) 455 | 456 | # Zero-initialize the last BN in each residual branch, 457 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 458 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 459 | if zero_init_residual: 460 | for m in self.modules(): 461 | if m.bn3.weight is not None: 462 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 463 | 464 | def _make_layer( 465 | self, 466 | planes: int, 467 | blocks: int, 468 | stride: int = 1, 469 | dilate: bool = False, 470 | ) -> nn.Sequential: 471 | norm_layer = self._norm_layer 472 | downsample = None 473 | previous_dilation = self.dilation 474 | if dilate: 475 | self.dilation *= stride 476 | stride = 1 477 | if stride != 1 or self.inplanes != planes * Bottleneck.expansion: 478 | downsample = nn.Sequential( 479 | conv1(self.inplanes, planes * Bottleneck.expansion, stride), 480 | norm_layer(planes * Bottleneck.expansion), 481 | ) 482 | 483 | layers = [] 484 | layers.append( 485 | Bottleneck( 486 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 487 | ) 488 | ) 489 | self.inplanes = planes * Bottleneck.expansion 490 | for _ in range(1, blocks): 491 | layers.append( 492 | Bottleneck( 493 | self.inplanes, 494 | planes, 495 | groups=self.groups, 496 | base_width=self.base_width, 497 | dilation=self.dilation, 498 | norm_layer=norm_layer, 499 | ) 500 | ) 501 | 502 | return nn.Sequential(*layers) 503 | 504 | def _forward_impl(self, x: Tensor) -> Tensor: 505 | # See note [TorchScript super()] 506 | x = self.conv1(x) 507 | x = self.bn1(x) 508 | x = self.relu(x) 509 | x = self.maxpool(x) 510 | 511 | x = self.layer1(x) 512 | x = self.layers(x) 513 | 514 | x = self.avgpool(x) 515 | x = torch.flatten(x, 1) 516 | x = self.fc(x) 517 | 518 | return x 519 | 520 | def forward(self, x: Tensor) -> Tensor: 521 | return self._forward_impl(x) 522 | 523 | -------------------------------------------------------------------------------- /ssv2a/model/pipeline.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gc 3 | import json 4 | import os.path 5 | from pathlib import Path 6 | from shutil import rmtree 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | import soundfile as sf 13 | 14 | from ssv2a.data.detect import detect 15 | from ssv2a.data.tpairs import tpairs2tclips 16 | from ssv2a.data.utils import clip_embed_images, get_timestamp, save_wave, set_seed, emb2seq, batch_extract_frames, \ 17 | prior_embed_texts 18 | from ssv2a.model.aggregator import Aggregator 19 | from ssv2a.model.clap import clap_embed_auds 20 | from ssv2a.model.aldm import build_audioldm, emb_to_audio 21 | from ssv2a.model.generator import Generator 22 | from ssv2a.model.manifold import Manifold 23 | from ssv2a.model.remixer import Remixer 24 | 25 | 26 | class Pipeline(nn.Module): 27 | def __init__(self, config, pretrained=None, device='cuda'): 28 | super().__init__() 29 | if not isinstance(config, dict): 30 | with open(config, 'r') as fp: 31 | config = json.load(fp) 32 | self.ckpt_path = Path(config['checkpoints']) 33 | self.config = config 34 | self.device = device 35 | config['device'] = device 36 | self.clip_dim = config['clip_dim'] 37 | self.clap_dim = config['clap_dim'] 38 | self.fold_dim = config['manifold_dim'] 39 | 40 | # SDV2A Manifold 41 | self.manifold = Manifold(**config) 42 | # if the generator is just a linear operation, ignore modelling 43 | if config['generator']['disabled']: 44 | self.generator = None 45 | self.skip_gen_model = True 46 | elif config['generator']['arch'] == 'linear': 47 | self.generator = self.linear_generator 48 | self.skip_gen_model = True 49 | else: 50 | self.generator = Generator(**config) 51 | self.skip_gen_model = False 52 | 53 | # SDV2A Remixer 54 | self.remixer = Remixer(**config) 55 | 56 | # if there is any pretrained, load 57 | if pretrained: 58 | self.load(pretrained) 59 | 60 | # timestamp 61 | self.timestamp = get_timestamp() 62 | 63 | def linear_generator(self, fold_clips): 64 | gen_claps = self.manifold.clap_encoder.model.solve(fold_clips) 65 | return gen_claps, 0 # add a fake kl loss term to align with other generator models 66 | 67 | def save(self, filepath): 68 | state = { 69 | 'timestamp': get_timestamp(), 70 | 'manifold_state': self.manifold.state_dict(), 71 | 'generator_state': None if self.skip_gen_model else self.generator.state_dict(), 72 | 'remixer_state': self.remixer.state_dict() 73 | } 74 | torch.save(state, filepath) 75 | 76 | def load(self, filepath): 77 | state = torch.load(filepath, map_location='cpu') 78 | mia_states = [] 79 | 80 | if 'manifold_state' in state: 81 | self.manifold.load_state_dict(state['manifold_state']) 82 | else: 83 | mia_states.append('manifold') 84 | 85 | if 'generator_state' in state and not self.skip_gen_model: 86 | self.generator.load_state_dict(state['generator_state']) 87 | else: 88 | mia_states.append('generator') 89 | 90 | if 'remixer_state' in state: 91 | self.remixer.load_state_dict(state['remixer_state']) 92 | else: 93 | mia_states.append('remixer') 94 | 95 | if len(mia_states) > 0: 96 | print(f"These states are missing in the model checkpoint supplied:\n" 97 | f"{' '.join(mia_states)}\n" 98 | f"Inference will be funky if these modules are involved without training!") 99 | 100 | self.timestamp = state['timestamp'] 101 | 102 | def __str__(self): 103 | return (f"SDV2A@{self.timestamp}" 104 | f"{json.dumps(self.config, sort_keys=True, indent=4)}") 105 | 106 | # postprocessing: cycle generation 107 | def cycle_mix(self, clips, fixed_src=None, its=1, var_samples=64, samples=16, shuffle=True): 108 | """ 109 | clips: [B, slot, E] (global clip injected at first token) 110 | """ 111 | B = clips.shape[0] 112 | rt_claps = torch.empty(B, 512).to(self.device) # [B, E] 113 | rt_scores = torch.ones(B).to(self.device) # [B] 114 | src = self.manifold.fold_clips(clips, var_samples=var_samples, normalize=False) 115 | for i in range(its + 1): # one more round to include its = 0 (no recursion) 116 | if fixed_src is not None: # reinject audio sources 117 | src_mask = torch.sum(clips.bool(), dim=-1).bool().logical_not() # zero slots 118 | src[src_mask, ...] = fixed_src[src_mask, ...] 119 | 120 | src_claps = self.generator.fold2claps(src, var_samples=var_samples) 121 | 122 | src_claps[:, 0, :] = 0 # suppress global clap in clap score later 123 | 124 | if i > 0: # recursion, inject the best clap as the global source code 125 | src[:, 0, :] = self.manifold.fold_claps(rt_claps, var_samples=var_samples, normalize=False) 126 | 127 | # sample remixed claps 128 | remix_claps = torch.empty(src_claps.shape[0], samples, 512).to(self.device) # [B, S, E] 129 | for j in range(samples): 130 | if self.remixer.guidance == 'generator': 131 | remix_claps[:, j, :] = self.remixer.sample(self.generator.fold2claps(src, var_samples=var_samples), 132 | clips, 133 | var_samples=var_samples, normalize=True, shuffle=shuffle) 134 | elif self.remixer.guidance == 'manifold+generator': 135 | src_code = torch.cat([src, self.generator.fold2claps(src, var_samples=var_samples)], dim=-1) 136 | remix_claps[:, j, :] = self.remixer.sample(src_code, clips, 137 | var_samples=var_samples, normalize=True, shuffle=shuffle) 138 | else: 139 | remix_claps[:, j, :] = self.remixer.sample(src, clips, 140 | var_samples=var_samples, normalize=True, shuffle=shuffle) 141 | 142 | # select the remixed clap with highest CLAP-Score, can use std or mean 143 | clap_score = torch.einsum('bmi,bni->bmn', remix_claps, src_claps) # [B, S, slot] 144 | clap_score /= (torch.einsum('bmi,bni->bmn', 145 | remix_claps.norm(dim=-1, keepdim=True), 146 | src_claps.norm(dim=-1, keepdim=True)) + 1e-6) 147 | clap_score = torch.std(clap_score, dim=-1) # [B, S, slot] -> [B, S] 148 | best_remix_claps = torch.argmin(clap_score, dim=-1) 149 | clap_score = clap_score[torch.arange(B), best_remix_claps] 150 | updates = clap_score < rt_scores 151 | rt_scores[updates] = clap_score[updates] 152 | best_remix_claps = remix_claps[torch.arange(B), best_remix_claps] # [B, S, E] -> [B, E] 153 | rt_claps[updates] = best_remix_claps[updates] 154 | 155 | return rt_claps 156 | 157 | # reconstruct an audio's clap 158 | def recon_claps(self, claps, var_samples=64): 159 | fold_claps = self.manifold.fold_claps(claps, var_samples=var_samples) 160 | gen_claps = self.generator.fold2claps(fold_claps, var_samples=var_samples) 161 | return gen_claps 162 | 163 | def clips2foldclaps(self, clips, var_samples=64): 164 | fold_clips = self.manifold.fold_clips(clips, var_samples=var_samples, normalize=False) 165 | gen_claps = self.generator.fold2claps(fold_clips, var_samples=var_samples) 166 | return gen_claps 167 | 168 | def clips2folds(self, clips, var_samples=64, normalize=False): 169 | fold_clips = self.manifold.fold_clips(clips, var_samples=var_samples, normalize=normalize) 170 | return fold_clips 171 | 172 | def claps2folds(self, claps, var_samples=64, normalize=False): 173 | fold_claps = self.manifold.fold_claps(claps, var_samples=var_samples, normalize=normalize) 174 | return fold_claps 175 | 176 | def clips2clap(self, clips, var_samples=64, normalize=False): 177 | src = self.clips2folds(clips, var_samples=var_samples, normalize=False) 178 | if self.remixer.guidance == 'generator': 179 | src = self.generator.fold2claps(src, var_samples=var_samples) 180 | elif self.remixer.guidance == 'manifold+generator': 181 | fold_gen_claps = self.generator.fold2claps(src, var_samples=var_samples) 182 | src = torch.cat([src, fold_gen_claps], dim=-1) 183 | clap = self.remixer.sample(src, clips, var_samples=var_samples, normalize=normalize) 184 | return clap 185 | 186 | def tpairs2clap(self, pairs, var_samples=64, normalize=False): 187 | clips = tpairs2tclips(pairs, max_length=self.remixer.slot, device=self.device) 188 | clap = self.clips2clap(clips, var_samples, normalize) 189 | return clap 190 | 191 | 192 | # in this application we recycle models to save memory, the intermediate products are saved to disk under data_cache 193 | @torch.no_grad() 194 | def image_to_audio(images, text="", transcription="", save_dir="", config=None, 195 | gen_remix=True, gen_tracks=False, emb_only=False, 196 | pretrained=None, batch_size=64, var_samples=1, 197 | shuffle_remix=True, cycle_its=3, cycle_samples=16, keep_data_cache=False, 198 | duration=10, seed=42, device='cuda'): 199 | set_seed(seed) 200 | # revert to default model config if not supplied 201 | if not os.path.exists(config): 202 | config = Path().resolve() / 'configs' / 'model.json' 203 | with open(config, 'r') as fp: 204 | config = json.load(fp) 205 | 206 | if not save_dir: 207 | save_dir = Path().resolve() / 'output' # default saving folder 208 | else: 209 | save_dir = Path(save_dir) 210 | if not os.path.exists(save_dir): 211 | os.makedirs(save_dir) 212 | if gen_tracks: 213 | os.makedirs(save_dir / 'tracks') 214 | cache_dir = save_dir / 'data_cache' 215 | 216 | # segmentation proposal 217 | if not isinstance(images, dict): 218 | local_imgs = detect(images, config['detector'], 219 | save_dir=cache_dir / 'masked_images', batch_size=batch_size, device=device) 220 | else: 221 | local_imgs = copy.deepcopy(images) 222 | images = [k for k in images] 223 | keep_data_cache = True # prevent deleting nonexistent folder 224 | 225 | # clip embed 226 | global_clips = clip_embed_images(images, batch_size=batch_size, device=device) 227 | imgs = [] 228 | for img in images: 229 | imgs += [li for li, _ in local_imgs[img]] 230 | local_clips = clip_embed_images(imgs, batch_size=batch_size, device=device) 231 | 232 | jumps = [len(local_imgs[img]) for img in local_imgs] 233 | 234 | # SDV2A 235 | model = Pipeline(copy.deepcopy(config), pretrained, device) 236 | model.eval() 237 | with torch.no_grad(): 238 | # clips to claps 239 | local_claps = model.clips2foldclaps(local_clips, var_samples=var_samples) 240 | 241 | if gen_remix: 242 | # remix 243 | remix_clips = emb2seq(jumps, local_clips, max_length=model.remixer.slot, delay=1, device=model.device) 244 | remix_clips[:, 0, :] = global_clips # blend in global clip 245 | remix_clap = model.cycle_mix(remix_clips, its=cycle_its, var_samples=var_samples, 246 | samples=cycle_samples, shuffle=shuffle_remix) 247 | 248 | del remix_clips 249 | 250 | if emb_only: 251 | if not keep_data_cache: 252 | rmtree(cache_dir) 253 | return remix_clap.detach().cpu().numpy() 254 | 255 | # clean up gpu 256 | # del global_clips, local_clips, remix_clips 257 | del local_clips 258 | 259 | audioldm_v = config['audioldm_version'] 260 | # AudioLDM 261 | model = build_audioldm(model_name=audioldm_v, device=device) 262 | if gen_tracks: 263 | local_wave = emb_to_audio(model, local_claps, batchsize=batch_size, duration=duration()) 264 | if gen_remix: 265 | waveform = emb_to_audio(model, remix_clap, batchsize=batch_size, duration=duration) 266 | 267 | # I/O 268 | if gen_tracks: 269 | local_names = [Path(img).name.replace('.png', '') for img in imgs] 270 | save_wave(local_wave, save_dir / 'tracks', name=local_names) 271 | if gen_remix: 272 | save_wave(waveform, save_dir, 273 | name=[os.path.basename(img).replace('.png', '') for img in images]) 274 | if not keep_data_cache: 275 | rmtree(cache_dir) 276 | 277 | 278 | @torch.no_grad() 279 | def video_to_claps(config, pretrained, videos, save_dir, frames=64, batch_size=256, var_samples=64, 280 | shuffle_remix=True, cycle_its=4, cycle_samples=64, seed=42, device='cuda'): 281 | cache_dir = Path(save_dir) / 'cache' 282 | 283 | print('Extracting frames and generate high-level audios:') 284 | result_claps = [] 285 | for s in tqdm(range(0, len(videos), batch_size)): 286 | # extract frames 287 | os.makedirs(cache_dir, exist_ok=True) 288 | e = min(len(videos), s + batch_size) 289 | batch_extract_frames(videos[s:e], cache_dir, size=(512, 512), frames=frames, num_workers=8) 290 | 291 | # get generated claps 292 | imgs = [str(p) for p in cache_dir.glob('*.png')] 293 | gen_claps = image_to_audio(imgs, save_dir=str(cache_dir), config=config, 294 | gen_remix=True, gen_tracks=False, emb_only=True, 295 | pretrained=pretrained, batch_size=64, var_samples=var_samples, 296 | shuffle_remix=shuffle_remix, cycle_its=cycle_its, cycle_samples=cycle_samples, 297 | keep_data_cache=False, seed=seed, device=device) 298 | 299 | # map to output 300 | for video_f in videos[s:e]: 301 | vid = str(os.path.basename(video_f).replace('.mp4', '')) 302 | gen_clap = [None] * frames 303 | for i, img in enumerate(imgs): 304 | if vid in img: 305 | img_idx = int(img[img.rfind('_') + 1:img.rfind('.')]) 306 | gen_clap[img_idx] = gen_claps[i] 307 | result_claps.append(np.stack(gen_clap)) 308 | 309 | rmtree(cache_dir) 310 | 311 | gc.collect() 312 | torch.cuda.empty_cache() 313 | 314 | return np.stack(result_claps) 315 | 316 | 317 | @torch.no_grad() 318 | def video_to_audio(config, pretrained, videos, agg_ckpt, save_dir, 319 | agg_var_samples=1, frames=64, batch_size=256, 320 | var_samples=64, cycle_its=4, cycle_samples=64, 321 | duration=10, seed=42, device='cuda'): 322 | os.makedirs(save_dir, exist_ok=True) 323 | 324 | claps = video_to_claps(config, pretrained, videos, save_dir, 325 | frames=frames, batch_size=batch_size, 326 | var_samples=var_samples, shuffle_remix=True, cycle_its=cycle_its, 327 | cycle_samples=cycle_samples, seed=seed, device=device) 328 | 329 | # Temporal Aggregation 330 | model = Aggregator(emb_dim=512, device=device) 331 | model.load_state_dict(torch.load(agg_ckpt)) 332 | agg_claps = [] 333 | for s in range(0, len(videos), batch_size): 334 | e = min(len(videos), s + batch_size) 335 | agg_claps.append(model.sample(torch.from_numpy(claps[s:e]), var_samples=agg_var_samples)) 336 | agg_claps = torch.cat(agg_claps, dim=0) 337 | 338 | # AudioLDM 339 | print('Low level generation with AudioLDM:') 340 | with open(config, 'r') as fp: 341 | m_config = json.load(fp) 342 | 343 | audioldm_v = m_config['audioldm_version'] 344 | # AudioLDM 345 | model = build_audioldm(model_name=audioldm_v, device=device) 346 | waveform = emb_to_audio(model, agg_claps, batchsize=batch_size, duration=duration) 347 | 348 | # I/O 349 | save_wave(waveform, save_dir, name=[os.path.basename(v).replace('.mp4', '') for v in videos]) 350 | 351 | 352 | # generate oracle audio from audioldm 353 | @torch.no_grad() 354 | def audio_to_audio(aud_dir, save_dir, aldm_version='audioldm-s-full-v2', batchsize=16, device='cuda'): 355 | os.makedirs(save_dir, exist_ok=True) 356 | auds = [str(aud) for aud in Path(aud_dir).glob('*.wav')] 357 | claps = clap_embed_auds(auds, clap_version=aldm_version, device=device) 358 | model = build_audioldm(model_name=aldm_version, device=device) 359 | waveform = emb_to_audio(model, claps, batchsize=batchsize) 360 | fns = [os.path.basename(f).replace('*.wav', '') for f in auds] 361 | save_wave(waveform, save_dir, name=fns) 362 | 363 | 364 | # generate audio from multimodal conditions 365 | @torch.no_grad() 366 | def srcs_to_audio(srcs, save_dir, 367 | config=None, pretrained=None, 368 | dalle2_cfg='', dalle2_ckpt='', 369 | shuffle_remix=True, cycle_its=3, cycle_samples=16, 370 | var_samples=1, batch_size=64, seed=42, duration=10, device='cuda'): 371 | set_seed(seed) 372 | with open(config, 'r') as fp: 373 | config = json.load(fp) 374 | 375 | # CLIP embeds 376 | img_ks = list(srcs['image'].keys()) 377 | if img_ks: 378 | embs = clip_embed_images(img_ks, version='ViT-L/14', batch_size=batch_size, device=device).detach().cpu().numpy() 379 | for k, img_k in enumerate(img_ks): 380 | srcs['image'][img_k] = [None, embs[k]] 381 | 382 | # DALLE2 Prior embeds 383 | text_ks = list(srcs['text'].keys()) 384 | if text_ks: 385 | embs = prior_embed_texts(text_ks, cfg=dalle2_cfg, ckpt=dalle2_ckpt, bs=batch_size, 386 | n_samples_per_batch=2, cond_scale=1, device=device) 387 | embs = embs.detach().cpu().numpy() 388 | for k, text_k in enumerate(text_ks): 389 | srcs['text'][text_k] = [None, embs[k]] 390 | 391 | # CLAP embeds 392 | aud_ks = list(srcs['audio'].keys()) 393 | if aud_ks: 394 | embs = clap_embed_auds(aud_ks, clap_version='audioldm-s-full-v2', device=device).detach().cpu().numpy() 395 | for k, aud_k in enumerate(aud_ks): 396 | srcs['audio'][aud_k] = [None, embs[k]] 397 | 398 | model = Pipeline(copy.deepcopy(config), pretrained, device) 399 | model.eval() 400 | # manifold embeds 401 | for mod in ['image', 'text', 'audio']: 402 | ks = list(srcs[mod].keys()) 403 | if ks: 404 | embs = np.stack([srcs[mod][k][1] for k in ks]) 405 | for ks_s in range(0, len(ks), batch_size): 406 | ks_e = min(len(ks), ks_s + batch_size) 407 | bembs = torch.from_numpy(embs[ks_s:ks_e]).to(device) 408 | if mod == 'audio': 409 | bembs = model.manifold.fold_claps(bembs, var_samples=var_samples, normalize=False) 410 | else: 411 | bembs = model.manifold.fold_clips(bembs, var_samples=var_samples, normalize=False) 412 | bembs = bembs.detach().cpu().numpy() 413 | for z, k in enumerate(ks[ks_s:ks_e]): 414 | srcs[mod][k][0] = bembs[z] 415 | 416 | # assemble remixer input 417 | rm_src = torch.zeros(model.remixer.slot, model.fold_dim) 418 | rm_clip = torch.zeros(model.remixer.slot, model.clip_dim) 419 | 420 | stepper = 1 # reserve first row for global condition (empty) 421 | for mod in srcs: 422 | for k in srcs[mod]: 423 | rm_src[stepper, ...] = torch.from_numpy(srcs[mod][k][0]) 424 | if mod == 'audio': 425 | continue 426 | else: 427 | rm_clip[stepper, ...] = torch.from_numpy(srcs[mod][k][1]) 428 | stepper += 1 429 | 430 | rm_src = rm_src.unsqueeze(0).float().to(device) 431 | rm_clip = rm_clip.unsqueeze(0).float().to(device) 432 | 433 | # remix! 434 | remix_clap = model.cycle_mix(rm_clip, fixed_src=rm_src, 435 | its=cycle_its, var_samples=var_samples, 436 | samples=cycle_samples, shuffle=shuffle_remix) 437 | 438 | del model, rm_src, rm_clip, embs 439 | 440 | # AudioLDM 441 | audioldm_v = config['audioldm_version'] 442 | model = build_audioldm(model_name=audioldm_v, device=device) 443 | waveform = emb_to_audio(model, remix_clap, batchsize=batch_size, duration=duration) 444 | 445 | # I/O 446 | sf.write(save_dir, waveform[0, 0], samplerate=16000) 447 | 448 | -------------------------------------------------------------------------------- /ssv2a/model/remixer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ssv2a.data.utils import random_mute 7 | from ssv2a.model.modules import TransEncoder, MLP, sample_normal 8 | 9 | 10 | class PanauralGate(nn.Module): 11 | def __init__(self, gate=None, slot=64, clip_dim=512): 12 | super().__init__() 13 | self.slot = slot 14 | self.arch = gate['arch'] 15 | 16 | hidden_dim = clip_dim 17 | if self.arch == 'transformer': 18 | self.embed_dim = gate['embed_dim'] 19 | self.in_proj = nn.Linear(clip_dim, self.embed_dim) 20 | 21 | self.encoder = TransEncoder(**gate) 22 | self.pred_token = nn.Parameter(torch.zeros(1, clip_dim)) 23 | nn.init.normal_(self.pred_token, mean=0, std=1) 24 | 25 | self.head = MLP(**gate) 26 | 27 | hidden_dim = gate['layers'][-1] 28 | 29 | else: 30 | raise NotImplementedError('Architecture is not supported.') 31 | 32 | self.out = nn.Linear(hidden_dim, self.slot) 33 | 34 | def forward(self, x): 35 | """ 36 | x: [B, slot, clip_dim] 37 | """ 38 | if self.arch == 'transformer': 39 | ws = torch.cat([torch.tile(self.pred_token, (x.shape[0], 1, 1)), x], dim=1) 40 | ws = self.in_proj(x) 41 | ws = self.encoder(x)[:, 0, :] 42 | ws = self.head(ws) 43 | 44 | else: 45 | raise NotImplementedError('Architecture is not supported.') 46 | 47 | ws = self.out(ws) # [B, slot] 48 | ws = ws * torch.sum(x, dim=-1) # zero out empty conditions 49 | ws = torch.nn.functional.relu(ws) 50 | ws = ws / ws.norm(p=2, dim=-1, keepdim=True) 51 | return ws 52 | 53 | 54 | class Styler(nn.Module): 55 | def __init__(self, styler=None, slot=64, clap_dim=512, clip_dim=512, manifold_dim=512, device='cuda'): 56 | super().__init__() 57 | self.slot = slot 58 | self.variational = styler['variational'] 59 | self.arch = styler['arch'] 60 | self.device = device 61 | 62 | if self.arch == 'transformer': 63 | self.embed_dim = styler['embed_dim'] 64 | 65 | self.in_proj = nn.Linear(manifold_dim + clip_dim, self.embed_dim) 66 | 67 | self.style_encoder = TransEncoder(**styler) 68 | 69 | if self.variational: 70 | self.pred_token = nn.Parameter(torch.zeros(2, self.embed_dim)) 71 | self.head_mu = MLP(layers=[self.embed_dim] * 2, dropout=styler['dropout']) 72 | self.head_sigma = MLP(layers=[self.embed_dim] * 2, dropout=styler['dropout']) 73 | else: 74 | self.pred_token = nn.Parameter(torch.zeros(1, self.embed_dim)) 75 | self.head_mu = MLP(layers=[self.embed_dim] * 2, dropout=styler['dropout']) 76 | nn.init.normal_(self.pred_token, mean=0, std=1) 77 | 78 | hidden_dim = self.embed_dim 79 | 80 | else: 81 | raise NotImplementedError('Architecture is not supported.') 82 | 83 | if self.variational: 84 | self.out_mu = nn.Linear(hidden_dim, clap_dim) 85 | self.out_sigma = nn.Linear(hidden_dim, clap_dim) 86 | else: 87 | self.out = nn.Linear(hidden_dim, clap_dim) 88 | 89 | def forward(self, src, src_clips): 90 | """ 91 | src: [B, L, manifold_dim] (manifold queried by style semantics) 92 | src_clips: [B, L, clip_dim] (semantic clip embeddings) 93 | locality: [B, L] (not used for now) 94 | """ 95 | if self.arch == 'transformer': 96 | # src = torch.zeros(src.shape).to(self.device) # suppress manifold embed for ablation 97 | # src_clips = torch.zeros(src_clips.shape).to(self.device) # suppress clip embed for ablation 98 | 99 | src = torch.cat([src, src_clips], dim=-1) 100 | src = self.in_proj(src) 101 | 102 | src = torch.cat([torch.tile(self.pred_token, (src.shape[0], 1, 1)), src], dim=1) 103 | src = self.style_encoder(src) 104 | 105 | if self.variational: 106 | mu = self.head_mu(src[:, 0, :]) 107 | sigma = self.head_sigma(src[:, 1, :]) 108 | else: 109 | mu = self.head_mu(src[:, 0, :]) 110 | 111 | else: 112 | raise NotImplementedError('Architecture is not supported.') 113 | 114 | if self.variational: 115 | return self.out_mu(mu), self.out_sigma(sigma) 116 | return self.out(mu) 117 | 118 | def sample(self, src, src_clips, var_samples=1): 119 | if self.variational: 120 | mu, log_sigma = self.forward(src, src_clips) 121 | clap = torch.zeros(mu.shape).to(self.device) 122 | for i in range(var_samples): 123 | clap += sample_normal(mu, log_sigma) / var_samples 124 | else: 125 | clap = self.forward(src, src_clips) 126 | return clap 127 | 128 | 129 | class Remixer(nn.Module): 130 | def __init__(self, remixer, clip_dim=512, clap_dim=512, manifold_dim=512, device='cuda', **_): 131 | super().__init__() 132 | self.slot = remixer['slot'] 133 | self.kl_weight = remixer['styler']['kl_weight'] 134 | self.variational = remixer['styler']['variational'] 135 | self.guidance = remixer["guidance"] 136 | 137 | self.cfg = remixer['cfg'] 138 | 139 | if self.guidance == 'manifold+generator': 140 | manifold_dim += clap_dim 141 | elif self.guidance == 'generator': 142 | manifold_dim = clap_dim 143 | self.styler = Styler(styler=remixer['styler'], slot=self.slot, clap_dim=clap_dim, clip_dim=clip_dim, 144 | manifold_dim=manifold_dim, device=device) 145 | 146 | self.to(device) 147 | self.float() 148 | 149 | def forward(self, src, src_clips): 150 | """ 151 | src: [B, L, manifold_dim] (manifold queried by style semantics) 152 | style: [B, L, clip_dim] (unannotated style semantics) 153 | """ 154 | if self.training: # classifier free guidance 155 | src, src_mask_idx = random_mute(src, p=self.cfg) 156 | src_clips, src_mask_idx = random_mute(src_clips, p=self.cfg) 157 | 158 | return self.styler(src, src_clips) 159 | 160 | def sample(self, src, src_clips, var_samples=1, normalize=True, shuffle=True): 161 | if shuffle: 162 | idx = random.sample(range(self.slot), self.slot) 163 | src = src[:, idx, :] 164 | src_clips = src_clips[:, idx, :] 165 | clap = self.styler.sample(src, src_clips, var_samples=var_samples) 166 | if normalize: 167 | clap = clap / clap.norm(p=2, dim=-1, keepdim=True) 168 | return clap 169 | 170 | --------------------------------------------------------------------------------