├── .gitignore ├── LICENSE.md ├── README.md ├── __pycache__ └── model_drum.cpython-38.pyc ├── cog.yaml ├── data ├── freesound │ ├── mean.mel.npy │ └── std.mel.npy ├── looperman │ ├── mean.mel.npy │ └── std.mel.npy └── looperman_four_bar │ ├── mean.mel.npy │ └── std.mel.npy ├── dataset.py ├── distributed.py ├── environment.yml ├── evaluation ├── FAD │ └── looperman_2000.stats ├── IS │ ├── attention_modules.py │ ├── best_model.ckpt │ ├── compute_is_score.sh │ ├── inception_score.py │ ├── model.py │ └── modules.py ├── NDB_JS │ ├── compute_ndb_js.py │ ├── compute_ndb_js.sh │ └── ndb.py └── nine_audio │ ├── Chillout │ └── 1365.wav │ ├── Drum_and_Bass │ └── 14.wav │ ├── Electronic │ └── 1353.wav │ ├── Hiphop │ └── 323.wav │ ├── Rap │ └── 153.wav │ ├── Rock │ └── 752.wav │ ├── Trap │ └── 1395.wav │ ├── dupstep │ └── 113.wav │ └── industrial │ └── 1606.wav ├── generate_audio.py ├── generate_looperman_four_bar.py ├── melgan ├── .ipynb_checkpoints │ └── modules-checkpoint.py ├── args.yml ├── best_netG.pt └── modules.py ├── model_drum.py ├── model_drum_four_bar.py ├── non_leaking.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── predict.py ├── preprocess ├── collect_audio_clips.py ├── compute_mean_std.mel.py ├── extract_mel.py ├── make_dataset.py └── trim_2_seconds.py ├── scripts ├── generate_freesound.sh ├── generate_looperman_four_bar.sh ├── generate_looperman_one_bar.sh └── train.sh └── train_drum.py /.gitignore: -------------------------------------------------------------------------------- 1 | generated_looperman_one_bar 2 | generated_freesound_one_bar 3 | generated_interpolation_one_bar_2 4 | pretrained_model 5 | __pycache__/*.pyc 6 | __pycache__ 7 | __pycache__/model_drum.cpython-38.pyc 8 | evaluation/NDB_JS/__pycache__ 9 | evaluation/NDB_JS/looperman_2000.txt 10 | evaluation/NDB_JS/looper_2000/ 11 | evaluation/NDB_JS/looper_2000.zip 12 | evaluation/user_study_analysis/ 13 | evaluation/IS/freesound_styelgan2.pkl 14 | freesound_checkpoint.pt 15 | generated_audio_looperman_four_bar/ 16 | looperman_four_bar_checkpoint.pt 17 | looperman_one_bar_checkpoint.pt 18 | freesound_mel_80_320.zip 19 | mel_80_320/ 20 | freesound_checkpoint/ 21 | freesound_sample_dir/ 22 | evaluation/IS/compute_is_score_one.sh 23 | 24 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) loop-generation develope team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoopTest 2 | [![GitHub](https://img.shields.io/github/license/allenhung1025/loop-generation?label=license)](./LICENSE.md) 3 | ![GitHub issues](https://img.shields.io/github/issues/allenhung1025/loop-generation) 4 | ![GitHub Repo stars](https://img.shields.io/github/stars/allenhung1025/loop-generation) 5 | 6 | 7 | * This is the official repository of **A Benchmarking Initiative for Audio-domain Music Generation Using the FreeSound Loop Dataset** co-authored with [Paul Chen](https://paulyuchen.com/), [Arthur Yeh](http://yentung.com/) and my supervisor [Yi-Hsuan Yang](http://mac.citi.sinica.edu.tw/~yang/). The paper has been accepted by International Society for Music Information Retrieval Conference 2021. [[Demo Page]](https://loopgen.github.io/), [[arxiv]](https://arxiv.org/pdf/2108.01576.pdf). 8 | * We not only provided pretrained model to generate loops on your own but also provided scripts for you to evaluate the generated loops. 9 | ## Environment 10 | ``` 11 | $ conda env create -f environment.yml 12 | ``` 13 | ## Quick Start 14 | 15 | * Generate loops from one-bar looperman pretrained model 16 | ``` bash 17 | $ gdown --id 1GQpzWz9ycIm5wzkxLsVr-zN17GWD3_6K -O looperman_one_bar_checkpoint.pt 18 | $ bash scripts/generate_looperman_one_bar.sh 19 | ``` 20 | 21 | * Generate loops from four-bar looperman pretrained model 22 | ``` bash 23 | $ gdown --id 19rk3vx7XM4dultTF1tN4srCpdya7uxBV -O looperman_four_bar_checkpoint.pt 24 | $ bash scripts/generate_looperman_four_bar.sh 25 | ``` 26 | 27 | * Generate loops from freesound pretrained model 28 | ``` bash 29 | $ gdown --id 197DMCOASEMFBVi8GMahHfRwgJ0bhcUND -O freesound_checkpoint.pt 30 | $ bash scripts/generate_freesound.sh 31 | ``` 32 | ## Pretrained Checkpoint 33 | * [Looperman pretrained one-bar model](https://drive.google.com/file/d/1GQpzWz9ycIm5wzkxLsVr-zN17GWD3_6K/view?usp=sharing) 34 | * [Looperman pretrained four-bar model](https://drive.google.com/file/d/19rk3vx7XM4dultTF1tN4srCpdya7uxBV/view?usp=sharing) 35 | * [Freesound pretrained one-bar model](https://drive.google.com/file/d/197DMCOASEMFBVi8GMahHfRwgJ0bhcUND/view?usp=sharing) 36 | 37 | ## Benchmarking Freesound Loop Dataset 38 | ### Download dataset 39 | ``` bash 40 | 41 | $ gdown --id 1fQfSZgD9uWbCdID4SzVqNGhsYNXOAbK5 42 | $ unzip freesound_mel_80_320.zip 43 | 44 | ``` 45 | ### Training 46 | 47 | ``` bash 48 | $ CUDA_VISIBLE_DEVICES=2 python train_drum.py \ 49 | --size 64 --batch 8 --sample_dir freesound_sample_dir \ 50 | --checkpoint_dir freesound_checkpoint \ 51 | --iter 100000 52 | mel_80_320 53 | ``` 54 | 55 | ### Generate audio 56 | ```bash 57 | $ CUDA_VISIBLE_DEVICES=2 python generate_audio.py \ 58 | --ckpt freesound_checkpoint/100000.pt \ 59 | --pics 2000 --data_path "./data/freesound" \ 60 | --store_path "./generated_freesound_one_bar" 61 | ``` 62 | ### Evaluation 63 | #### NDB_JS 64 | * 2000 looperman melspectrogram [link](https://drive.google.com/file/d/1aFGPYlkkAysVBWp9VacHVk2tf-b4rLIh/view?usp=sharing) 65 | ``` bash 66 | $ cd evaluation/NDB_JS 67 | $ gdown --id 1aFGPYlkkAysVBWp9VacHVk2tf-b4rLIh 68 | $ unzip looper_2000.zip # contain 2000 looperman mel-sepctrogram 69 | $ rm looper_2000/.zip 70 | $ bash compute_ndb_js.sh 71 | ``` 72 | #### IS 73 | * Short-Chunk CNN [checkpoint](./evaluation/IS/best_model.ckpt) 74 | ``` bash 75 | $ cd evaluation/IS 76 | $ bash compute compute_is_score.sh 77 | ``` 78 | #### FAD 79 | * FAD looperman ground truth [link](./evaluation/FAD/looperman_2000.stats), follow the official [doc][fad] to install required packages. 80 | 81 | ``` bash 82 | $ ls --color=never generated_freesound_one_bar/100000/*.wav > freesound.csv 83 | $ python -m frechet_audio_distance.create_embeddings_main --input_files freesound.csv --stats freesound.stats 84 | $ python -m frechet_audio_distance.compute_fad --background_stats ./evaluation/FAD/looperman_2000.stats --test_stats freesound.stats 85 | ``` 86 | 87 | 88 | 89 | ## Train the model with your loop dataset 90 | ### Preprocess the Loop Dataset 91 | In the [preprocess](./preprocess) directory and modify some settings (e.g. data path) in the codes and run them with the following orders 92 | ``` bash 93 | $ python trim_2_seconds.py # Cut loop into the single bar and stretch them to 2 second. 94 | $ python extract_mel.py # Extract mel-spectrogram from 2-second audio. 95 | $ python make_dataset.py 96 | $ python compute_mean_std.py 97 | ``` 98 | 99 | ### Train the Model 100 | ``` bash 101 | CUDA_VISIBLE_DEVICES=2 python train_drum.py \ 102 | --size 64 --batch 8 --sample_dir [sample_dir] \ 103 | --checkpoint_dir [checkpoint_dir] \ 104 | [mel-spectrogram dataset from the proprocessing] 105 | ``` 106 | * checkpoint_dir stores model in the designated directory. 107 | * sample_dir stores mel-spectrogram generated from the model. 108 | * You should give the data directory in the end. 109 | * There is an example training [script](./scripts/train.sh) 110 | 111 | ## Vocoder 112 | We use [MelGAN][melgan] as the vocoder. We trained the vocoder with looperman dataset and use the vocoder in generating freesound and looperman models. 113 | The trained vocoder is in [melgan](./melgan) directory. 114 | 115 | ## References 116 | The code comes heavily from the code below 117 | * [StyleGAN2 from rosinality][stylegan2] 118 | * [Official MelGAN repo][melgan] 119 | * [Official UNAGAN repo from ciaua][unagan]. 120 | * [Official Short Chunk CNN repo][cnn] 121 | * [FAD official document][fad] 122 | 123 | [fad]: https://github.com/google-research/google-research/tree/master/frechet_audio_distance 124 | [cnn]: https://github.com/minzwon/sota-music-tagging-models 125 | [stylegan2]: https://github.com/rosinality/stylegan2-pytorch 126 | [unagan]: https://github.com/ciaua/unagan 127 | [melgan]: https://github.com/descriptinc/melgan-neurips 128 | 129 | ## Citation 130 | If you find this repo useful, please kindly cite with the following information. 131 | ``` 132 | @inproceedings{ allenloopgen, 133 | title={A Benchmarking Initiative for Audio-domain Music Generation using the {FreeSound Loop Dataset}}, 134 | author={Tun-Min Hung and Bo-Yu Chen and Yen-Tung Yeh, and Yi-Hsuan Yang}, 135 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.}, 136 | year={2021}, 137 | } 138 | ``` 139 | -------------------------------------------------------------------------------- /__pycache__/model_drum.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/__pycache__/model_drum.cpython-38.pyc -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:Predictor" 2 | build: 3 | gpu: true 4 | python_version: "3.8" 5 | system_packages: 6 | - "ffmpeg" 7 | python_packages: 8 | - "cython==0.29.24" 9 | - "stylegan2-pytorch==1.8.1" 10 | - "torch==1.6.0" 11 | - "torchaudio==0.6.0" 12 | - "torchvision==0.7.0" 13 | - "librosa==0.8.0" 14 | - "numpy==1.20.1" 15 | - "scipy==1.6.2" 16 | - "tqdm==4.55.2" 17 | - "pysoundfile==0.9.0.post1" 18 | - "pyrubberband==0.3.0" 19 | - "pydub==0.23.1" 20 | - "matplotlib==3.3.4" 21 | - "lmdb==0.96" 22 | - "pillow==8.2.0" 23 | - "ninja==1.10.2" 24 | pre_install: 25 | - "pip install madmom==0.16.1" # needs to be after python_packages since madmom needs cython without requiring it 26 | -------------------------------------------------------------------------------- /data/freesound/mean.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/freesound/mean.mel.npy -------------------------------------------------------------------------------- /data/freesound/std.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/freesound/std.mel.npy -------------------------------------------------------------------------------- /data/looperman/mean.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman/mean.mel.npy -------------------------------------------------------------------------------- /data/looperman/std.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman/std.mel.npy -------------------------------------------------------------------------------- /data/looperman_four_bar/mean.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman_four_bar/mean.mel.npy -------------------------------------------------------------------------------- /data/looperman_four_bar/std.mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman_four_bar/std.mel.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from torch.utils import data 8 | import numpy as np 9 | import os 10 | def data_sampler(dataset, shuffle, distributed): 11 | if distributed: 12 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 13 | 14 | if shuffle: 15 | return data.RandomSampler(dataset) 16 | 17 | else: 18 | return data.SequentialSampler(dataset) 19 | 20 | class MultiResolutionDataset(Dataset): 21 | def __init__(self, path, transform, resolution=256): 22 | self.env = lmdb.open( 23 | path, 24 | max_readers=32, 25 | readonly=True, 26 | lock=False, 27 | readahead=False, 28 | meminit=False, 29 | ) 30 | 31 | if not self.env: 32 | raise IOError('Cannot open lmdb dataset', path) 33 | 34 | with self.env.begin(write=False) as txn: 35 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 36 | 37 | self.resolution = resolution 38 | self.transform = transform 39 | 40 | def __len__(self): 41 | return self.length 42 | 43 | def __getitem__(self, index): 44 | with self.env.begin(write=False) as txn: 45 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 46 | img_bytes = txn.get(key) 47 | 48 | buffer = BytesIO(img_bytes) 49 | img = Image.open(buffer) 50 | img = self.transform(img) 51 | 52 | return img 53 | class MultiResolutionDataset_drum(Dataset): 54 | def __init__(self, path, transform, resolution=None): 55 | self.path_list = [] 56 | for file in os.listdir(path): 57 | if file.endswith('.npy') == True: 58 | if file.startswith('std') == False and file.startswith('mean') == False: 59 | self.path_list.append(os.path.join(path, file)) 60 | self.resolution = resolution 61 | self.transform = transform 62 | 63 | def __len__(self): 64 | return len(self.path_list) 65 | 66 | def __getitem__(self, index): 67 | 68 | img = np.load(self.path_list[index]) 69 | img = self.transform(img) 70 | 71 | return img 72 | 73 | class MultiResolutionDataset_drum_with_filename(Dataset): 74 | def __init__(self, path, transform, resolution=None): 75 | self.path_list = [] 76 | for file in os.listdir(path): 77 | if file.endswith('.npy') == True: 78 | if file.startswith('std') == False and file.startswith('mean') == False: 79 | self.path_list.append(os.path.join(path, file)) 80 | self.resolution = resolution 81 | self.transform = transform 82 | 83 | def __len__(self): 84 | return len(self.path_list) 85 | 86 | def __getitem__(self, index): 87 | 88 | img = np.load(self.path_list[index]) 89 | img = self.transform(img) 90 | 91 | return img, self.path_list[index].split('/')[-1] 92 | 93 | class MultiResolutionDataset_drum_with_label(Dataset): 94 | def __init__(self, path, transform, label_dictionary, resolution=None): 95 | self.path_list = [] 96 | for file in os.listdir(path): 97 | if file.endswith('.npy') == True: 98 | if file.startswith('std') == False and file.startswith('mean') == False: 99 | self.path_list.append(os.path.join(path, file)) 100 | self.resolution = resolution 101 | self.transform = transform 102 | class MultiResolutionDataset_drum_with_label(Dataset): 103 | def __init__(self, path, transform, label_dictionary, resolution=None): 104 | self.path_list = [] 105 | for file in os.listdir(path): 106 | if file.endswith('.npy') == True: 107 | if file.startswith('std') == False and file.startswith('mean') == False: 108 | self.path_list.append(os.path.join(path, file)) 109 | self.resolution = resolution 110 | self.transform = transform 111 | 112 | ## read label dictionary 113 | import pickle as pickle 114 | with open(label_dictionary, 'rb') as f: 115 | self.label_dictionary = pickle.load(f) 116 | 117 | ## genre to int dictionary 118 | self.genre_to_int = {} 119 | count = 0 120 | for _, genre in self.label_dictionary.items(): 121 | if self.genre_to_int.get(genre) == None: 122 | self.genre_to_int[genre] = count 123 | count += 1 124 | def __len__(self): 125 | return len(self.path_list) 126 | 127 | def __getitem__(self, index): 128 | 129 | img = np.load(self.path_list[index]) 130 | img = self.transform(img) 131 | 132 | file_name = self.path_list[index].split('/')[-1] 133 | label = self.genre_to_int[self.label_dictionary[file_name]] 134 | return img, label 135 | if __name__ == '__main__': 136 | transform = transforms.Compose( 137 | [ 138 | #transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 141 | ] 142 | ) 143 | path = '/home/allenhung/nas189/home/style-based-gan-drum/training_data_one_bar_all/mel_80_320_genre_more_than_600' 144 | label_dictionary = '/home/allenhung/nas189/home/style-based-gan-drum/training_data_one_bar_all/dict_one_bar_more_than_600.pickle' 145 | dataset = MultiResolutionDataset_drum_with_label(path, transform, label_dictionary) 146 | loader = data.DataLoader( 147 | dataset, 148 | batch_size=2, 149 | sampler=data_sampler(dataset, shuffle=True, distributed=False), 150 | drop_last=True, 151 | ) 152 | for data in loader: 153 | import pdb; pdb.set_trace() 154 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: melgan_suc 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - absl-py=0.12.0=py38h06a4308_0 11 | - appdirs=1.4.4=py_0 12 | - audioread=2.1.9=py38h578d9bd_0 13 | - backcall=0.2.0=pyhd3eb1b0_0 14 | - blas=1.0=mkl 15 | - brotlipy=0.7.0=py38h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.17.1=h27cfd23_0 18 | - ca-certificates=2020.10.14=0 19 | - cairo=1.16.0=h7979940_1007 20 | - certifi=2020.6.20=py38_0 21 | - cffi=1.14.5=py38h261ae71_0 22 | - chardet=4.0.0=py38h06a4308_1003 23 | - click=7.1.2=py_0 24 | - coverage=5.5=py38h27cfd23_2 25 | - cryptography=3.4.7=py38hd23ed53_0 26 | - cudatoolkit=10.2.89=hfd86e86_1 27 | - cycler=0.10.0=py38_0 28 | - dbus=1.13.18=hb2f20db_0 29 | - decorator=5.0.6=pyhd3eb1b0_0 30 | - dlib=19.21.1=py38h2c889cf_0 31 | - expat=2.2.10=he6710b0_2 32 | - ffmpeg=4.3.1=h3215721_1 33 | - flask=1.1.2=py_0 34 | - flask-wtf=0.14.3=py_0 35 | - fontconfig=2.13.1=hba837de_1004 36 | - freetype=2.10.4=h5ab3b9f_0 37 | - gettext=0.21.0=hf68c758_0 38 | - glib=2.66.6=ha03b18c_2 39 | - glib-tools=2.66.6=ha03b18c_2 40 | - gmp=6.2.1=h2531618_2 41 | - gnutls=3.6.15=he1e5248_0 42 | - graphite2=1.3.14=h23475e2_0 43 | - grpcio=1.36.1=py38h2157cd5_1 44 | - gst-plugins-base=1.14.5=h0935bb2_2 45 | - gstreamer=1.18.3=h3560a44_0 46 | - harfbuzz=2.7.4=h5cf4720_0 47 | - hdf5=1.10.6=hb1b8bf9_0 48 | - icu=68.1=h2531618_0 49 | - idna=2.10=pyhd3eb1b0_0 50 | - importlib-metadata=3.10.0=py38h06a4308_0 51 | - intel-openmp=2021.2.0=h06a4308_610 52 | - ipykernel=5.3.4=py38h5ca1d4c_0 53 | - ipython=7.22.0=py38hb070fc8_0 54 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 55 | - itsdangerous=1.1.0=py_0 56 | - jasper=1.900.1=hd497a04_4 57 | - jedi=0.17.0=py38_0 58 | - jinja2=2.11.2=py_0 59 | - joblib=1.0.1=pyhd3eb1b0_0 60 | - jpeg=9d=h36c2ea0_0 61 | - jupyter_client=6.1.12=pyhd3eb1b0_0 62 | - jupyter_core=4.7.1=py38h06a4308_0 63 | - kiwisolver=1.3.1=py38h2531618_0 64 | - krb5=1.17.1=h173b8e3_0 65 | - lame=3.100=h7b6447c_0 66 | - lcms2=2.12=h3be6417_0 67 | - ld_impl_linux-64=2.33.1=h53a641e_7 68 | - libblas=3.9.0=1_h6e990d7_netlib 69 | - libcblas=3.9.0=3_h893e4fe_netlib 70 | - libclang=11.0.1=default_ha53f305_1 71 | - libedit=3.1.20210216=h27cfd23_1 72 | - libevent=2.1.10=hcdb4288_3 73 | - libffi=3.3=he6710b0_2 74 | - libflac=1.3.3=h9c3ff4c_1 75 | - libgcc-ng=9.3.0=h5dbcf3e_17 76 | - libgfortran-ng=7.5.0=hae1eefd_17 77 | - libgfortran4=7.5.0=hae1eefd_17 78 | - libglib=2.66.6=hdb14261_2 79 | - libgomp=9.3.0=h5dbcf3e_17 80 | - libiconv=1.16=h516909a_0 81 | - libidn2=2.3.0=h27cfd23_0 82 | - liblapack=3.9.0=3_h893e4fe_netlib 83 | - liblapacke=3.9.0=3_h893e4fe_netlib 84 | - libllvm10=10.0.1=hbcb73fb_5 85 | - libllvm11=11.0.1=hf817b99_0 86 | - libogg=1.3.4=h7f98852_0 87 | - libopencv=4.5.0=py38h703c3c0_7 88 | - libopus=1.3.1=h7b6447c_0 89 | - libpng=1.6.37=hbc83047_0 90 | - libpq=12.3=h255efa7_3 91 | - libprotobuf=3.14.0=h8c45485_0 92 | - librosa=0.8.0=pyh9f0ad1d_0 93 | - libsndfile=1.0.30=h9c3ff4c_1 94 | - libsodium=1.0.18=h7b6447c_0 95 | - libstdcxx-ng=9.3.0=h2ae2ef3_17 96 | - libtasn1=4.16.0=h27cfd23_0 97 | - libtiff=4.1.0=h2733197_1 98 | - libunistring=0.9.10=h27cfd23_0 99 | - libuuid=2.32.1=h7f98852_1000 100 | - libvorbis=1.3.7=h7b6447c_0 101 | - libwebp-base=1.2.0=h27cfd23_0 102 | - libxcb=1.14=h7b6447c_0 103 | - libxkbcommon=1.0.3=he3ba5ed_0 104 | - libxml2=2.9.10=h72842e0_3 105 | - llvmlite=0.36.0=py38h612dafd_4 106 | - lz4-c=1.9.3=h2531618_0 107 | - markdown=3.3.4=py38h06a4308_0 108 | - markupsafe=1.1.1=py38h7b6447c_0 109 | - matplotlib-base=3.3.4=py38h62a2d02_0 110 | - mkl=2021.2.0=h06a4308_296 111 | - mkl-service=2.3.0=py38h27cfd23_1 112 | - mkl_fft=1.3.0=py38h42c9631_2 113 | - mkl_random=1.2.1=py38ha9443f7_2 114 | - mysql-common=8.0.22=ha770c72_1 115 | - mysql-libs=8.0.22=h1fd7589_1 116 | - ncurses=6.2=he6710b0_1 117 | - nettle=3.7.2=hbbd107a_1 118 | - ninja=1.10.2=hff7bd54_1 119 | - nspr=4.29=h9c3ff4c_1 120 | - nss=3.61=hb5efdd6_0 121 | - numba=0.53.1=py38ha9443f7_0 122 | - numpy=1.20.1=py38h93e21f0_0 123 | - numpy-base=1.20.1=py38h7d8b39e_0 124 | - olefile=0.46=py_0 125 | - opencv=4.5.0=py38h578d9bd_7 126 | - openh264=2.1.1=h8b12597_0 127 | - openssl=1.1.1k=h27cfd23_0 128 | - packaging=20.9=pyhd3eb1b0_0 129 | - parso=0.8.2=pyhd3eb1b0_0 130 | - pcre=8.44=he6710b0_0 131 | - pexpect=4.8.0=pyhd3eb1b0_3 132 | - pickleshare=0.7.5=pyhd3eb1b0_1003 133 | - pillow=8.2.0=py38he98fc37_0 134 | - pip=21.0.1=py38h06a4308_0 135 | - pixman=0.40.0=h7b6447c_0 136 | - pooch=1.3.0=pyhd3eb1b0_0 137 | - prompt-toolkit=3.0.17=pyh06a4308_0 138 | - protobuf=3.14.0=py38h2531618_1 139 | - ptyprocess=0.7.0=pyhd3eb1b0_2 140 | - py-opencv=4.5.0=py38h81c977d_7 141 | - pycparser=2.20=py_2 142 | - pydub=0.23.1=py_0 143 | - pygments=2.8.1=pyhd3eb1b0_0 144 | - pyopenssl=20.0.1=pyhd3eb1b0_1 145 | - pyparsing=2.4.7=pyhd3eb1b0_0 146 | - pysocks=1.7.1=py38h06a4308_0 147 | - python=3.8.8=hdb3f193_5 148 | - python-dateutil=2.8.1=pyhd3eb1b0_0 149 | - python-lmdb=0.96=py38h950e882_1 150 | - python_abi=3.8=1_cp38 151 | - pytorch=1.6.0=py3.8_cuda10.2.89_cudnn7.6.5_0 152 | - pyyaml=5.3.1=py38h8df0ef7_1 153 | - pyzmq=20.0.0=py38h2531618_1 154 | - qt=5.12.9=h9d6b050_2 155 | - readline=8.1=h27cfd23_0 156 | - requests=2.25.1=pyhd3eb1b0_0 157 | - resampy=0.2.2=py_0 158 | - scikit-learn=0.24.1=py38ha9443f7_0 159 | - scipy=1.6.2=py38had2a1c9_1 160 | - setuptools=52.0.0=py38h06a4308_0 161 | - six=1.15.0=py38h06a4308_0 162 | - sqlite=3.35.4=hdfb4753_0 163 | - tbb=2020.3=hfd86e86_0 164 | - tensorboard=1.15.0=py38_0 165 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 166 | - tk=8.6.10=hbc83047_0 167 | - torchaudio=0.6.0=py38 168 | - torchvision=0.7.0=py38_cu102 169 | - tornado=6.1=py38h27cfd23_0 170 | - tqdm=4.55.2=pyhd8ed1ab_0 171 | - traitlets=5.0.5=pyhd3eb1b0_0 172 | - urllib3=1.26.4=pyhd3eb1b0_0 173 | - wcwidth=0.2.5=py_0 174 | - werkzeug=1.0.1=pyhd3eb1b0_0 175 | - wheel=0.36.2=pyhd3eb1b0_0 176 | - wtforms=2.3.3=py_0 177 | - x264=1!152.20180806=h14c3975_0 178 | - xorg-kbproto=1.0.7=h7f98852_1002 179 | - xorg-libice=1.0.10=h516909a_0 180 | - xorg-libsm=1.2.3=h84519dc_1000 181 | - xorg-libx11=1.6.12=h516909a_0 182 | - xorg-libxext=1.3.4=h516909a_0 183 | - xorg-libxrender=0.9.10=h516909a_1002 184 | - xorg-renderproto=0.11.1=h14c3975_1002 185 | - xorg-xextproto=7.3.0=h7f98852_1002 186 | - xorg-xproto=7.0.31=h27cfd23_1007 187 | - xz=5.2.5=h7b6447c_0 188 | - yaml=0.2.5=h7b6447c_0 189 | - zeromq=4.3.4=h2531618_0 190 | - zipp=3.4.1=pyhd3eb1b0_0 191 | - zlib=1.2.11=h7b6447c_3 192 | - zstd=1.4.9=haebb681_0 193 | - pip: 194 | - aim==2.3.0 195 | - aimrecords==0.0.7 196 | - anytree==2.8.0 197 | - base58==2.0.1 198 | - contrastive-learner==0.1.1 199 | - cython==0.29.21 200 | - docker==5.0.0 201 | - einops==0.3.0 202 | - fire==0.4.0 203 | - future==0.18.2 204 | - gitdb==4.0.7 205 | - gitpython==3.1.14 206 | - kornia==0.5.1 207 | - madmom==0.16.1 208 | - mido==1.2.9 209 | - psutil==5.8.0 210 | - py==1.10.0 211 | - py3nvml==0.2.6 212 | - pyrser==0.2.0 213 | - pyrubberband==0.3.0 214 | - pysoundfile==0.9.0.post1 215 | - retry==0.9.2 216 | - smmap==4.0.0 217 | - stylegan2-pytorch==1.8.1 218 | - termcolor==1.1.0 219 | - timm==0.4.5 220 | - vector-quantize-pytorch==0.1.0 221 | - websocket-client==0.59.0 222 | - xmltodict==0.12.0 223 | prefix: /home/allenhung/miniconda3/envs/melgan_suc 224 | -------------------------------------------------------------------------------- /evaluation/FAD/looperman_2000.stats: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/FAD/looperman_2000.stats -------------------------------------------------------------------------------- /evaluation/IS/attention_modules.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Code adopted from https://github.com/huggingface/pytorch-pretrained-BERT 3 | 4 | import math 5 | import copy 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | # Gelu 11 | def gelu(x): 12 | """Implementation of the gelu activation function. 13 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 14 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 15 | Also see https://arxiv.org/abs/1606.08415 16 | """ 17 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 18 | 19 | # LayerNorm 20 | try: 21 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 22 | except ImportError: 23 | #print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 24 | class BertLayerNorm(nn.Module): 25 | def __init__(self, hidden_size, eps=1e-12): 26 | """Construct a layernorm module in the TF style (epsilon inside the square root). 27 | """ 28 | super(BertLayerNorm, self).__init__() 29 | self.weight = nn.Parameter(torch.ones(hidden_size)) 30 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 31 | self.variance_epsilon = eps 32 | 33 | def forward(self, x): 34 | u = x.mean(-1, keepdim=True) 35 | s = (x - u).pow(2).mean(-1, keepdim=True) 36 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 37 | return self.weight * x + self.bias 38 | 39 | 40 | class BertConfig(object): 41 | def __init__(self, 42 | vocab_size, 43 | hidden_size=768, 44 | num_hidden_layers=12, 45 | num_attention_heads=12, 46 | intermediate_size=3072, 47 | hidden_act="gelu", 48 | hidden_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | attention_probs_dropout_prob=0.1, 51 | type_vocab_size=2): 52 | self.vocab_size = vocab_size 53 | self.hidden_size = hidden_size 54 | self.num_hidden_layers = num_hidden_layers 55 | self.num_attention_heads = num_attention_heads 56 | self.hidden_act = hidden_act 57 | self.intermediate_size = intermediate_size 58 | self.hidden_dropout_prob = hidden_dropout_prob 59 | self.max_position_embeddings = max_position_embeddings 60 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 61 | self.type_vocab_size = type_vocab_size 62 | 63 | 64 | class BertSelfAttention(nn.Module): 65 | def __init__(self, config): 66 | super(BertSelfAttention, self).__init__() 67 | if config.hidden_size % config.num_attention_heads != 0: 68 | raise ValueError( 69 | "The hidden size (%d) is not a multiple of the number of attention " 70 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 71 | self.num_attention_heads = config.num_attention_heads 72 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 73 | self.all_head_size = self.num_attention_heads * self.attention_head_size 74 | 75 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 76 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 77 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 78 | 79 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 80 | 81 | def transpose_for_scores(self, x): 82 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 83 | x = x.view(*new_x_shape) 84 | return x.permute(0, 2, 1, 3) 85 | 86 | def forward(self, hidden_states, attention_mask): 87 | mixed_query_layer = self.query(hidden_states) 88 | mixed_key_layer = self.key(hidden_states) 89 | mixed_value_layer = self.value(hidden_states) 90 | 91 | query_layer = self.transpose_for_scores(mixed_query_layer) 92 | key_layer = self.transpose_for_scores(mixed_key_layer) 93 | value_layer = self.transpose_for_scores(mixed_value_layer) 94 | 95 | # Take the dot product between "query" and "key" to get the raw attention scores. 96 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 97 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 98 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 99 | if attention_mask is not None: 100 | attention_scores = attention_scores + attention_mask 101 | 102 | # Normalize the attention scores to probabilities. 103 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 104 | 105 | # This is actually dropping out entire tokens to attend to, which might 106 | # seem a bit unusual, but is taken from the original Transformer paper. 107 | attention_probs = self.dropout(attention_probs) 108 | 109 | context_layer = torch.matmul(attention_probs, value_layer) 110 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 111 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 112 | context_layer = context_layer.view(*new_context_layer_shape) 113 | return context_layer 114 | 115 | 116 | class BertSelfOutput(nn.Module): 117 | def __init__(self, config): 118 | super(BertSelfOutput, self).__init__() 119 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 120 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 121 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 122 | 123 | def forward(self, hidden_states, input_tensor): 124 | hidden_states = self.dense(hidden_states) 125 | hidden_states = self.dropout(hidden_states) 126 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 127 | return hidden_states 128 | 129 | 130 | class BertAttention(nn.Module): 131 | def __init__(self, config): 132 | super(BertAttention, self).__init__() 133 | self.self = BertSelfAttention(config) 134 | self.output = BertSelfOutput(config) 135 | 136 | def forward(self, input_tensor, attention_mask): 137 | self_output = self.self(input_tensor, attention_mask) 138 | attention_output = self.output(self_output, input_tensor) 139 | return attention_output 140 | 141 | 142 | class BertIntermediate(nn.Module): 143 | def __init__(self, config): 144 | super(BertIntermediate, self).__init__() 145 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 146 | self.intermediate_act_fn = gelu 147 | 148 | def forward(self, hidden_states): 149 | hidden_states = self.dense(hidden_states) 150 | hidden_states = self.intermediate_act_fn(hidden_states) 151 | return hidden_states 152 | 153 | 154 | class BertOutput(nn.Module): 155 | def __init__(self, config): 156 | super(BertOutput, self).__init__() 157 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 158 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 159 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 160 | 161 | def forward(self, hidden_states, input_tensor): 162 | hidden_states = self.dense(hidden_states) 163 | hidden_states = self.dropout(hidden_states) 164 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 165 | return hidden_states 166 | 167 | 168 | class BertLayer(nn.Module): 169 | def __init__(self, config): 170 | super(BertLayer, self).__init__() 171 | self.attention = BertAttention(config) 172 | self.intermediate = BertIntermediate(config) 173 | self.output = BertOutput(config) 174 | 175 | def forward(self, hidden_states, attention_mask): 176 | attention_output = self.attention(hidden_states, attention_mask) 177 | intermediate_output = self.intermediate(attention_output) 178 | layer_output = self.output(intermediate_output, attention_output) 179 | return layer_output 180 | 181 | 182 | class BertEncoder(nn.Module): 183 | def __init__(self, config): 184 | super(BertEncoder, self).__init__() 185 | layer = BertLayer(config) 186 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 187 | 188 | def forward(self, hidden_states, attention_mask=None, output_all_encoded_layers=True): 189 | all_encoder_layers = [] 190 | for layer_module in self.layer: 191 | hidden_states = layer_module(hidden_states, attention_mask) 192 | if output_all_encoded_layers: 193 | all_encoder_layers.append(hidden_states) 194 | if not output_all_encoded_layers: 195 | all_encoder_layers.append(hidden_states) 196 | return all_encoder_layers 197 | 198 | 199 | class BertEmbeddings(nn.Module): 200 | """Construct the embeddings from word, position and token_type embeddings. 201 | """ 202 | def __init__(self, config): 203 | super(BertEmbeddings, self).__init__() 204 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 205 | 206 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 207 | # any TensorFlow checkpoint file 208 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 209 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 210 | 211 | def forward(self, input_ids, token_type_ids=None): 212 | seq_length = input_ids.size(1) 213 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 214 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids[:, :, 0]) 215 | 216 | position_embeddings = self.position_embeddings(position_ids) 217 | 218 | embeddings = input_ids + position_embeddings 219 | #embeddings = input_ids 220 | embeddings = self.LayerNorm(embeddings) 221 | embeddings = self.dropout(embeddings) 222 | return embeddings 223 | 224 | 225 | class PositionalEncoding(nn.Module): 226 | def __init__(self, config): 227 | super(PositionalEncoding, self).__init__() 228 | emb_dim = config.hidden_size 229 | max_len = config.max_position_embeddings 230 | self.position_enc = self.position_encoding_init(max_len, emb_dim) 231 | 232 | @staticmethod 233 | def position_encoding_init(n_position, emb_dim): 234 | ''' Init the sinusoid position encoding table ''' 235 | 236 | # keep dim 0 for padding token position encoding zero vector 237 | position_enc = np.array([ 238 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)] 239 | if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)]) 240 | 241 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim 242 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim 243 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 244 | 245 | def forward(self, word_seq): 246 | position_encoding = self.position_enc.unsqueeze(0).expand_as(word_seq) 247 | position_encoding = position_encoding.to(word_seq.device) 248 | word_pos_encoded = word_seq + position_encoding 249 | return word_pos_encoded 250 | 251 | class BertPooler(nn.Module): 252 | def __init__(self, config): 253 | super(BertPooler, self).__init__() 254 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 255 | self.activation = nn.Tanh() 256 | 257 | def forward(self, hidden_states): 258 | # We "pool" the model by simply taking the hidden state corresponding 259 | # to the first token. 260 | first_token_tensor = hidden_states[:, 0] 261 | pooled_output = self.dense(first_token_tensor) 262 | pooled_output = self.activation(pooled_output) 263 | return pooled_output 264 | -------------------------------------------------------------------------------- /evaluation/IS/best_model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/IS/best_model.ckpt -------------------------------------------------------------------------------- /evaluation/IS/compute_is_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=2 python inception_score.py \ 3 | ./best_model.ckpt \ 4 | --data_dir ../../generated_freesound_one_bar/freesound_checkpoint/100000/mel_80_320 --classes 66 \ 5 | --mean_std_dir ../../data/freesound \ 6 | --store_every_score freesound_styelgan2.pkl 7 | 8 | -------------------------------------------------------------------------------- /evaluation/IS/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import numpy as np 6 | from scipy.stats import entropy 7 | import argparse 8 | import sys 9 | import os 10 | from model import FCN, ShortChunkCNN_one_bar 11 | sys.path.append('../') 12 | sys.path.append('../../') 13 | from dataset import MultiResolutionDataset_drum, data_sampler 14 | from torchvision import transforms 15 | from scipy.stats import ttest_ind 16 | import pickle 17 | #from module import genre_classifier 18 | def inception_score(args, transform): 19 | #Preprocess 20 | #100 files under args.data_dir 21 | #classify into args.class classes 22 | N = len([name for name in os.listdir(args.data_dir) if os.path.isfile(os.path.join(args.data_dir, name))]) 23 | preds = np.zeros((N, args.classes)) 24 | 25 | # load checkpoint 26 | checkpoint = torch.load(args.path) 27 | #print(checkpoint['model_state_dict']) 28 | model = ShortChunkCNN_one_bar(n_class = args.classes).cuda() 29 | #print(model) 30 | #print(checkpoint['model_state_dict'].keys()) 31 | model.load_state_dict(checkpoint['model_state_dict']) 32 | model.eval() 33 | 34 | #Load data 35 | dataset = MultiResolutionDataset_drum(args.data_dir, transform) 36 | loader =torch.utils.data.DataLoader( 37 | dataset, 38 | batch_size=2, 39 | sampler=data_sampler(dataset, shuffle=True, distributed=False), 40 | drop_last=True, 41 | ) 42 | 43 | mean_fp = os.path.join(args.mean_std_dir, f'mean.mel.npy') 44 | std_fp = os.path.join(args.mean_std_dir, f'std.mel.npy') 45 | feat_dim = 80 46 | mean = torch.from_numpy(np.load(mean_fp)).float().cuda().view(1, feat_dim, 1) 47 | std = torch.from_numpy(np.load(std_fp)).float().cuda().view(1, feat_dim, 1) 48 | #Model inference 49 | for i, data in enumerate(loader): 50 | data = data.cuda() # [bs, 1, 80, 320] 51 | 52 | data = data * std + mean 53 | data = data.squeeze(1) 54 | output = model(data) # [bs, args.class] 55 | 56 | logit = F.softmax(output, dim = 1).data.cpu().numpy() # [bs, args.class] 57 | 58 | preds[i * 2 : (i + 1) * 2] = logit 59 | #KL divergence 60 | scores = [] 61 | py = np.mean(preds, axis=0) 62 | 63 | for i in range(preds.shape[0]): 64 | pyx = preds[i, :] 65 | scores.append(entropy(pyx, py)) 66 | #Data PostProcessing 67 | is_score = np.exp(np.mean(scores)) 68 | std = np.exp(np.std(scores)) 69 | every_score = np.exp(scores) 70 | return is_score, std, every_score 71 | 72 | 73 | 74 | 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser(description="compute inception score") 79 | 80 | parser.add_argument("path", type=str, help="path to the model") 81 | 82 | parser.add_argument("--data_dir", type=str, help="has 100 npy normalized file under this directory") 83 | 84 | parser.add_argument("--classes", type=int, help="number of classes") 85 | 86 | parser.add_argument("--mean_std_dir", type=str, help="directory which has mean and std npy file") 87 | parser.add_argument("--store_every_score", type=str) 88 | 89 | args = parser.parse_args() 90 | 91 | transform = transforms.Compose( 92 | [ 93 | #transforms.RandomHorizontalFlip(), 94 | transforms.ToTensor(), 95 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 96 | ] 97 | ) 98 | 99 | is_score, std, every_score = inception_score(args, transform) 100 | with open(f'{args.store_every_score}', 'wb') as f: 101 | pickle.dump(every_score, f) 102 | print(is_score, std) 103 | -------------------------------------------------------------------------------- /evaluation/IS/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torchaudio 6 | import sys 7 | from torch.autograd import Variable 8 | import math 9 | import librosa 10 | 11 | 12 | class Conv_1d(nn.Module): 13 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): 14 | super(Conv_1d, self).__init__() 15 | self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 16 | self.bn = nn.BatchNorm1d(output_channels) 17 | self.relu = nn.ReLU() 18 | self.mp = nn.MaxPool1d(pooling) 19 | def forward(self, x): 20 | out = self.mp(self.relu(self.bn(self.conv(x)))) 21 | return out 22 | 23 | 24 | class Conv_2d(nn.Module): 25 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): 26 | super(Conv_2d, self).__init__() 27 | self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 28 | self.bn = nn.BatchNorm2d(output_channels) 29 | self.relu = nn.ReLU() 30 | self.mp = nn.MaxPool2d(pooling) 31 | def forward(self, x): 32 | out = self.mp(self.relu(self.bn(self.conv(x)))) 33 | return out 34 | 35 | 36 | class Res_2d(nn.Module): 37 | def __init__(self, input_channels, output_channels, shape=3, stride=2): 38 | super(Res_2d, self).__init__() 39 | # convolution 40 | self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 41 | self.bn_1 = nn.BatchNorm2d(output_channels) 42 | self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) 43 | self.bn_2 = nn.BatchNorm2d(output_channels) 44 | 45 | # residual 46 | self.diff = False 47 | if (stride != 1) or (input_channels != output_channels): 48 | self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 49 | self.bn_3 = nn.BatchNorm2d(output_channels) 50 | self.diff = True 51 | self.relu = nn.ReLU() 52 | 53 | def forward(self, x): 54 | # convolution 55 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 56 | 57 | # residual 58 | if self.diff: 59 | x = self.bn_3(self.conv_3(x)) 60 | out = x + out 61 | out = self.relu(out) 62 | return out 63 | 64 | 65 | class Res_2d_mp(nn.Module): 66 | def __init__(self, input_channels, output_channels, pooling=2): 67 | super(Res_2d_mp, self).__init__() 68 | self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1) 69 | self.bn_1 = nn.BatchNorm2d(output_channels) 70 | self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1) 71 | self.bn_2 = nn.BatchNorm2d(output_channels) 72 | self.relu = nn.ReLU() 73 | self.mp = nn.MaxPool2d(pooling) 74 | def forward(self, x): 75 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 76 | out = x + out 77 | out = self.mp(self.relu(out)) 78 | return out 79 | 80 | 81 | class ResSE_1d(nn.Module): 82 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=3): 83 | super(ResSE_1d, self).__init__() 84 | # convolution 85 | self.conv_1 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 86 | self.bn_1 = nn.BatchNorm1d(output_channels) 87 | self.conv_2 = nn.Conv1d(output_channels, output_channels, shape, padding=shape//2) 88 | self.bn_2 = nn.BatchNorm1d(output_channels) 89 | 90 | # squeeze & excitation 91 | self.dense1 = nn.Linear(output_channels, output_channels) 92 | self.dense2 = nn.Linear(output_channels, output_channels) 93 | 94 | # residual 95 | self.diff = False 96 | if (stride != 1) or (input_channels != output_channels): 97 | self.conv_3 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 98 | self.bn_3 = nn.BatchNorm1d(output_channels) 99 | self.diff = True 100 | self.relu = nn.ReLU() 101 | self.sigmoid = nn.Sigmoid() 102 | self.mp = nn.MaxPool1d(pooling) 103 | 104 | def forward(self, x): 105 | # convolution 106 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 107 | 108 | # squeeze & excitation 109 | se_out = nn.AvgPool1d(out.size(-1))(out) 110 | se_out = se_out.squeeze(-1) 111 | se_out = self.relu(self.dense1(se_out)) 112 | se_out = self.sigmoid(self.dense2(se_out)) 113 | se_out = se_out.unsqueeze(-1) 114 | out = torch.mul(out, se_out) 115 | 116 | # residual 117 | if self.diff: 118 | x = self.bn_3(self.conv_3(x)) 119 | out = x + out 120 | out = self.mp(self.relu(out)) 121 | return out 122 | 123 | 124 | class Conv_V(nn.Module): 125 | # vertical convolution 126 | def __init__(self, input_channels, output_channels, filter_shape): 127 | super(Conv_V, self).__init__() 128 | self.conv = nn.Conv2d(input_channels, output_channels, filter_shape, 129 | padding=(0, filter_shape[1]//2)) 130 | self.bn = nn.BatchNorm2d(output_channels) 131 | self.relu = nn.ReLU() 132 | 133 | def forward(self, x): 134 | x = self.relu(self.bn(self.conv(x))) 135 | freq = x.size(2) 136 | out = nn.MaxPool2d((freq, 1), stride=(freq, 1))(x) 137 | out = out.squeeze(2) 138 | return out 139 | 140 | 141 | class Conv_H(nn.Module): 142 | # horizontal convolution 143 | def __init__(self, input_channels, output_channels, filter_length): 144 | super(Conv_H, self).__init__() 145 | self.conv = nn.Conv1d(input_channels, output_channels, filter_length, 146 | padding=filter_length//2) 147 | self.bn = nn.BatchNorm1d(output_channels) 148 | self.relu = nn.ReLU() 149 | 150 | def forward(self, x): 151 | freq = x.size(2) 152 | out = nn.AvgPool2d((freq, 1), stride=(freq, 1))(x) 153 | out = out.squeeze(2) 154 | out = self.relu(self.bn(self.conv(out))) 155 | return out 156 | 157 | 158 | # Modules for harmonic filters 159 | def hz_to_midi(hz): 160 | return 12 * (torch.log2(hz) - np.log2(440.0)) + 69 161 | 162 | def midi_to_hz(midi): 163 | return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) 164 | 165 | def note_to_midi(note): 166 | return librosa.core.note_to_midi(note) 167 | 168 | def hz_to_note(hz): 169 | return librosa.core.hz_to_note(hz) 170 | 171 | def initialize_filterbank(sample_rate, n_harmonic, semitone_scale): 172 | # MIDI 173 | # lowest note 174 | low_midi = note_to_midi('C1') 175 | 176 | # highest note 177 | high_note = hz_to_note(sample_rate / (2 * n_harmonic)) 178 | high_midi = note_to_midi(high_note) 179 | 180 | # number of scales 181 | level = (high_midi - low_midi) * semitone_scale 182 | midi = np.linspace(low_midi, high_midi, level + 1) 183 | hz = midi_to_hz(midi[:-1]) 184 | 185 | # stack harmonics 186 | harmonic_hz = [] 187 | for i in range(n_harmonic): 188 | harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1))) 189 | 190 | return harmonic_hz, level 191 | 192 | 193 | class HarmonicSTFT(nn.Module): 194 | def __init__(self, 195 | sample_rate=16000, 196 | n_fft=513, 197 | win_length=None, 198 | hop_length=None, 199 | pad=0, 200 | power=2, 201 | normalized=False, 202 | n_harmonic=6, 203 | semitone_scale=2, 204 | bw_Q=1.0, 205 | learn_bw=None): 206 | super(HarmonicSTFT, self).__init__() 207 | 208 | # Parameters 209 | self.sample_rate = sample_rate 210 | self.n_harmonic = n_harmonic 211 | self.bw_alpha = 0.1079 212 | self.bw_beta = 24.7 213 | 214 | # Spectrogram 215 | self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length, 216 | hop_length=None, pad=0, 217 | window_fn=torch.hann_window, 218 | power=power, normalized=normalized, wkwargs=None) 219 | self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() 220 | 221 | # Initialize the filterbank. Equally spaced in MIDI scale. 222 | harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale) 223 | 224 | # Center frequncies to tensor 225 | self.f0 = torch.tensor(harmonic_hz.astype('float32')) 226 | 227 | # Bandwidth parameters 228 | if learn_bw == 'only_Q': 229 | self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32'))) 230 | elif learn_bw == 'fix': 231 | self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32')) 232 | 233 | def get_harmonic_fb(self): 234 | # bandwidth 235 | bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q 236 | bw = bw.unsqueeze(0) # (1, n_band) 237 | f0 = self.f0.unsqueeze(0) # (1, n_band) 238 | fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1) 239 | 240 | up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw) 241 | down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw) 242 | fb = torch.max(self.zero, torch.min(down_slope, up_slope)) 243 | return fb 244 | 245 | def to_device(self, device, n_bins): 246 | self.f0 = self.f0.to(device) 247 | self.bw_Q = self.bw_Q.to(device) 248 | # fft bins 249 | self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins) 250 | self.fft_bins = self.fft_bins.to(device) 251 | self.zero = torch.zeros(1) 252 | self.zero = self.zero.to(device) 253 | 254 | def forward(self, waveform): 255 | # stft 256 | spectrogram = self.spec(waveform) 257 | 258 | # to device 259 | self.to_device(waveform.device, spectrogram.size(1)) 260 | 261 | # triangle filter 262 | harmonic_fb = self.get_harmonic_fb() 263 | harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2) 264 | 265 | # (batch, channel, length) -> (batch, harmonic, f0, length) 266 | b, c, l = harmonic_spec.size() 267 | harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l) 268 | 269 | # amplitude to db 270 | harmonic_spec = self.amplitude_to_db(harmonic_spec) 271 | return harmonic_spec 272 | -------------------------------------------------------------------------------- /evaluation/NDB_JS/compute_ndb_js.py: -------------------------------------------------------------------------------- 1 | import ndb 2 | import argparse 3 | import os 4 | import numpy as np 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser(description="compute ndb and JS divergence") 9 | 10 | parser.add_argument("--real_dir", type=str) 11 | 12 | parser.add_argument("--gen_dir", type=str) 13 | 14 | parser.add_argument("--mean_std_dir", type=str, help="directory which has mean and std npy file") 15 | 16 | args = parser.parse_args() 17 | 18 | dim = 80 * 320 19 | n_train = 2000 #Don't change this line, this is the amount of looperman dataset 20 | n_test = 2000 # You can change this line depends on how many generated audio you have in the generation directory 21 | 22 | # load train samples 23 | mean_fp = os.path.join(args.mean_std_dir, f'mean.mel.npy') 24 | std_fp = os.path.join(args.mean_std_dir, f'std.mel.npy') 25 | feat_dim = 80 26 | mean = np.load(mean_fp).reshape((1, feat_dim, 1)) 27 | std = np.load(std_fp).reshape((1, feat_dim, 1)) 28 | 29 | train_samples = np.zeros(shape = [n_train, dim]) 30 | 31 | for i, path in enumerate(os.listdir(args.real_dir)): 32 | train_path = os.path.join(args.real_dir, path) 33 | train_numpy = np.load(train_path) 34 | train_numpy = (train_numpy - mean) / std 35 | train_samples[i, :] = train_numpy.reshape((-1, )) 36 | 37 | # load test samples 38 | 39 | 40 | test_samples = np.zeros(shape = [n_test, dim]) 41 | 42 | for i, path in enumerate(os.listdir(args.gen_dir)): 43 | 44 | test_path = os.path.join(args.gen_dir, path) 45 | test_numpy = np.load(test_path) 46 | test_samples[i, :] = test_numpy.reshape((-1, )) 47 | 48 | # NDB and JSD calculation 49 | k = 100 50 | train_ndb = ndb.NDB(training_data=train_samples, number_of_bins=k, whitening=True) 51 | 52 | train_ndb.evaluate(test_samples, model_label='Test') 53 | -------------------------------------------------------------------------------- /evaluation/NDB_JS/compute_ndb_js.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python compute_ndb_js.py --mean_std_dir ../../data/looperman --real_dir ./looper_2000 --gen_dir ../../generated_freesound_one_bar/freesound_checkpoint/100000/mel_80_320 3 | -------------------------------------------------------------------------------- /evaluation/NDB_JS/ndb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from scipy.stats import norm 5 | from matplotlib import pyplot as plt 6 | import pickle as pkl 7 | 8 | class NDB: 9 | def __init__(self, training_data=None, number_of_bins=100, significance_level=0.05, z_threshold=None, 10 | whitening=False, max_dims=None, cache_folder=None): 11 | """ 12 | NDB Evaluation Class 13 | :param training_data: Optional - the training samples - array of m x d floats (m samples of dimension d) 14 | :param number_of_bins: Number of bins (clusters) default=100 15 | :param significance_level: The statistical significance level for the two-sample test 16 | :param z_threshold: Allow defining a threshold in terms of difference/SE for defining a bin as statistically different 17 | :param whitening: Perform data whitening - subtract mean and divide by per-dimension std 18 | :param max_dims: Max dimensions to use in K-means. By default derived automatically from d 19 | :param bins_file: Optional - file to write / read-from the clusters (to avoid re-calculation) 20 | """ 21 | self.number_of_bins = number_of_bins 22 | self.significance_level = significance_level 23 | self.z_threshold = z_threshold 24 | self.whitening = whitening 25 | self.ndb_eps = 1e-6 26 | self.training_mean = 0.0 27 | self.training_std = 1.0 28 | self.max_dims = max_dims 29 | self.cache_folder = cache_folder 30 | self.bin_centers = None 31 | self.bin_proportions = None 32 | self.ref_sample_size = None 33 | self.used_d_indices = None 34 | self.results_file = None 35 | self.test_name = 'ndb_{}_bins_{}'.format(self.number_of_bins, 'whiten' if self.whitening else 'orig') 36 | self.cached_results = {} 37 | if self.cache_folder: 38 | self.results_file = os.path.join(cache_folder, self.test_name+'_results.pkl') 39 | if os.path.isfile(self.results_file): 40 | # print('Loading previous results from', self.results_file, ':') 41 | self.cached_results = pkl.load(open(self.results_file, 'rb')) 42 | # print(self.cached_results.keys()) 43 | if training_data is not None or cache_folder is not None: 44 | bins_file = None 45 | if cache_folder: 46 | os.makedirs(cache_folder, exist_ok=True) 47 | bins_file = os.path.join(cache_folder, self.test_name+'.pkl') 48 | self.construct_bins(training_data, bins_file) 49 | 50 | def construct_bins(self, training_samples, bins_file): 51 | """ 52 | Performs K-means clustering of the training samples 53 | :param training_samples: An array of m x d floats (m samples of dimension d) 54 | """ 55 | 56 | if self.__read_from_bins_file(bins_file): 57 | return 58 | n, d = training_samples.shape 59 | k = self.number_of_bins 60 | if self.whitening: 61 | self.training_mean = np.mean(training_samples, axis=0) 62 | self.training_std = np.std(training_samples, axis=0) + self.ndb_eps 63 | 64 | if self.max_dims is None and d > 1000: 65 | # To ran faster, perform binning on sampled data dimension (i.e. don't use all channels of all pixels) 66 | self.max_dims = d//6 67 | 68 | whitened_samples = (training_samples-self.training_mean)/self.training_std 69 | d_used = d if self.max_dims is None else min(d, self.max_dims) 70 | self.used_d_indices = np.random.choice(d, d_used, replace=False) 71 | 72 | print('Performing K-Means clustering of {} samples in dimension {} / {} to {} clusters ...'.format(n, d_used, d, k)) 73 | print('Can take a couple of minutes...') 74 | if n//k > 1000: 75 | print('Training data size should be ~500 times the number of bins (for reasonable speed and accuracy)') 76 | 77 | clusters = KMeans(n_clusters=k, max_iter=100, n_jobs=-1).fit(whitened_samples[:, self.used_d_indices]) 78 | 79 | bin_centers = np.zeros([k, d]) 80 | for i in range(k): 81 | bin_centers[i, :] = np.mean(whitened_samples[clusters.labels_ == i, :], axis=0) 82 | 83 | # Organize bins by size 84 | label_vals, label_counts = np.unique(clusters.labels_, return_counts=True) 85 | bin_order = np.argsort(-label_counts) 86 | self.bin_proportions = label_counts[bin_order] / np.sum(label_counts) 87 | self.bin_centers = bin_centers[bin_order, :] 88 | self.ref_sample_size = n 89 | self.__write_to_bins_file(bins_file) 90 | print('Done.') 91 | 92 | def evaluate(self, query_samples, model_label=None): 93 | """ 94 | Assign each sample to the nearest bin center (in L2). Pre-whiten if required. and calculate the NDB 95 | (Number of statistically Different Bins) and JS divergence scores. 96 | :param query_samples: An array of m x d floats (m samples of dimension d) 97 | :param model_label: optional label string for the evaluated model, allows plotting results of multiple models 98 | :return: results dictionary containing NDB and JS scores and array of labels (assigned bin for each query sample) 99 | """ 100 | n = query_samples.shape[0] 101 | query_bin_proportions, query_bin_assignments = self.__calculate_bin_proportions(query_samples) 102 | # print(query_bin_proportions) 103 | different_bins = NDB.two_proportions_z_test(self.bin_proportions, self.ref_sample_size, query_bin_proportions, 104 | n, significance_level=self.significance_level, 105 | z_threshold=self.z_threshold) 106 | ndb = np.count_nonzero(different_bins) 107 | js = NDB.jensen_shannon_divergence(self.bin_proportions, query_bin_proportions) 108 | results = {'NDB': ndb, 109 | 'JS': js, 110 | 'Proportions': query_bin_proportions, 111 | 'N': n, 112 | 'Bin-Assignment': query_bin_assignments, 113 | 'Different-Bins': different_bins} 114 | 115 | if model_label: 116 | print('Results for {} samples from {}: '.format(n, model_label), end='') 117 | self.cached_results[model_label] = results 118 | if self.results_file: 119 | # print('Storing result to', self.results_file) 120 | pkl.dump(self.cached_results, open(self.results_file, 'wb')) 121 | 122 | print('NDB =', ndb, 'NDB/K =', ndb/self.number_of_bins, ', JS =', js) 123 | return results 124 | 125 | def print_results(self): 126 | print('NSB results (K={}{}):'.format(self.number_of_bins, ', data whitening' if self.whitening else '')) 127 | for model in sorted(list(self.cached_results.keys())): 128 | res = self.cached_results[model] 129 | print('%s: NDB = %d, NDB/K = %.3f, JS = %.4f' % (model, res['NDB'], res['NDB']/self.number_of_bins, res['JS'])) 130 | 131 | def plot_results(self, models_to_plot=None): 132 | """ 133 | Plot the binning proportions of different methods 134 | :param models_to_plot: optional list of model labels to plot 135 | """ 136 | K = self.number_of_bins 137 | w = 1.0 / (len(self.cached_results)+1) 138 | assert K == self.bin_proportions.size 139 | assert self.cached_results 140 | 141 | # Used for plotting only 142 | def calc_se(p1, n1, p2, n2): 143 | p = (p1 * n1 + p2 * n2) / (n1 + n2) 144 | return np.sqrt(p * (1 - p) * (1/n1 + 1/n2)) 145 | 146 | if not models_to_plot: 147 | models_to_plot = sorted(list(self.cached_results.keys())) 148 | 149 | # Visualize the standard errors using the train proportions and size and query sample size 150 | train_se = calc_se(self.bin_proportions, self.ref_sample_size, 151 | self.bin_proportions, self.cached_results[models_to_plot[0]]['N']) 152 | plt.bar(np.arange(0, K)+0.5, height=train_se*2.0, bottom=self.bin_proportions-train_se, 153 | width=1.0, label='Train$\pm$SE', color='gray') 154 | 155 | ymax = 0.0 156 | for i, model in enumerate(models_to_plot): 157 | results = self.cached_results[model] 158 | label = '%s (%i : %.4f)' % (model, results['NDB'], results['JS']) 159 | ymax = max(ymax, np.max(results['Proportions'])) 160 | if K <= 70: 161 | plt.bar(np.arange(0, K)+(i+1.0)*w, results['Proportions'], width=w, label=label) 162 | else: 163 | plt.plot(np.arange(0, K)+0.5, results['Proportions'], '--*', label=label) 164 | plt.legend(loc='best') 165 | plt.ylim((0.0, min(ymax, np.max(self.bin_proportions)*4.0))) 166 | plt.grid(True) 167 | plt.title('Binning Proportions Evaluation Results for {} bins (NDB : JS)'.format(K)) 168 | plt.show() 169 | 170 | def __calculate_bin_proportions(self, samples): 171 | if self.bin_centers is None: 172 | print('First run construct_bins on samples from the reference training data') 173 | assert samples.shape[1] == self.bin_centers.shape[1] 174 | n, d = samples.shape 175 | k = self.bin_centers.shape[0] 176 | D = np.zeros([n, k], dtype=samples.dtype) 177 | 178 | print('Calculating bin assignments for {} samples...'.format(n)) 179 | whitened_samples = (samples-self.training_mean)/self.training_std 180 | for i in range(k): 181 | print('.', end='', flush=True) 182 | D[:, i] = np.linalg.norm(whitened_samples[:, self.used_d_indices] - self.bin_centers[i, self.used_d_indices], 183 | ord=2, axis=1) 184 | print() 185 | labels = np.argmin(D, axis=1) 186 | probs = np.zeros([k]) 187 | label_vals, label_counts = np.unique(labels, return_counts=True) 188 | probs[label_vals] = label_counts / n 189 | return probs, labels 190 | 191 | def __read_from_bins_file(self, bins_file): 192 | if bins_file and os.path.isfile(bins_file): 193 | print('Loading binning results from', bins_file) 194 | bins_data = pkl.load(open(bins_file,'rb')) 195 | self.bin_proportions = bins_data['proportions'] 196 | self.bin_centers = bins_data['centers'] 197 | self.ref_sample_size = bins_data['n'] 198 | self.training_mean = bins_data['mean'] 199 | self.training_std = bins_data['std'] 200 | self.used_d_indices = bins_data['d_indices'] 201 | return True 202 | return False 203 | 204 | def __write_to_bins_file(self, bins_file): 205 | if bins_file: 206 | print('Caching binning results to', bins_file) 207 | bins_data = {'proportions': self.bin_proportions, 208 | 'centers': self.bin_centers, 209 | 'n': self.ref_sample_size, 210 | 'mean': self.training_mean, 211 | 'std': self.training_std, 212 | 'd_indices': self.used_d_indices} 213 | pkl.dump(bins_data, open(bins_file, 'wb')) 214 | 215 | @staticmethod 216 | def two_proportions_z_test(p1, n1, p2, n2, significance_level, z_threshold=None): 217 | # Per http://stattrek.com/hypothesis-test/difference-in-proportions.aspx 218 | # See also http://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/binotest.htm 219 | p = (p1 * n1 + p2 * n2) / (n1 + n2) 220 | se = np.sqrt(p * (1 - p) * (1/n1 + 1/n2)) 221 | z = (p1 - p2) / se 222 | # Allow defining a threshold in terms as Z (difference relative to the SE) rather than in p-values. 223 | if z_threshold is not None: 224 | return abs(z) > z_threshold 225 | p_values = 2.0 * norm.cdf(-1.0 * np.abs(z)) # Two-tailed test 226 | return p_values < significance_level 227 | 228 | @staticmethod 229 | def jensen_shannon_divergence(p, q): 230 | """ 231 | Calculates the symmetric Jensen–Shannon divergence between the two PDFs 232 | """ 233 | m = (p + q) * 0.5 234 | return 0.5 * (NDB.kl_divergence(p, m) + NDB.kl_divergence(q, m)) 235 | 236 | @staticmethod 237 | def kl_divergence(p, q): 238 | """ 239 | The Kullback–Leibler divergence. 240 | Defined only if q != 0 whenever p != 0. 241 | """ 242 | assert np.all(np.isfinite(p)) 243 | assert np.all(np.isfinite(q)) 244 | assert not np.any(np.logical_and(p != 0, q == 0)) 245 | 246 | p_pos = (p > 0) 247 | return np.sum(p[p_pos] * np.log(p[p_pos] / q[p_pos])) 248 | 249 | 250 | if __name__ == "__main__": 251 | dim=80 * 320 252 | k=100 253 | n_train = k*10 254 | n_test = k*10 255 | 256 | train_samples = np.random.uniform(size=[n_train,dim]) 257 | ndb = NDB(training_data=train_samples, number_of_bins=k, whitening=True) 258 | 259 | test_samples = np.random.uniform(high=1.0, size=[n_test, dim]) 260 | ndb.evaluate(test_samples, model_label='Test') 261 | 262 | test_samples = np.random.uniform(high=0.9, size=[n_test, dim]) 263 | ndb.evaluate(test_samples, model_label='Good') 264 | 265 | test_samples = np.random.uniform(high=0.75, size=[n_test, dim]) 266 | ndb.evaluate(test_samples, model_label='Bad') 267 | 268 | ndb.plot_results(models_to_plot=['Test', 'Good', 'Bad']) 269 | -------------------------------------------------------------------------------- /evaluation/nine_audio/Chillout/1365.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Chillout/1365.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Drum_and_Bass/14.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Drum_and_Bass/14.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Electronic/1353.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Electronic/1353.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Hiphop/323.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Hiphop/323.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Rap/153.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Rap/153.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Rock/752.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Rock/752.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/Trap/1395.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Trap/1395.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/dupstep/113.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/dupstep/113.wav -------------------------------------------------------------------------------- /evaluation/nine_audio/industrial/1606.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/industrial/1606.wav -------------------------------------------------------------------------------- /generate_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model_drum import Generator 6 | from tqdm import tqdm 7 | 8 | import sys 9 | sys.path.append('./melgan') 10 | from modules import Generator_melgan 11 | 12 | import yaml 13 | import os 14 | 15 | import librosa 16 | 17 | import soundfile as sf 18 | 19 | import numpy as np 20 | 21 | import os 22 | def read_yaml(fp): 23 | with open(fp) as file: 24 | # return yaml.load(file) 25 | return yaml.load(file, Loader=yaml.Loader) 26 | 27 | def generate(args, g_ema, device, mean_latent): 28 | epoch = args.ckpt.split('.')[0] 29 | 30 | os.makedirs(f'{args.store_path}/{epoch}', exist_ok=True) 31 | os.makedirs(f'{args.store_path}/{epoch}/mel_80_320', exist_ok=True) 32 | feat_dim = 80 33 | mean_fp = f'{args.data_path}/mean.mel.npy' 34 | std_fp = f'{args.data_path}/std.mel.npy' 35 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device) 36 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device) 37 | vocoder_config_fp = './melgan/args.yml' 38 | vocoder_config = read_yaml(vocoder_config_fp) 39 | 40 | n_mel_channels = vocoder_config.n_mel_channels 41 | ngf = vocoder_config.ngf 42 | n_residual_layers = vocoder_config.n_residual_layers 43 | sr=44100 44 | 45 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device) 46 | vocoder.eval() 47 | 48 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt') 49 | vocoder.load_state_dict(torch.load(vocoder_param_fp)) 50 | 51 | 52 | with torch.no_grad(): 53 | g_ema.eval() 54 | for i in tqdm(range(args.pics)): 55 | sample_z = torch.randn(args.sample, args.latent, device=device) 56 | 57 | sample, _ = g_ema( 58 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 59 | ) 60 | np.save(f'{args.store_path}/{epoch}/mel_80_320/{i}.npy', sample.squeeze().data.cpu().numpy()) 61 | 62 | utils.save_image( 63 | sample, 64 | f"{args.store_path}/{epoch}/{str(i).zfill(6)}.png", 65 | nrow=1, 66 | normalize=True, 67 | range=(-1, 1), 68 | ) 69 | de_norm = sample.squeeze(0) * std + mean 70 | audio_output = vocoder(de_norm) 71 | sf.write(f'{args.store_path}/{epoch}/{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 72 | print('generate {}th wav file'.format(i)) 73 | @torch.no_grad() 74 | def style_mixing(args, generator, step, mean_style, n_source, n_target, device, j): 75 | index = 2 76 | # create directory 77 | os.makedirs(f'./generated_interpolation_one_bar_{index}/{j}', exist_ok=True) 78 | 79 | # load melgan vocoder 80 | feat_dim = 80 81 | mean_fp = f'{args.data_path}/mean.mel.npy' 82 | std_fp = f'{args.data_path}/std.mel.npy' 83 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device) 84 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device) 85 | vocoder_config_fp = './melgan/args.yml' 86 | vocoder_config = read_yaml(vocoder_config_fp) 87 | 88 | n_mel_channels = vocoder_config.n_mel_channels 89 | ngf = vocoder_config.ngf 90 | n_residual_layers = vocoder_config.n_residual_layers 91 | sr=44100 92 | 93 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device) 94 | vocoder.eval() 95 | 96 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt') 97 | vocoder.load_state_dict(torch.load(vocoder_param_fp)) 98 | 99 | #generate spectrogram 100 | source_code = torch.randn(n_source, 512).to(device) 101 | target_code = torch.randn(n_target, 512).to(device) 102 | 103 | shape = 4 * 2 ** step 104 | alpha = 1 105 | 106 | images = [torch.ones(1, 1, 80, 320).to(device) * -1] 107 | 108 | source_image,_ = generator( 109 | [source_code], truncation=args.truncation, truncation_latent=mean_style 110 | ) 111 | target_image,_ = generator( 112 | [target_code], truncation=args.truncation, truncation_latent=mean_style 113 | ) 114 | 115 | images.append(source_image) 116 | 117 | for i in range(n_source): 118 | de_norm = source_image[i] * std + mean 119 | audio_output = vocoder(de_norm) 120 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/source_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 121 | 122 | for i in range(n_target): 123 | de_norm = target_image[i] * std + mean 124 | audio_output = vocoder(de_norm) 125 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 126 | 127 | for i in range(n_target): 128 | image, _ = generator( 129 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code], 130 | truncation_latent=mean_style, 131 | inject_index = index 132 | ) 133 | 134 | for k in range(n_source): 135 | de_norm = image[k] * std + mean 136 | audio_output = vocoder(de_norm) 137 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/source_{k}_target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 138 | 139 | images.append(target_image[i].unsqueeze(0)) 140 | images.append(image) 141 | 142 | images = torch.cat(images, 0) 143 | utils.save_image( 144 | images, f'./generated_interpolation_one_bar_{index}/{j}/sample_mixing.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1) 145 | ) 146 | return images 147 | if __name__ == "__main__": 148 | device = "cuda" 149 | 150 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 151 | 152 | parser.add_argument( 153 | "--size", type=int, default=64, help="output image size of the generator" 154 | ) 155 | parser.add_argument( 156 | "--sample", 157 | type=int, 158 | default=1, 159 | help="number of samples to be generated for each image", 160 | ) 161 | parser.add_argument( 162 | "--pics", type=int, default=20, help="number of images to be generated" 163 | ) 164 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 165 | parser.add_argument( 166 | "--truncation_mean", 167 | type=int, 168 | default=4096, 169 | help="number of vectors to calculate mean for the truncation", 170 | ) 171 | parser.add_argument( 172 | "--ckpt", 173 | type=str, 174 | default="stylegan2-ffhq-config-f.pt", 175 | help="path to the model checkpoint", 176 | ) 177 | parser.add_argument( 178 | "--data_path", 179 | type=str, 180 | help="path store the std and mean of mel", 181 | ) 182 | parser.add_argument( 183 | "--store_path", 184 | type=str, 185 | help="path store the generated audio", 186 | ) 187 | parser.add_argument( 188 | "--channel_multiplier", 189 | type=int, 190 | default=2, 191 | help="channel multiplier of the generator. config-f = 2, else = 1", 192 | ) 193 | parser.add_argument("--style_mixing", action = "store_true") 194 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix') 195 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix') 196 | args = parser.parse_args() 197 | 198 | args.latent = 512 199 | args.n_mlp = 8 200 | 201 | g_ema = Generator( 202 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 203 | ).to(device) 204 | checkpoint = torch.load(args.ckpt) 205 | 206 | g_ema.load_state_dict(checkpoint["g_ema"]) 207 | 208 | if args.truncation < 1: 209 | with torch.no_grad(): 210 | mean_latent = g_ema.mean_latent(args.truncation_mean) 211 | else: 212 | mean_latent = None 213 | 214 | # Generate audio 215 | generate(args, g_ema, device, mean_latent) 216 | 217 | #Style mixing 218 | if args.style_mixing == True: 219 | step = 0 220 | for j in range(20): 221 | img = style_mixing(args,g_ema, step, mean_latent, args.n_col, args.n_row, device, j) 222 | -------------------------------------------------------------------------------- /generate_looperman_four_bar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model_drum_four_bar import Generator 6 | from tqdm import tqdm 7 | 8 | import sys 9 | sys.path.append('./melgan') 10 | from modules import Generator_melgan 11 | 12 | import yaml 13 | import os 14 | 15 | import librosa 16 | 17 | import soundfile as sf 18 | 19 | import numpy as np 20 | 21 | import os 22 | def read_yaml(fp): 23 | with open(fp) as file: 24 | # return yaml.load(file) 25 | return yaml.load(file, Loader=yaml.Loader) 26 | 27 | def generate(args, g_ema, device, mean_latent): 28 | epoch = args.ckpt.split('.')[0] 29 | 30 | os.makedirs(f'{args.store_path}/{epoch}', exist_ok=True) 31 | os.makedirs(f'{args.store_path}/{epoch}/mel_80_320', exist_ok=True) 32 | feat_dim = 80 33 | mean_fp = f'{args.data_path}/mean.mel.npy' 34 | std_fp = f'{args.data_path}/std.mel.npy' 35 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device) 36 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device) 37 | vocoder_config_fp = './melgan/args.yml' 38 | vocoder_config = read_yaml(vocoder_config_fp) 39 | 40 | n_mel_channels = vocoder_config.n_mel_channels 41 | ngf = vocoder_config.ngf 42 | n_residual_layers = vocoder_config.n_residual_layers 43 | sr=44100 44 | 45 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device) 46 | vocoder.eval() 47 | 48 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt') 49 | vocoder.load_state_dict(torch.load(vocoder_param_fp)) 50 | 51 | 52 | with torch.no_grad(): 53 | g_ema.eval() 54 | for i in tqdm(range(args.pics)): 55 | sample_z = torch.randn(args.sample, args.latent, device=device) 56 | 57 | sample, _ = g_ema( 58 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 59 | ) 60 | np.save(f'{args.store_path}/{epoch}/mel_80_320/{i}.npy', sample.squeeze().data.cpu().numpy()) 61 | 62 | utils.save_image( 63 | sample, 64 | f"{args.store_path}/{epoch}/{str(i).zfill(6)}.png", 65 | nrow=1, 66 | normalize=True, 67 | range=(-1, 1), 68 | ) 69 | de_norm = sample.squeeze(0) * std + mean 70 | audio_output = vocoder(de_norm) 71 | sf.write(f'{args.store_path}/{epoch}/{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 72 | print('generate {}th wav file'.format(i)) 73 | @torch.no_grad() 74 | def style_mixing(args, generator, step, mean_style, n_source, n_target, device, j): 75 | index = 5 76 | # create directory 77 | os.makedirs(f'./generated_interpolation_{index}/{j}', exist_ok=True) 78 | 79 | # load melgan vocoder 80 | feat_dim = 80 81 | mean_fp = f'{args.data_path}/mean.mel.npy' 82 | std_fp = f'{args.data_path}/std.mel.npy' 83 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device) 84 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device) 85 | vocoder_config_fp = './melgan/args.yml' 86 | vocoder_config = read_yaml(vocoder_config_fp) 87 | 88 | n_mel_channels = vocoder_config.n_mel_channels 89 | ngf = vocoder_config.ngf 90 | n_residual_layers = vocoder_config.n_residual_layers 91 | sr=44100 92 | 93 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device) 94 | vocoder.eval() 95 | 96 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt') 97 | vocoder.load_state_dict(torch.load(vocoder_param_fp)) 98 | 99 | #generate spectrogram 100 | source_code = torch.randn(n_source, 512).to(device) 101 | target_code = torch.randn(n_target, 512).to(device) 102 | 103 | shape = 4 * 2 ** step 104 | alpha = 1 105 | 106 | images = [torch.ones(1, 1, 80, 320).to(device) * -1] 107 | 108 | source_image,_ = generator( 109 | [source_code], truncation=args.truncation, truncation_latent=mean_style 110 | ) 111 | target_image,_ = generator( 112 | [target_code], truncation=args.truncation, truncation_latent=mean_style 113 | ) 114 | 115 | images.append(source_image) 116 | 117 | for i in range(n_source): 118 | de_norm = source_image[i] * std + mean 119 | audio_output = vocoder(de_norm) 120 | sf.write(f'./generated_interpolation_{index}/{j}/source_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 121 | 122 | for i in range(n_target): 123 | de_norm = target_image[i] * std + mean 124 | audio_output = vocoder(de_norm) 125 | sf.write(f'./generated_interpolation_{index}/{j}/target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 126 | 127 | for i in range(n_target): 128 | image, _ = generator( 129 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code], 130 | truncation_latent=mean_style, 131 | inject_index = index 132 | ) 133 | 134 | for k in range(n_source): 135 | de_norm = image[k] * std + mean 136 | audio_output = vocoder(de_norm) 137 | sf.write(f'./generated_interpolation_{index}/{j}/source_{k}_target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr) 138 | 139 | images.append(target_image[i].unsqueeze(0)) 140 | images.append(image) 141 | 142 | images = torch.cat(images, 0) 143 | utils.save_image( 144 | images, f'./generated_interpolation_{index}/{j}/sample_mixing.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1) 145 | ) 146 | return images 147 | if __name__ == "__main__": 148 | device = "cuda" 149 | 150 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 151 | 152 | parser.add_argument( 153 | "--size", type=int, default=64, help="output image size of the generator" 154 | ) 155 | parser.add_argument( 156 | "--sample", 157 | type=int, 158 | default=1, 159 | help="number of samples to be generated for each image", 160 | ) 161 | parser.add_argument( 162 | "--pics", type=int, default=20, help="number of images to be generated" 163 | ) 164 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 165 | parser.add_argument( 166 | "--truncation_mean", 167 | type=int, 168 | default=4096, 169 | help="number of vectors to calculate mean for the truncation", 170 | ) 171 | parser.add_argument( 172 | "--ckpt", 173 | type=str, 174 | default="stylegan2-ffhq-config-f.pt", 175 | help="path to the model checkpoint", 176 | ) 177 | parser.add_argument( 178 | "--data_path", 179 | type=str, 180 | help="path store the std and mean of mel", 181 | ) 182 | parser.add_argument( 183 | "--store_path", 184 | type=str, 185 | help="path store the generated audio", 186 | ) 187 | parser.add_argument( 188 | "--channel_multiplier", 189 | type=int, 190 | default=2, 191 | help="channel multiplier of the generator. config-f = 2, else = 1", 192 | ) 193 | 194 | parser.add_argument("--style_mixing", action = "store_true") 195 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix') 196 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix') 197 | args = parser.parse_args() 198 | 199 | args.latent = 512 200 | args.n_mlp = 8 201 | 202 | g_ema = Generator( 203 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 204 | ).to(device) 205 | checkpoint = torch.load(args.ckpt) 206 | 207 | g_ema.load_state_dict(checkpoint["g_ema"]) 208 | 209 | if args.truncation < 1: 210 | with torch.no_grad(): 211 | mean_latent = g_ema.mean_latent(args.truncation_mean) 212 | else: 213 | mean_latent = None 214 | 215 | generate(args, g_ema, device, mean_latent) 216 | # Style mixing 217 | if args.style_mixing == True: 218 | step = 0 219 | for j in range(20): 220 | img = style_mixing(args,g_ema, step, mean_latent, args.n_col, args.n_row, device, j) 221 | -------------------------------------------------------------------------------- /melgan/.ipynb_checkpoints/modules-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from librosa.filters import mel as librosa_mel_fn 5 | from torch.nn.utils import weight_norm 6 | import numpy as np 7 | 8 | 9 | def weights_init(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | m.weight.data.normal_(0.0, 0.02) 13 | elif classname.find("BatchNorm2d") != -1: 14 | m.weight.data.normal_(1.0, 0.02) 15 | m.bias.data.fill_(0) 16 | 17 | 18 | def WNConv1d(*args, **kwargs): 19 | return weight_norm(nn.Conv1d(*args, **kwargs)) 20 | 21 | 22 | def WNConvTranspose1d(*args, **kwargs): 23 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 24 | 25 | 26 | class Audio2Mel(nn.Module): 27 | def __init__( 28 | self, 29 | n_fft=1024, 30 | hop_length=256, 31 | win_length=1024, 32 | sampling_rate=22050, 33 | n_mel_channels=80, 34 | mel_fmin=0.0, 35 | mel_fmax=None, 36 | ): 37 | super().__init__() 38 | ############################################## 39 | # FFT Parameters # 40 | ############################################## 41 | window = torch.hann_window(win_length).float() 42 | mel_basis = librosa_mel_fn( 43 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax 44 | ) 45 | mel_basis = torch.from_numpy(mel_basis).float() 46 | self.register_buffer("mel_basis", mel_basis) 47 | self.register_buffer("window", window) 48 | self.n_fft = n_fft 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.sampling_rate = sampling_rate 52 | self.n_mel_channels = n_mel_channels 53 | 54 | def forward(self, audio): 55 | p = (self.n_fft - self.hop_length) // 2 56 | audio = F.pad(audio, (p, p), "reflect").squeeze(1) 57 | fft = torch.stft( 58 | audio, 59 | n_fft=self.n_fft, 60 | hop_length=self.hop_length, 61 | win_length=self.win_length, 62 | window=self.window, 63 | center=False, 64 | ) 65 | real_part, imag_part = fft.unbind(-1) 66 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 67 | mel_output = torch.matmul(self.mel_basis, magnitude) 68 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5)) 69 | return log_mel_spec 70 | 71 | 72 | class ResnetBlock(nn.Module): 73 | def __init__(self, dim, dilation=1): 74 | super().__init__() 75 | self.block = nn.Sequential( 76 | nn.LeakyReLU(0.2), 77 | nn.ReflectionPad1d(dilation), 78 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation), 79 | nn.LeakyReLU(0.2), 80 | WNConv1d(dim, dim, kernel_size=1), 81 | ) 82 | self.shortcut = WNConv1d(dim, dim, kernel_size=1) 83 | 84 | def forward(self, x): 85 | return self.shortcut(x) + self.block(x) 86 | 87 | 88 | class Generator_melgan(nn.Module): 89 | def __init__(self, input_size, ngf, n_residual_layers): 90 | super().__init__() 91 | ratios = [8, 8, 2, 2] 92 | self.hop_length = np.prod(ratios) 93 | mult = int(2 ** len(ratios)) 94 | 95 | model = [ 96 | nn.ReflectionPad1d(3), 97 | WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0), 98 | ] 99 | 100 | # Upsample to raw audio scale 101 | for i, r in enumerate(ratios): 102 | model += [ 103 | nn.LeakyReLU(0.2), 104 | WNConvTranspose1d( 105 | mult * ngf, 106 | mult * ngf // 2, 107 | kernel_size=r * 2, 108 | stride=r, 109 | padding=r // 2 + r % 2, 110 | output_padding=r % 2, 111 | ), 112 | ] 113 | 114 | for j in range(n_residual_layers): 115 | model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] 116 | 117 | mult //= 2 118 | 119 | model += [ 120 | nn.LeakyReLU(0.2), 121 | nn.ReflectionPad1d(3), 122 | WNConv1d(ngf, 1, kernel_size=7, padding=0), 123 | nn.Tanh(), 124 | ] 125 | 126 | self.model = nn.Sequential(*model) 127 | self.apply(weights_init) 128 | 129 | def forward(self, x): 130 | return self.model(x) 131 | 132 | 133 | class NLayerDiscriminator(nn.Module): 134 | def __init__(self, ndf, n_layers, downsampling_factor): 135 | super().__init__() 136 | model = nn.ModuleDict() 137 | 138 | model["layer_0"] = nn.Sequential( 139 | nn.ReflectionPad1d(7), 140 | WNConv1d(1, ndf, kernel_size=15), 141 | nn.LeakyReLU(0.2, True), 142 | ) 143 | 144 | nf = ndf 145 | stride = downsampling_factor 146 | for n in range(1, n_layers + 1): 147 | nf_prev = nf 148 | nf = min(nf * stride, 1024) 149 | 150 | model["layer_%d" % n] = nn.Sequential( 151 | WNConv1d( 152 | nf_prev, 153 | nf, 154 | kernel_size=stride * 10 + 1, 155 | stride=stride, 156 | padding=stride * 5, 157 | groups=nf_prev // 4, 158 | ), 159 | nn.LeakyReLU(0.2, True), 160 | ) 161 | 162 | nf = min(nf * 2, 1024) 163 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 164 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 165 | nn.LeakyReLU(0.2, True), 166 | ) 167 | 168 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 169 | nf, 1, kernel_size=3, stride=1, padding=1 170 | ) 171 | 172 | self.model = model 173 | 174 | def forward(self, x): 175 | results = [] 176 | for key, layer in self.model.items(): 177 | x = layer(x) 178 | results.append(x) 179 | return results 180 | 181 | 182 | class Discriminator(nn.Module): 183 | def __init__(self, num_D, ndf, n_layers, downsampling_factor): 184 | super().__init__() 185 | self.model = nn.ModuleDict() 186 | for i in range(num_D): 187 | self.model[f"disc_{i}"] = NLayerDiscriminator( 188 | ndf, n_layers, downsampling_factor 189 | ) 190 | 191 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 192 | self.apply(weights_init) 193 | 194 | def forward(self, x): 195 | results = [] 196 | for key, disc in self.model.items(): 197 | results.append(disc(x)) 198 | x = self.downsample(x) 199 | return results 200 | -------------------------------------------------------------------------------- /melgan/args.yml: -------------------------------------------------------------------------------- 1 | !!python/object:argparse.Namespace 2 | batch_size: 16 3 | cond_disc: false 4 | data_path: !!python/object/apply:pathlib.PosixPath 5 | - / 6 | - home 7 | - allenhung 8 | - nas189 9 | - home 10 | - one_bar_loop_all 11 | downsamp_factor: 4 12 | epochs: 30000 13 | lambda_feat: 10 14 | load_path: null 15 | log_interval: 100 16 | n_layers_D: 4 17 | n_mel_channels: 80 18 | n_residual_layers: 3 19 | n_test_samples: 8 20 | ndf: 16 21 | ngf: 32 22 | num_D: 3 23 | save_interval: 1000 24 | save_path: logs/one_bar_all 25 | seq_len: 8192 26 | -------------------------------------------------------------------------------- /melgan/best_netG.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/melgan/best_netG.pt -------------------------------------------------------------------------------- /melgan/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from librosa.filters import mel as librosa_mel_fn 5 | from torch.nn.utils import weight_norm 6 | import numpy as np 7 | 8 | 9 | def weights_init(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | m.weight.data.normal_(0.0, 0.02) 13 | elif classname.find("BatchNorm2d") != -1: 14 | m.weight.data.normal_(1.0, 0.02) 15 | m.bias.data.fill_(0) 16 | 17 | 18 | def WNConv1d(*args, **kwargs): 19 | return weight_norm(nn.Conv1d(*args, **kwargs)) 20 | 21 | 22 | def WNConvTranspose1d(*args, **kwargs): 23 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 24 | 25 | 26 | class Audio2Mel(nn.Module): 27 | def __init__( 28 | self, 29 | n_fft=1024, 30 | hop_length=256, 31 | win_length=1024, 32 | sampling_rate=22050, 33 | n_mel_channels=80, 34 | mel_fmin=0.0, 35 | mel_fmax=None, 36 | ): 37 | super().__init__() 38 | ############################################## 39 | # FFT Parameters # 40 | ############################################## 41 | window = torch.hann_window(win_length).float() 42 | mel_basis = librosa_mel_fn( 43 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax 44 | ) 45 | mel_basis = torch.from_numpy(mel_basis).float() 46 | self.register_buffer("mel_basis", mel_basis) 47 | self.register_buffer("window", window) 48 | self.n_fft = n_fft 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.sampling_rate = sampling_rate 52 | self.n_mel_channels = n_mel_channels 53 | 54 | def forward(self, audio): 55 | p = (self.n_fft - self.hop_length) // 2 56 | audio = F.pad(audio, (p, p), "reflect").squeeze(1) 57 | fft = torch.stft( 58 | audio, 59 | n_fft=self.n_fft, 60 | hop_length=self.hop_length, 61 | win_length=self.win_length, 62 | window=self.window, 63 | center=False, 64 | ) 65 | real_part, imag_part = fft.unbind(-1) 66 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 67 | mel_output = torch.matmul(self.mel_basis, magnitude) 68 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5)) 69 | return log_mel_spec 70 | 71 | 72 | class ResnetBlock(nn.Module): 73 | def __init__(self, dim, dilation=1): 74 | super().__init__() 75 | self.block = nn.Sequential( 76 | nn.LeakyReLU(0.2), 77 | nn.ReflectionPad1d(dilation), 78 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation), 79 | nn.LeakyReLU(0.2), 80 | WNConv1d(dim, dim, kernel_size=1), 81 | ) 82 | self.shortcut = WNConv1d(dim, dim, kernel_size=1) 83 | 84 | def forward(self, x): 85 | return self.shortcut(x) + self.block(x) 86 | 87 | 88 | class Generator_melgan(nn.Module): 89 | def __init__(self, input_size, ngf, n_residual_layers): 90 | super().__init__() 91 | ratios = [8, 8, 2, 2] 92 | self.hop_length = np.prod(ratios) 93 | mult = int(2 ** len(ratios)) 94 | 95 | model = [ 96 | nn.ReflectionPad1d(3), 97 | WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0), 98 | ] 99 | 100 | # Upsample to raw audio scale 101 | for i, r in enumerate(ratios): 102 | model += [ 103 | nn.LeakyReLU(0.2), 104 | WNConvTranspose1d( 105 | mult * ngf, 106 | mult * ngf // 2, 107 | kernel_size=r * 2, 108 | stride=r, 109 | padding=r // 2 + r % 2, 110 | output_padding=r % 2, 111 | ), 112 | ] 113 | 114 | for j in range(n_residual_layers): 115 | model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)] 116 | 117 | mult //= 2 118 | 119 | model += [ 120 | nn.LeakyReLU(0.2), 121 | nn.ReflectionPad1d(3), 122 | WNConv1d(ngf, 1, kernel_size=7, padding=0), 123 | nn.Tanh(), 124 | ] 125 | 126 | self.model = nn.Sequential(*model) 127 | self.apply(weights_init) 128 | 129 | def forward(self, x): 130 | return self.model(x) 131 | 132 | 133 | class NLayerDiscriminator(nn.Module): 134 | def __init__(self, ndf, n_layers, downsampling_factor): 135 | super().__init__() 136 | model = nn.ModuleDict() 137 | 138 | model["layer_0"] = nn.Sequential( 139 | nn.ReflectionPad1d(7), 140 | WNConv1d(1, ndf, kernel_size=15), 141 | nn.LeakyReLU(0.2, True), 142 | ) 143 | 144 | nf = ndf 145 | stride = downsampling_factor 146 | for n in range(1, n_layers + 1): 147 | nf_prev = nf 148 | nf = min(nf * stride, 1024) 149 | 150 | model["layer_%d" % n] = nn.Sequential( 151 | WNConv1d( 152 | nf_prev, 153 | nf, 154 | kernel_size=stride * 10 + 1, 155 | stride=stride, 156 | padding=stride * 5, 157 | groups=nf_prev // 4, 158 | ), 159 | nn.LeakyReLU(0.2, True), 160 | ) 161 | 162 | nf = min(nf * 2, 1024) 163 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 164 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 165 | nn.LeakyReLU(0.2, True), 166 | ) 167 | 168 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 169 | nf, 1, kernel_size=3, stride=1, padding=1 170 | ) 171 | 172 | self.model = model 173 | 174 | def forward(self, x): 175 | results = [] 176 | for key, layer in self.model.items(): 177 | x = layer(x) 178 | results.append(x) 179 | return results 180 | 181 | 182 | class Discriminator(nn.Module): 183 | def __init__(self, num_D, ndf, n_layers, downsampling_factor): 184 | super().__init__() 185 | self.model = nn.ModuleDict() 186 | for i in range(num_D): 187 | self.model[f"disc_{i}"] = NLayerDiscriminator( 188 | ndf, n_layers, downsampling_factor 189 | ) 190 | 191 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 192 | self.apply(weights_init) 193 | 194 | def forward(self, x): 195 | results = [] 196 | for key, disc in self.model.items(): 197 | results.append(disc(x)) 198 | x = self.downsample(x) 199 | return results 200 | -------------------------------------------------------------------------------- /model_drum.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer("kernel", kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer("kernel", kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer("kernel", kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 166 | ) 167 | 168 | 169 | class ModulatedConv2d(nn.Module): 170 | def __init__( 171 | self, 172 | in_channel, 173 | out_channel, 174 | kernel_size, 175 | style_dim, 176 | demodulate=True, 177 | upsample=False, 178 | downsample=False, 179 | blur_kernel=[1, 3, 3, 1], 180 | ): 181 | super().__init__() 182 | 183 | self.eps = 1e-8 184 | self.kernel_size = kernel_size 185 | self.in_channel = in_channel 186 | self.out_channel = out_channel 187 | self.upsample = upsample 188 | self.downsample = downsample 189 | 190 | if upsample: 191 | factor = 2 192 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 193 | pad0 = (p + 1) // 2 + factor - 1 194 | pad1 = p // 2 + 1 195 | 196 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 197 | 198 | if downsample: 199 | factor = 2 200 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 201 | pad0 = (p + 1) // 2 202 | pad1 = p // 2 203 | 204 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 205 | 206 | fan_in = in_channel * kernel_size ** 2 207 | self.scale = 1 / math.sqrt(fan_in) 208 | self.padding = kernel_size // 2 209 | 210 | self.weight = nn.Parameter( 211 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 212 | ) 213 | 214 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 215 | 216 | self.demodulate = demodulate 217 | 218 | def __repr__(self): 219 | return ( 220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 221 | f"upsample={self.upsample}, downsample={self.downsample})" 222 | ) 223 | 224 | def forward(self, input, style): 225 | batch, in_channel, height, width = input.shape 226 | 227 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 228 | weight = self.scale * self.weight * style 229 | 230 | if self.demodulate: 231 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 232 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 233 | 234 | weight = weight.view( 235 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 236 | ) 237 | 238 | if self.upsample: 239 | input = input.view(1, batch * in_channel, height, width) 240 | weight = weight.view( 241 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 242 | ) 243 | weight = weight.transpose(1, 2).reshape( 244 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 245 | ) 246 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 247 | _, _, height, width = out.shape 248 | out = out.view(batch, self.out_channel, height, width) 249 | out = self.blur(out) 250 | 251 | elif self.downsample: 252 | input = self.blur(input) 253 | _, _, height, width = input.shape 254 | input = input.view(1, batch * in_channel, height, width) 255 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 256 | _, _, height, width = out.shape 257 | out = out.view(batch, self.out_channel, height, width) 258 | 259 | else: 260 | input = input.view(1, batch * in_channel, height, width) 261 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 262 | _, _, height, width = out.shape 263 | out = out.view(batch, self.out_channel, height, width) 264 | 265 | return out 266 | 267 | 268 | class NoiseInjection(nn.Module): 269 | def __init__(self): 270 | super().__init__() 271 | 272 | self.weight = nn.Parameter(torch.zeros(1)) 273 | 274 | def forward(self, image, noise=None): 275 | if noise is None: 276 | batch, _, height, width = image.shape 277 | noise = image.new_empty(batch, 1, height, width).normal_() 278 | 279 | return image + self.weight * noise 280 | 281 | 282 | class ConstantInput(nn.Module): 283 | def __init__(self, channel, size=4): 284 | super().__init__() 285 | 286 | self.input = nn.Parameter(torch.randn(1, channel, 5, 20)) 287 | 288 | def forward(self, input): 289 | batch = input.shape[0] 290 | out = self.input.repeat(batch, 1, 1, 1) 291 | 292 | return out 293 | 294 | 295 | class StyledConv(nn.Module): 296 | def __init__( 297 | self, 298 | in_channel, 299 | out_channel, 300 | kernel_size, 301 | style_dim, 302 | upsample=False, 303 | blur_kernel=[1, 3, 3, 1], 304 | demodulate=True, 305 | ): 306 | super().__init__() 307 | 308 | self.conv = ModulatedConv2d( 309 | in_channel, 310 | out_channel, 311 | kernel_size, 312 | style_dim, 313 | upsample=upsample, 314 | blur_kernel=blur_kernel, 315 | demodulate=demodulate, 316 | ) 317 | 318 | self.noise = NoiseInjection() 319 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 320 | # self.activate = ScaledLeakyReLU(0.2) 321 | self.activate = FusedLeakyReLU(out_channel) 322 | 323 | def forward(self, input, style, noise=None): 324 | out = self.conv(input, style) 325 | out = self.noise(out, noise=noise) 326 | # out = out + self.bias 327 | out = self.activate(out) 328 | 329 | return out 330 | 331 | 332 | class ToRGB(nn.Module): 333 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 334 | super().__init__() 335 | 336 | if upsample: 337 | self.upsample = Upsample(blur_kernel) 338 | 339 | self.conv = ModulatedConv2d(in_channel, 1, 1, style_dim, demodulate=False) 340 | self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 341 | 342 | def forward(self, input, style, skip=None): 343 | out = self.conv(input, style) 344 | out = out + self.bias 345 | 346 | if skip is not None: 347 | skip = self.upsample(skip) 348 | 349 | out = out + skip 350 | 351 | return out 352 | 353 | 354 | class Generator(nn.Module): 355 | def __init__( 356 | self, 357 | size, 358 | style_dim, 359 | n_mlp, 360 | channel_multiplier=2, 361 | blur_kernel=[1, 3, 3, 1], 362 | lr_mlp=0.01, 363 | ): 364 | super().__init__() 365 | 366 | self.size = size 367 | 368 | self.style_dim = style_dim 369 | 370 | layers = [PixelNorm()] 371 | 372 | for i in range(n_mlp): 373 | layers.append( 374 | EqualLinear( 375 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 376 | ) 377 | ) 378 | 379 | self.style = nn.Sequential(*layers) 380 | 381 | self.channels = { 382 | 4: 512, 383 | 8: 512, 384 | 16: 512, 385 | 32: 512, 386 | 64: 256 * channel_multiplier, 387 | 128: 128 * channel_multiplier, 388 | 256: 64 * channel_multiplier, 389 | 512: 32 * channel_multiplier, 390 | 1024: 16 * channel_multiplier, 391 | } 392 | 393 | self.input = ConstantInput(self.channels[4]) 394 | self.conv1 = StyledConv( 395 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 396 | ) 397 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 398 | 399 | self.log_size = int(math.log(size, 2)) 400 | self.num_layers = (self.log_size - 2) * 2 + 1 #9 401 | 402 | self.convs = nn.ModuleList() 403 | self.upsamples = nn.ModuleList() 404 | self.to_rgbs = nn.ModuleList() 405 | self.noises = nn.Module() 406 | 407 | in_channel = self.channels[4] 408 | 409 | for layer_idx in range(self.num_layers): 410 | res = (layer_idx + 5) // 2 411 | shape = [1, 1, 2 ** res, 2 ** res] 412 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 413 | 414 | for i in range(3, self.log_size + 1): 415 | out_channel = self.channels[2 ** i] 416 | 417 | self.convs.append( 418 | StyledConv( 419 | in_channel, 420 | out_channel, 421 | 3, 422 | style_dim, 423 | upsample=True, 424 | blur_kernel=blur_kernel, 425 | ) 426 | ) 427 | 428 | self.convs.append( 429 | StyledConv( 430 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 431 | ) 432 | ) 433 | 434 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 435 | 436 | in_channel = out_channel 437 | 438 | self.n_latent = self.log_size * 2 - 2 439 | 440 | def make_noise(self): 441 | device = self.input.input.device 442 | 443 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 444 | 445 | for i in range(3, self.log_size + 1): 446 | for _ in range(2): 447 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 448 | 449 | return noises 450 | 451 | def mean_latent(self, n_latent): 452 | latent_in = torch.randn( 453 | n_latent, self.style_dim, device=self.input.input.device 454 | ) 455 | latent = self.style(latent_in).mean(0, keepdim=True) 456 | 457 | return latent 458 | 459 | def get_latent(self, input): 460 | return self.style(input) 461 | 462 | def forward( 463 | self, 464 | styles, 465 | return_latents=False, 466 | inject_index=None, 467 | truncation=1, 468 | truncation_latent=None, 469 | input_is_latent=False, 470 | noise=None, 471 | randomize_noise=True, 472 | ): 473 | if not input_is_latent: 474 | styles = [self.style(s) for s in styles] 475 | 476 | if noise is None: 477 | if randomize_noise: 478 | noise = [None] * self.num_layers 479 | else: 480 | noise = [ 481 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 482 | ] 483 | 484 | if truncation < 1: 485 | style_t = [] 486 | 487 | for style in styles: 488 | style_t.append( 489 | truncation_latent + truncation * (style - truncation_latent) 490 | ) 491 | 492 | styles = style_t 493 | 494 | if len(styles) < 2: 495 | inject_index = self.n_latent 496 | 497 | if styles[0].ndim < 3: 498 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 499 | 500 | else: 501 | latent = styles[0] 502 | 503 | else: 504 | if inject_index is None: 505 | inject_index = random.randint(1, self.n_latent - 1) 506 | 507 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 508 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 509 | 510 | latent = torch.cat([latent, latent2], 1) 511 | 512 | out = self.input(latent) 513 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 514 | 515 | skip = self.to_rgb1(out, latent[:, 1]) 516 | 517 | i = 1 518 | for conv1, conv2, noise1, noise2, to_rgb in zip( 519 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 520 | ): 521 | out = conv1(out, latent[:, i], noise=noise1) 522 | out = conv2(out, latent[:, i + 1], noise=noise2) 523 | skip = to_rgb(out, latent[:, i + 2], skip) 524 | 525 | i += 2 526 | 527 | image = skip 528 | 529 | if return_latents: 530 | return image, latent 531 | 532 | else: 533 | return image, None 534 | 535 | 536 | class ConvLayer(nn.Sequential): 537 | def __init__( 538 | self, 539 | in_channel, 540 | out_channel, 541 | kernel_size, 542 | downsample=False, 543 | blur_kernel=[1, 3, 3, 1], 544 | bias=True, 545 | activate=True, 546 | ): 547 | layers = [] 548 | 549 | if downsample: 550 | factor = 2 551 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 552 | pad0 = (p + 1) // 2 553 | pad1 = p // 2 554 | 555 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 556 | 557 | stride = 2 558 | self.padding = 0 559 | 560 | else: 561 | stride = 1 562 | self.padding = kernel_size // 2 563 | 564 | layers.append( 565 | EqualConv2d( 566 | in_channel, 567 | out_channel, 568 | kernel_size, 569 | padding=self.padding, 570 | stride=stride, 571 | bias=bias and not activate, 572 | ) 573 | ) 574 | 575 | if activate: 576 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 577 | 578 | super().__init__(*layers) 579 | 580 | 581 | class ResBlock(nn.Module): 582 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 583 | super().__init__() 584 | 585 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 586 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 587 | 588 | self.skip = ConvLayer( 589 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 590 | ) 591 | 592 | def forward(self, input): 593 | out = self.conv1(input) 594 | out = self.conv2(out) 595 | 596 | skip = self.skip(input) 597 | out = (out + skip) / math.sqrt(2) 598 | 599 | return out 600 | 601 | 602 | class Discriminator(nn.Module): 603 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 604 | super().__init__() 605 | 606 | channels = { 607 | 4: 512, 608 | 8: 512, 609 | 16: 512, 610 | 32: 512, 611 | 64: 256 * channel_multiplier, 612 | 128: 128 * channel_multiplier, 613 | 256: 64 * channel_multiplier, 614 | 512: 32 * channel_multiplier, 615 | 1024: 16 * channel_multiplier, 616 | } 617 | 618 | convs = [ConvLayer(1, channels[size], 1)] 619 | 620 | log_size = int(math.log(size, 2)) 621 | 622 | in_channel = channels[size] 623 | 624 | for i in range(log_size, 2, -1): 625 | out_channel = channels[2 ** (i - 1)] 626 | 627 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 628 | 629 | in_channel = out_channel 630 | 631 | self.convs = nn.Sequential(*convs) 632 | 633 | self.stddev_group = 4 634 | self.stddev_feat = 1 635 | 636 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 637 | self.final_linear = nn.Sequential( 638 | EqualLinear(channels[4] * 5 * 20, channels[4], activation="fused_lrelu"), 639 | EqualLinear(channels[4], 1), 640 | ) 641 | 642 | def forward(self, input): 643 | out = self.convs(input) 644 | 645 | batch, channel, height, width = out.shape 646 | group = min(batch, self.stddev_group) 647 | stddev = out.view( 648 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 649 | ) 650 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 651 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 652 | stddev = stddev.repeat(group, 1, height, width) 653 | out = torch.cat([out, stddev], 1) 654 | 655 | out = self.final_conv(out) 656 | 657 | out = out.view(batch, -1) 658 | out = self.final_linear(out) 659 | 660 | return out 661 | 662 | if __name__ == "__main__": 663 | generator = Generator(size=64, 664 | style_dim=512, 665 | n_mlp=8, 666 | channel_multiplier=2, 667 | blur_kernel=[1, 3, 3, 1], 668 | lr_mlp=0.01) 669 | noise = torch.randn(1, 16, 512).unbind(0) 670 | fake_img, _ = generator(noise) #[16, 80, 320] 671 | dis = Discriminator(size=64) 672 | fake_output = dis(fake_img) 673 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /model_drum_four_bar.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer("kernel", kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer("kernel", kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer("kernel", kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 166 | ) 167 | 168 | 169 | class ModulatedConv2d(nn.Module): 170 | def __init__( 171 | self, 172 | in_channel, 173 | out_channel, 174 | kernel_size, 175 | style_dim, 176 | demodulate=True, 177 | upsample=False, 178 | downsample=False, 179 | blur_kernel=[1, 3, 3, 1], 180 | ): 181 | super().__init__() 182 | 183 | self.eps = 1e-8 184 | self.kernel_size = kernel_size 185 | self.in_channel = in_channel 186 | self.out_channel = out_channel 187 | self.upsample = upsample 188 | self.downsample = downsample 189 | 190 | if upsample: 191 | factor = 2 192 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 193 | pad0 = (p + 1) // 2 + factor - 1 194 | pad1 = p // 2 + 1 195 | 196 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 197 | 198 | if downsample: 199 | factor = 2 200 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 201 | pad0 = (p + 1) // 2 202 | pad1 = p // 2 203 | 204 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 205 | 206 | fan_in = in_channel * kernel_size ** 2 207 | self.scale = 1 / math.sqrt(fan_in) 208 | self.padding = kernel_size // 2 209 | 210 | self.weight = nn.Parameter( 211 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 212 | ) 213 | 214 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 215 | 216 | self.demodulate = demodulate 217 | 218 | def __repr__(self): 219 | return ( 220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 221 | f"upsample={self.upsample}, downsample={self.downsample})" 222 | ) 223 | 224 | def forward(self, input, style): 225 | batch, in_channel, height, width = input.shape 226 | 227 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 228 | weight = self.scale * self.weight * style 229 | 230 | if self.demodulate: 231 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 232 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 233 | 234 | weight = weight.view( 235 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 236 | ) 237 | 238 | if self.upsample: 239 | input = input.view(1, batch * in_channel, height, width) 240 | weight = weight.view( 241 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 242 | ) 243 | weight = weight.transpose(1, 2).reshape( 244 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 245 | ) 246 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 247 | _, _, height, width = out.shape 248 | out = out.view(batch, self.out_channel, height, width) 249 | out = self.blur(out) 250 | 251 | elif self.downsample: 252 | input = self.blur(input) 253 | _, _, height, width = input.shape 254 | input = input.view(1, batch * in_channel, height, width) 255 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 256 | _, _, height, width = out.shape 257 | out = out.view(batch, self.out_channel, height, width) 258 | 259 | else: 260 | input = input.view(1, batch * in_channel, height, width) 261 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 262 | _, _, height, width = out.shape 263 | out = out.view(batch, self.out_channel, height, width) 264 | 265 | return out 266 | 267 | 268 | class NoiseInjection(nn.Module): 269 | def __init__(self): 270 | super().__init__() 271 | 272 | self.weight = nn.Parameter(torch.zeros(1)) 273 | 274 | def forward(self, image, noise=None): 275 | if noise is None: 276 | batch, _, height, width = image.shape 277 | noise = image.new_empty(batch, 1, height, width).normal_() 278 | 279 | return image + self.weight * noise 280 | 281 | 282 | class ConstantInput(nn.Module): 283 | def __init__(self, channel, size=4): 284 | super().__init__() 285 | 286 | self.input = nn.Parameter(torch.randn(1, channel, 5, 80)) 287 | 288 | def forward(self, input): 289 | batch = input.shape[0] 290 | out = self.input.repeat(batch, 1, 1, 1) 291 | 292 | return out 293 | 294 | 295 | class StyledConv(nn.Module): 296 | def __init__( 297 | self, 298 | in_channel, 299 | out_channel, 300 | kernel_size, 301 | style_dim, 302 | upsample=False, 303 | blur_kernel=[1, 3, 3, 1], 304 | demodulate=True, 305 | ): 306 | super().__init__() 307 | 308 | self.conv = ModulatedConv2d( 309 | in_channel, 310 | out_channel, 311 | kernel_size, 312 | style_dim, 313 | upsample=upsample, 314 | blur_kernel=blur_kernel, 315 | demodulate=demodulate, 316 | ) 317 | 318 | self.noise = NoiseInjection() 319 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 320 | # self.activate = ScaledLeakyReLU(0.2) 321 | self.activate = FusedLeakyReLU(out_channel) 322 | 323 | def forward(self, input, style, noise=None): 324 | out = self.conv(input, style) 325 | out = self.noise(out, noise=noise) 326 | # out = out + self.bias 327 | out = self.activate(out) 328 | 329 | return out 330 | 331 | 332 | class ToRGB(nn.Module): 333 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 334 | super().__init__() 335 | 336 | if upsample: 337 | self.upsample = Upsample(blur_kernel) 338 | 339 | self.conv = ModulatedConv2d(in_channel, 1, 1, style_dim, demodulate=False) 340 | self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 341 | 342 | def forward(self, input, style, skip=None): 343 | out = self.conv(input, style) 344 | out = out + self.bias 345 | 346 | if skip is not None: 347 | skip = self.upsample(skip) 348 | 349 | out = out + skip 350 | 351 | return out 352 | 353 | 354 | class Generator(nn.Module): 355 | def __init__( 356 | self, 357 | size, 358 | style_dim, 359 | n_mlp, 360 | channel_multiplier=2, 361 | blur_kernel=[1, 3, 3, 1], 362 | lr_mlp=0.01, 363 | ): 364 | super().__init__() 365 | 366 | self.size = size 367 | 368 | self.style_dim = style_dim 369 | 370 | layers = [PixelNorm()] 371 | 372 | for i in range(n_mlp): 373 | layers.append( 374 | EqualLinear( 375 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 376 | ) 377 | ) 378 | 379 | self.style = nn.Sequential(*layers) 380 | 381 | self.channels = { 382 | 4: 512, 383 | 8: 512, 384 | 16: 512, 385 | 32: 512, 386 | 64: 256 * channel_multiplier, 387 | 128: 128 * channel_multiplier, 388 | 256: 64 * channel_multiplier, 389 | 512: 32 * channel_multiplier, 390 | 1024: 16 * channel_multiplier, 391 | } 392 | 393 | self.input = ConstantInput(self.channels[4]) 394 | self.conv1 = StyledConv( 395 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 396 | ) 397 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 398 | 399 | self.log_size = int(math.log(size, 2)) 400 | self.num_layers = (self.log_size - 2) * 2 + 1 401 | 402 | self.convs = nn.ModuleList() 403 | self.upsamples = nn.ModuleList() 404 | self.to_rgbs = nn.ModuleList() 405 | self.noises = nn.Module() 406 | 407 | in_channel = self.channels[4] 408 | 409 | for layer_idx in range(self.num_layers): 410 | res = (layer_idx + 5) // 2 411 | shape = [1, 1, 2 ** res, 2 ** res] 412 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 413 | 414 | for i in range(3, self.log_size + 1): 415 | out_channel = self.channels[2 ** i] 416 | 417 | self.convs.append( 418 | StyledConv( 419 | in_channel, 420 | out_channel, 421 | 3, 422 | style_dim, 423 | upsample=True, 424 | blur_kernel=blur_kernel, 425 | ) 426 | ) 427 | 428 | self.convs.append( 429 | StyledConv( 430 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 431 | ) 432 | ) 433 | 434 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 435 | 436 | in_channel = out_channel 437 | 438 | self.n_latent = self.log_size * 2 - 2 439 | 440 | def make_noise(self): 441 | device = self.input.input.device 442 | 443 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 444 | 445 | for i in range(3, self.log_size + 1): 446 | for _ in range(2): 447 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 448 | 449 | return noises 450 | 451 | def mean_latent(self, n_latent): 452 | latent_in = torch.randn( 453 | n_latent, self.style_dim, device=self.input.input.device 454 | ) 455 | latent = self.style(latent_in).mean(0, keepdim=True) 456 | 457 | return latent 458 | 459 | def get_latent(self, input): 460 | return self.style(input) 461 | 462 | def forward( 463 | self, 464 | styles, 465 | return_latents=False, 466 | inject_index=None, 467 | truncation=1, 468 | truncation_latent=None, 469 | input_is_latent=False, 470 | noise=None, 471 | randomize_noise=True, 472 | ): 473 | if not input_is_latent: 474 | styles = [self.style(s) for s in styles] 475 | 476 | if noise is None: 477 | if randomize_noise: 478 | noise = [None] * self.num_layers 479 | else: 480 | noise = [ 481 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 482 | ] 483 | 484 | if truncation < 1: 485 | style_t = [] 486 | 487 | for style in styles: 488 | style_t.append( 489 | truncation_latent + truncation * (style - truncation_latent) 490 | ) 491 | 492 | styles = style_t 493 | 494 | if len(styles) < 2: 495 | inject_index = self.n_latent 496 | 497 | if styles[0].ndim < 3: 498 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 499 | 500 | else: 501 | latent = styles[0] 502 | 503 | else: 504 | if inject_index is None: 505 | inject_index = random.randint(1, self.n_latent - 1) 506 | 507 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 508 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 509 | 510 | latent = torch.cat([latent, latent2], 1) 511 | 512 | out = self.input(latent) 513 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 514 | 515 | skip = self.to_rgb1(out, latent[:, 1]) 516 | 517 | i = 1 518 | for conv1, conv2, noise1, noise2, to_rgb in zip( 519 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 520 | ): 521 | out = conv1(out, latent[:, i], noise=noise1) 522 | out = conv2(out, latent[:, i + 1], noise=noise2) 523 | skip = to_rgb(out, latent[:, i + 2], skip) 524 | 525 | i += 2 526 | 527 | image = skip 528 | 529 | if return_latents: 530 | return image, latent 531 | 532 | else: 533 | return image, None 534 | 535 | 536 | class ConvLayer(nn.Sequential): 537 | def __init__( 538 | self, 539 | in_channel, 540 | out_channel, 541 | kernel_size, 542 | downsample=False, 543 | blur_kernel=[1, 3, 3, 1], 544 | bias=True, 545 | activate=True, 546 | ): 547 | layers = [] 548 | 549 | if downsample: 550 | factor = 2 551 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 552 | pad0 = (p + 1) // 2 553 | pad1 = p // 2 554 | 555 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 556 | 557 | stride = 2 558 | self.padding = 0 559 | 560 | else: 561 | stride = 1 562 | self.padding = kernel_size // 2 563 | 564 | layers.append( 565 | EqualConv2d( 566 | in_channel, 567 | out_channel, 568 | kernel_size, 569 | padding=self.padding, 570 | stride=stride, 571 | bias=bias and not activate, 572 | ) 573 | ) 574 | 575 | if activate: 576 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 577 | 578 | super().__init__(*layers) 579 | 580 | 581 | class ResBlock(nn.Module): 582 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 583 | super().__init__() 584 | 585 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 586 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 587 | 588 | self.skip = ConvLayer( 589 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 590 | ) 591 | 592 | def forward(self, input): 593 | out = self.conv1(input) 594 | out = self.conv2(out) 595 | 596 | skip = self.skip(input) 597 | out = (out + skip) / math.sqrt(2) 598 | 599 | return out 600 | 601 | 602 | class Discriminator(nn.Module): 603 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 604 | super().__init__() 605 | 606 | channels = { 607 | 4: 512, 608 | 8: 512, 609 | 16: 512, 610 | 32: 512, 611 | 64: 256 * channel_multiplier, 612 | 128: 128 * channel_multiplier, 613 | 256: 64 * channel_multiplier, 614 | 512: 32 * channel_multiplier, 615 | 1024: 16 * channel_multiplier, 616 | } 617 | 618 | convs = [ConvLayer(1, channels[size], 1)] 619 | 620 | log_size = int(math.log(size, 2)) 621 | 622 | in_channel = channels[size] 623 | 624 | for i in range(log_size, 2, -1): 625 | out_channel = channels[2 ** (i - 1)] 626 | 627 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 628 | 629 | in_channel = out_channel 630 | 631 | self.convs = nn.Sequential(*convs) 632 | 633 | self.stddev_group = 4 634 | self.stddev_feat = 1 635 | 636 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 637 | self.final_linear = nn.Sequential( 638 | EqualLinear(channels[4] * 5 * 80 , channels[4], activation="fused_lrelu"), 639 | EqualLinear(channels[4], 1), 640 | ) 641 | 642 | def forward(self, input): 643 | out = self.convs(input) 644 | 645 | batch, channel, height, width = out.shape 646 | group = min(batch, self.stddev_group) 647 | stddev = out.view( 648 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 649 | ) 650 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 651 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 652 | stddev = stddev.repeat(group, 1, height, width) 653 | out = torch.cat([out, stddev], 1) 654 | 655 | out = self.final_conv(out) 656 | out = out.view(batch, -1) 657 | out = self.final_linear(out) 658 | 659 | return out 660 | 661 | if __name__ == "__main__": 662 | generator = Generator(size=64, 663 | style_dim=512, 664 | n_mlp=8, 665 | channel_multiplier=2, 666 | blur_kernel=[1, 3, 3, 1], 667 | lr_mlp=0.01).cuda() 668 | noise = torch.randn(1, 2, 512).cuda().unbind(0) 669 | fake_img, _ = generator(noise) #[4, 80, 320] 670 | #fake_img = torch.randn(4, 1,80, 1280).cuda() 671 | dis = Discriminator(size=64).cuda() 672 | fake_output = dis(fake_img) 673 | import pdb; pdb.set_trace() 674 | -------------------------------------------------------------------------------- /non_leaking.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from distributed import reduce_sum 7 | from op import upfirdn2d 8 | 9 | 10 | class AdaptiveAugment: 11 | def __init__(self, ada_aug_target, ada_aug_len, update_every, device): 12 | self.ada_aug_target = ada_aug_target 13 | self.ada_aug_len = ada_aug_len 14 | self.update_every = update_every 15 | 16 | self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device) 17 | self.r_t_stat = 0 18 | self.ada_aug_p = 0 19 | 20 | @torch.no_grad() 21 | def tune(self, real_pred): 22 | ada_aug_data = torch.tensor( 23 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]), 24 | device=real_pred.device, 25 | ) 26 | self.ada_aug_buf += reduce_sum(ada_aug_data) 27 | 28 | if self.ada_aug_buf[1] > self.update_every - 1: 29 | pred_signs, n_pred = self.ada_aug_buf.tolist() 30 | 31 | self.r_t_stat = pred_signs / n_pred 32 | 33 | if self.r_t_stat > self.ada_aug_target: 34 | sign = 1 35 | 36 | else: 37 | sign = -1 38 | 39 | self.ada_aug_p += sign * n_pred / self.ada_aug_len 40 | self.ada_aug_p = min(1, max(0, self.ada_aug_p)) 41 | self.ada_aug_buf.mul_(0) 42 | 43 | return self.ada_aug_p 44 | 45 | 46 | SYM6 = ( 47 | 0.015404109327027373, 48 | 0.0034907120842174702, 49 | -0.11799011114819057, 50 | -0.048311742585633, 51 | 0.4910559419267466, 52 | 0.787641141030194, 53 | 0.3379294217276218, 54 | -0.07263752278646252, 55 | -0.021060292512300564, 56 | 0.04472490177066578, 57 | 0.0017677118642428036, 58 | -0.007800708325034148, 59 | ) 60 | 61 | 62 | def translate_mat(t_x, t_y): 63 | batch = t_x.shape[0] 64 | 65 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 66 | translate = torch.stack((t_x, t_y), 1) 67 | mat[:, :2, 2] = translate 68 | 69 | return mat 70 | 71 | 72 | def rotate_mat(theta): 73 | batch = theta.shape[0] 74 | 75 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 76 | sin_t = torch.sin(theta) 77 | cos_t = torch.cos(theta) 78 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) 79 | mat[:, :2, :2] = rot 80 | 81 | return mat 82 | 83 | 84 | def scale_mat(s_x, s_y): 85 | batch = s_x.shape[0] 86 | 87 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 88 | mat[:, 0, 0] = s_x 89 | mat[:, 1, 1] = s_y 90 | 91 | return mat 92 | 93 | 94 | def translate3d_mat(t_x, t_y, t_z): 95 | batch = t_x.shape[0] 96 | 97 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 98 | translate = torch.stack((t_x, t_y, t_z), 1) 99 | mat[:, :3, 3] = translate 100 | 101 | return mat 102 | 103 | 104 | def rotate3d_mat(axis, theta): 105 | batch = theta.shape[0] 106 | 107 | u_x, u_y, u_z = axis 108 | 109 | eye = torch.eye(3).unsqueeze(0) 110 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) 111 | outer = torch.tensor(axis) 112 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0) 113 | 114 | sin_t = torch.sin(theta).view(-1, 1, 1) 115 | cos_t = torch.cos(theta).view(-1, 1, 1) 116 | 117 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer 118 | 119 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 120 | eye_4[:, :3, :3] = rot 121 | 122 | return eye_4 123 | 124 | 125 | def scale3d_mat(s_x, s_y, s_z): 126 | batch = s_x.shape[0] 127 | 128 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 129 | mat[:, 0, 0] = s_x 130 | mat[:, 1, 1] = s_y 131 | mat[:, 2, 2] = s_z 132 | 133 | return mat 134 | 135 | 136 | def luma_flip_mat(axis, i): 137 | batch = i.shape[0] 138 | 139 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 140 | axis = torch.tensor(axis + (0,)) 141 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) 142 | 143 | return eye - flip 144 | 145 | 146 | def saturation_mat(axis, i): 147 | batch = i.shape[0] 148 | 149 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 150 | axis = torch.tensor(axis + (0,)) 151 | axis = torch.ger(axis, axis) 152 | saturate = axis + (eye - axis) * i.view(-1, 1, 1) 153 | 154 | return saturate 155 | 156 | 157 | def lognormal_sample(size, mean=0, std=1): 158 | return torch.empty(size).log_normal_(mean=mean, std=std) 159 | 160 | 161 | def category_sample(size, categories): 162 | category = torch.tensor(categories) 163 | sample = torch.randint(high=len(categories), size=(size,)) 164 | 165 | return category[sample] 166 | 167 | 168 | def uniform_sample(size, low, high): 169 | return torch.empty(size).uniform_(low, high) 170 | 171 | 172 | def normal_sample(size, mean=0, std=1): 173 | return torch.empty(size).normal_(mean, std) 174 | 175 | 176 | def bernoulli_sample(size, p): 177 | return torch.empty(size).bernoulli_(p) 178 | 179 | 180 | def random_mat_apply(p, transform, prev, eye): 181 | size = transform.shape[0] 182 | select = bernoulli_sample(size, p).view(size, 1, 1) 183 | select_transform = select * transform + (1 - select) * eye 184 | 185 | return select_transform @ prev 186 | 187 | 188 | def sample_affine(p, size, height, width): 189 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) 190 | eye = G 191 | 192 | # flip 193 | param = category_sample(size, (0, 1)) 194 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) 195 | G = random_mat_apply(p, Gc, G, eye) 196 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') 197 | 198 | # 90 rotate 199 | param = category_sample(size, (0, 3)) 200 | Gc = rotate_mat(-math.pi / 2 * param) 201 | G = random_mat_apply(p, Gc, G, eye) 202 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') 203 | 204 | # integer translate 205 | param = uniform_sample(size, -0.125, 0.125) 206 | param_height = torch.round(param * height) / height 207 | param_width = torch.round(param * width) / width 208 | Gc = translate_mat(param_width, param_height) 209 | G = random_mat_apply(p, Gc, G, eye) 210 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') 211 | 212 | # isotropic scale 213 | param = lognormal_sample(size, std=0.2 * math.log(2)) 214 | Gc = scale_mat(param, param) 215 | G = random_mat_apply(p, Gc, G, eye) 216 | # print('isotropic scale', G, scale_mat(param, param), sep='\n') 217 | 218 | p_rot = 1 - math.sqrt(1 - p) 219 | 220 | # pre-rotate 221 | param = uniform_sample(size, -math.pi, math.pi) 222 | Gc = rotate_mat(-param) 223 | G = random_mat_apply(p_rot, Gc, G, eye) 224 | # print('pre-rotate', G, rotate_mat(-param), sep='\n') 225 | 226 | # anisotropic scale 227 | param = lognormal_sample(size, std=0.2 * math.log(2)) 228 | Gc = scale_mat(param, 1 / param) 229 | G = random_mat_apply(p, Gc, G, eye) 230 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') 231 | 232 | # post-rotate 233 | param = uniform_sample(size, -math.pi, math.pi) 234 | Gc = rotate_mat(-param) 235 | G = random_mat_apply(p_rot, Gc, G, eye) 236 | # print('post-rotate', G, rotate_mat(-param), sep='\n') 237 | 238 | # fractional translate 239 | param = normal_sample(size, std=0.125) 240 | Gc = translate_mat(param, param) 241 | G = random_mat_apply(p, Gc, G, eye) 242 | # print('fractional translate', G, translate_mat(param, param), sep='\n') 243 | 244 | return G 245 | 246 | 247 | def sample_color(p, size): 248 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) 249 | eye = C 250 | axis_val = 1 / math.sqrt(3) 251 | axis = (axis_val, axis_val, axis_val) 252 | 253 | # brightness 254 | param = normal_sample(size, std=0.2) 255 | Cc = translate3d_mat(param, param, param) 256 | C = random_mat_apply(p, Cc, C, eye) 257 | 258 | # contrast 259 | param = lognormal_sample(size, std=0.5 * math.log(2)) 260 | Cc = scale3d_mat(param, param, param) 261 | C = random_mat_apply(p, Cc, C, eye) 262 | 263 | # luma flip 264 | param = category_sample(size, (0, 1)) 265 | Cc = luma_flip_mat(axis, param) 266 | C = random_mat_apply(p, Cc, C, eye) 267 | 268 | # hue rotation 269 | param = uniform_sample(size, -math.pi, math.pi) 270 | Cc = rotate3d_mat(axis, param) 271 | C = random_mat_apply(p, Cc, C, eye) 272 | 273 | # saturation 274 | param = lognormal_sample(size, std=1 * math.log(2)) 275 | Cc = saturation_mat(axis, param) 276 | C = random_mat_apply(p, Cc, C, eye) 277 | 278 | return C 279 | 280 | 281 | def make_grid(shape, x0, x1, y0, y1, device): 282 | n, c, h, w = shape 283 | grid = torch.empty(n, h, w, 3, device=device) 284 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) 285 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) 286 | grid[:, :, :, 2] = 1 287 | 288 | return grid 289 | 290 | 291 | def affine_grid(grid, mat): 292 | n, h, w, _ = grid.shape 293 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) 294 | 295 | 296 | def get_padding(G, height, width): 297 | extreme = ( 298 | G[:, :2, :] 299 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t() 300 | ) 301 | 302 | size = torch.tensor((width, height)) 303 | 304 | pad_low = ( 305 | ((extreme.min(-1).values + 1) * size) 306 | .clamp(max=0) 307 | .abs() 308 | .ceil() 309 | .max(0) 310 | .values.to(torch.int64) 311 | .tolist() 312 | ) 313 | pad_high = ( 314 | (extreme.max(-1).values * size - size) 315 | .clamp(min=0) 316 | .ceil() 317 | .max(0) 318 | .values.to(torch.int64) 319 | .tolist() 320 | ) 321 | 322 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1] 323 | 324 | 325 | def try_sample_affine_and_pad(img, p, pad_k, G=None): 326 | batch, _, height, width = img.shape 327 | 328 | G_try = G 329 | 330 | while True: 331 | if G is None: 332 | G_try = sample_affine(p, batch, height, width) 333 | 334 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding( 335 | torch.inverse(G_try), height, width 336 | ) 337 | 338 | try: 339 | img_pad = F.pad( 340 | img, 341 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k), 342 | mode="reflect", 343 | ) 344 | 345 | except RuntimeError: 346 | continue 347 | 348 | break 349 | 350 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) 351 | 352 | 353 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): 354 | kernel = antialiasing_kernel 355 | len_k = len(kernel) 356 | pad_k = (len_k + 1) // 2 357 | 358 | kernel = torch.as_tensor(kernel) 359 | kernel = torch.ger(kernel, kernel).to(img) 360 | kernel_flip = torch.flip(kernel, (0, 1)) 361 | 362 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( 363 | img, p, pad_k, G 364 | ) 365 | 366 | p_ux1 = pad_x1 367 | p_ux2 = pad_x2 + 1 368 | p_uy1 = pad_y1 369 | p_uy2 = pad_y2 + 1 370 | w_p = img_pad.shape[3] - len_k + 1 371 | h_p = img_pad.shape[2] - len_k + 1 372 | h_o = img.shape[2] 373 | w_o = img.shape[3] 374 | 375 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2) 376 | 377 | grid = make_grid( 378 | img_2x.shape, 379 | -2 * p_ux1 / w_o - 1, 380 | 2 * (w_p - p_ux1) / w_o - 1, 381 | -2 * p_uy1 / h_o - 1, 382 | 2 * (h_p - p_uy1) / h_o - 1, 383 | device=img_2x.device, 384 | ).to(img_2x) 385 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x)) 386 | grid = grid * torch.tensor( 387 | [w_o / w_p, h_o / h_p], device=grid.device 388 | ) + torch.tensor( 389 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device 390 | ) 391 | 392 | img_affine = F.grid_sample( 393 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros" 394 | ) 395 | 396 | img_down = upfirdn2d(img_affine, kernel, down=2) 397 | 398 | end_y = -pad_y2 - 1 399 | if end_y == 0: 400 | end_y = img_down.shape[2] 401 | 402 | end_x = -pad_x2 - 1 403 | if end_x == 0: 404 | end_x = img_down.shape[3] 405 | 406 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x] 407 | 408 | return img, G 409 | 410 | 411 | def apply_color(img, mat): 412 | batch = img.shape[0] 413 | img = img.permute(0, 2, 3, 1) 414 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) 415 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) 416 | img = img @ mat_mul + mat_add 417 | img = img.permute(0, 3, 1, 2) 418 | 419 | return img 420 | 421 | 422 | def random_apply_color(img, p, C=None): 423 | if C is None: 424 | C = sample_color(p, img.shape[0]) 425 | 426 | img = apply_color(img, C.to(img)) 427 | 428 | return img, C 429 | 430 | 431 | def augment(img, p, transform_matrix=(None, None)): 432 | img, G = random_apply_affine(img, p, transform_matrix[0]) 433 | img, C = random_apply_color(img, p, transform_matrix[1]) 434 | 435 | return img, (G, C) 436 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import tempfile 4 | from pathlib import Path 5 | import os 6 | import torch 7 | import numpy as np 8 | import soundfile as sf 9 | import yaml 10 | import cog 11 | 12 | from model_drum_four_bar import Generator 13 | 14 | sys.path.append("./melgan") 15 | 16 | from modules import Generator_melgan 17 | 18 | 19 | class Predictor(cog.Predictor): 20 | def setup(self): 21 | self.device = "cuda" 22 | checkpoint_path = "checkpoint-four-bar.pt" 23 | self.latent = 512 24 | 25 | self.g_ema = Generator( 26 | size=64, 27 | style_dim=self.latent, 28 | n_mlp=8, 29 | channel_multiplier=2, 30 | ).to(self.device) 31 | checkpoint = torch.load(checkpoint_path) 32 | self.g_ema.load_state_dict(checkpoint["g_ema"]) 33 | self.g_ema.eval() 34 | 35 | data_path = "data/looperman_four_bar" 36 | feat_dim = 80 37 | mean_fp = f"{data_path}/mean.mel.npy" 38 | std_fp = f"{data_path}/std.mel.npy" 39 | self.mean = ( 40 | torch.from_numpy(np.load(mean_fp)) 41 | .float() 42 | .view(1, feat_dim, 1) 43 | .to(self.device) 44 | ) 45 | self.std = ( 46 | torch.from_numpy(np.load(std_fp)) 47 | .float() 48 | .view(1, feat_dim, 1) 49 | .to(self.device) 50 | ) 51 | vocoder_config_fp = "melgan/args.yml" 52 | vocoder_config = read_yaml(vocoder_config_fp) 53 | n_mel_channels = vocoder_config.n_mel_channels 54 | ngf = vocoder_config.ngf 55 | n_residual_layers = vocoder_config.n_residual_layers 56 | self.sr = 44100 57 | 58 | self.vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to( 59 | self.device 60 | ) 61 | self.vocoder.eval() 62 | vocoder_param_fp = "melgan/best_netG.pt" 63 | self.vocoder.load_state_dict(torch.load(vocoder_param_fp)) 64 | 65 | @cog.input("seed", type=int, default=-1, help="Random seed, -1 for random") 66 | def predict(self, seed): 67 | if seed < 0: 68 | seed = int.from_bytes(os.urandom(2), "big") 69 | torch.manual_seed(seed) 70 | np.random.seed(seed) 71 | print(f"Prediction seed: {seed}") 72 | 73 | sample_z = torch.randn(1, self.latent, device=self.device) 74 | sample, _ = self.g_ema([sample_z], truncation=1, truncation_latent=None) 75 | de_norm = sample.squeeze(0) * self.std + self.mean 76 | audio_output = self.vocoder(de_norm) 77 | out_dir = Path(tempfile.mkdtemp()) 78 | wav_path = out_dir / "out.wav" 79 | mp3_path = out_dir / "out.mp3" 80 | 81 | try: 82 | sf.write( 83 | str(wav_path), audio_output.squeeze().detach().cpu().numpy(), self.sr 84 | ) 85 | subprocess.check_output( 86 | [ 87 | "ffmpeg", 88 | "-i", 89 | str(wav_path), 90 | str(mp3_path), 91 | ], 92 | ) 93 | return mp3_path 94 | finally: 95 | wav_path.unlink(missing_ok=True) 96 | 97 | 98 | def read_yaml(fp): 99 | with open(fp) as file: 100 | # return yaml.load(file) 101 | return yaml.load(file, Loader=yaml.Loader) 102 | -------------------------------------------------------------------------------- /preprocess/collect_audio_clips.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import numpy as np 3 | import librosa 4 | from pydub import AudioSegment 5 | # from shutil import copyfile 6 | from multiprocessing import Pool 7 | 8 | 9 | def process_one(args): 10 | fn, out_dir = args 11 | 12 | in_fp = os.path.join(audio_dir, f'{fn}.wav') 13 | if not os.path.exists(in_fp): 14 | print('Not exists') 15 | return 16 | 17 | duration = librosa.get_duration(filename=in_fp) 18 | 19 | # duration = song.duration_seconds 20 | num_subclips = int(duration // subclip_duration) 21 | # num_subclips = int(np.ceil(duration / subclip_duration)) 22 | 23 | try: 24 | song = AudioSegment.from_wav(in_fp) 25 | except Exception: 26 | print('Error in loading') 27 | return 28 | 29 | for ii in range(num_subclips): 30 | start = ii*subclip_duration 31 | end = (ii+1)*subclip_duration 32 | print(fn, start, end) 33 | 34 | out_fp = os.path.join(out_dir, f'{fn}.{start}_{end}.wav') 35 | if os.path.exists(out_fp): 36 | print('Done before') 37 | continue 38 | 39 | subclip = song[start*1000:end*1000] 40 | subclip.export(out_fp, format='wav') 41 | 42 | 43 | if __name__ == '__main__': 44 | audio_dir = '/home/allenhung/nas189/Database/Looper_man/drum_loops/audio' # Clean audios or separated audios from mixture 45 | out_dir = './training_data/drum_clips_7.9/' 46 | 47 | subclip_duration = 7.9 48 | sr = 22050 49 | ext = '.wav' 50 | 51 | # ### Process ### 52 | num_samples = int(round(subclip_duration * sr)) 53 | os.makedirs(out_dir, exist_ok=True) 54 | 55 | fns = [fn.replace(ext, '') for fn in os.listdir(audio_dir) if fn.endswith('.wav')] 56 | print(fns) 57 | pool = Pool(10) 58 | 59 | args_list = [] 60 | 61 | for fn in fns: 62 | args_list.append((fn, out_dir)) 63 | 64 | pool.map(process_one, args_list) 65 | -------------------------------------------------------------------------------- /preprocess/compute_mean_std.mel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | feat_type = 'mel' 10 | exp_dir = '/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar/mel_80_320' 11 | 12 | out_dir = exp_dir 13 | 14 | # ### Process ### 15 | 16 | dataset_fp = os.path.join(exp_dir, f'dataset.pkl') 17 | #feat_dir = os.path.join(exp_dir, feat_type) 18 | feat_dir = exp_dir 19 | out_fp_mean = os.path.join(out_dir, f'mean.{feat_type}.npy') 20 | out_fp_std = os.path.join(out_dir, f'std.{feat_type}.npy') 21 | 22 | with open(dataset_fp, 'rb') as f: 23 | dataset = pickle.load(f) 24 | 25 | in_fns = [fn for fn, _ in dataset] 26 | 27 | scaler = StandardScaler() 28 | 29 | for fn in in_fns: 30 | print(fn) 31 | in_fp = os.path.join(feat_dir, f'{fn}.npy') 32 | data = np.load(in_fp).T 33 | print(data.shape) 34 | print('data: ', data) 35 | scaler.partial_fit(data) 36 | print(scaler.mean_, scaler.scale_) 37 | if True in np.isnan(scaler.scale_): 38 | break 39 | 40 | mean = scaler.mean_ 41 | std = scaler.scale_ 42 | np.save(out_fp_mean, mean) 43 | np.save(out_fp_std, std) 44 | -------------------------------------------------------------------------------- /preprocess/extract_mel.py: -------------------------------------------------------------------------------- 1 | # import glob 2 | import os 3 | import librosa 4 | import numpy as np 5 | # from utils.display import * 6 | # from utils.dsp import * 7 | # import hparams as hp 8 | # from multiprocessing import Pool, cpu_count 9 | from multiprocessing import Pool 10 | from librosa.filters import mel as librosa_mel_fn 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | 15 | import argparse 16 | ''' 17 | Modified from 18 | https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py#L26 19 | ''' 20 | 21 | 22 | class Audio2Mel(nn.Module): 23 | def __init__( 24 | self, 25 | n_fft=1024, 26 | hop_length=256, 27 | win_length=1024, 28 | sampling_rate=22050, 29 | n_mel_channels=80, 30 | mel_fmin=0.0, 31 | mel_fmax=None, 32 | ): 33 | super().__init__() 34 | ############################################## 35 | # FFT Parameters # 36 | ############################################## 37 | window = torch.hann_window(win_length).float() 38 | mel_basis = librosa_mel_fn( 39 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax 40 | ) 41 | mel_basis = torch.from_numpy(mel_basis).float() 42 | self.register_buffer("mel_basis", mel_basis) 43 | self.register_buffer("window", window) 44 | self.n_fft = n_fft 45 | self.hop_length = hop_length 46 | self.win_length = win_length 47 | self.sampling_rate = sampling_rate 48 | self.n_mel_channels = n_mel_channels 49 | 50 | def forward(self, audio): 51 | p = (self.n_fft - self.hop_length) // 2 52 | audio = F.pad(audio, (p, p), "reflect").squeeze(1) 53 | fft = torch.stft( 54 | audio, 55 | n_fft=self.n_fft, 56 | hop_length=self.hop_length, 57 | win_length=self.win_length, 58 | window=self.window, 59 | center=False, 60 | ) 61 | real_part, imag_part = fft.unbind(-1) 62 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 63 | mel_output = torch.matmul(self.mel_basis, magnitude) 64 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5)) 65 | return log_mel_spec[:, :, :] 66 | 67 | 68 | def convert_file(path): 69 | y, _ = librosa.load(path, sr=sr) 70 | peak = np.abs(y).max() 71 | if peak_norm or peak > 1.0: 72 | y /= peak 73 | 74 | y = torch.from_numpy(y) 75 | y = y[None, None] 76 | mel = extract_func(y) 77 | mel = mel.numpy() 78 | mel = mel[0] 79 | print(mel.shape) 80 | 81 | return mel.astype(np.float32) 82 | 83 | 84 | def process_audios(path): 85 | id = path.split('/')[-1][:-4] 86 | 87 | out_dir = os.path.join(base_out_dir, feat_type) 88 | os.makedirs(out_dir, exist_ok=True) 89 | 90 | out_fp = os.path.join(out_dir, f'{id}.npy') 91 | 92 | if os.path.exists(out_fp): 93 | print('Done before') 94 | return id, 0 95 | 96 | try: 97 | m = convert_file(path) 98 | 99 | np.save(out_fp, m, allow_pickle=False) 100 | except Exception: 101 | return id, 0 102 | return id, m.shape[-1] 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser(description="compute inception score") 107 | 108 | parser.add_argument("--epoch", type=str, help="path to the model") 109 | 110 | args = parser.parse_args() 111 | base_out_dir = f'/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar' 112 | os.makedirs(base_out_dir, exist_ok=True) 113 | clip_dir = f'/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar' #out_dir from step1 114 | 115 | feat_type = 'mel_80_320' 116 | extension = '.wav' 117 | peak_norm = True 118 | 119 | n_fft = 1024 120 | hop_length = 275 #[241, 482, 964, 1928, 3856] 121 | win_length = 1024 122 | sampling_rate = 44100 123 | n_mel_channels = 80 #[80, 40, 20, 10, 5] 124 | 125 | # ### Process ### 126 | extract_func = Audio2Mel(n_fft, hop_length, win_length, sampling_rate, n_mel_channels) 127 | sr = sampling_rate 128 | 129 | audio_fns = [fn for fn in os.listdir(clip_dir) if fn.endswith(extension)] 130 | 131 | audio_fns = sorted(list(audio_fns)) 132 | 133 | audio_files = [os.path.join(clip_dir, fn) for fn in audio_fns] 134 | 135 | pool = Pool(processes=20) 136 | dataset = [] 137 | 138 | for i, (id, length) in enumerate(pool.imap_unordered(process_audios, audio_files), 1): 139 | print(id) 140 | if length == 0: 141 | continue 142 | dataset += [(id, length)] 143 | -------------------------------------------------------------------------------- /preprocess/make_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from multiprocessing import Pool 4 | import pickle 5 | 6 | 7 | def process_audios(feat_fn): 8 | feat_fp = os.path.join(feat_dir, f'{feat_fn}.npy') 9 | 10 | if os.path.exists(feat_fp): 11 | return feat_fn, np.load(feat_fp).shape[-1] 12 | else: 13 | return feat_fn, 0 14 | 15 | 16 | if __name__ == "__main__": 17 | feat_type = '' 18 | exp_dir = '/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar/mel_80_320' # base_out_dir from step2 19 | 20 | out_fp = os.path.join(exp_dir, 'dataset.pkl') 21 | 22 | # ### Process ### 23 | feat_dir = os.path.join(exp_dir, feat_type) 24 | 25 | feat_fns = [fn.replace('.npy', '') for fn in os.listdir(feat_dir)] 26 | 27 | pool = Pool(processes=20) 28 | dataset = [] 29 | 30 | for i, (feat_fn, length) in enumerate(pool.imap_unordered(process_audios, feat_fns), 1): 31 | print(feat_fn) 32 | if length == 0: 33 | continue 34 | dataset += [(feat_fn, length)] 35 | 36 | with open(out_fp, 'wb') as f: 37 | pickle.dump(dataset, f) 38 | 39 | print(len(dataset)) 40 | -------------------------------------------------------------------------------- /preprocess/trim_2_seconds.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import pyrubberband as pyrb 3 | import madmom 4 | from madmom.features.downbeats import RNNDownBeatProcessor 5 | from madmom.features.downbeats import DBNDownBeatTrackingProcessor 6 | import os 7 | import matplotlib.pyplot as plot 8 | import numpy as np 9 | import soundfile as sf 10 | from multiprocessing import Pool 11 | loop_dir='/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar' 12 | out_dir='/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar' 13 | os.makedirs(out_dir, exist_ok=True) 14 | def one_bar_segment(file): 15 | file_path = os.path.join(loop_dir, file) 16 | try: 17 | y, sr = librosa.core.load(file_path, sr=None) # sr = None will retrieve the original sampling rate = 44100 18 | except: 19 | print('load file failed') 20 | return 21 | try: 22 | act = RNNDownBeatProcessor()(file_path) 23 | down_beat=proc(act) # [..., 2] 2d-shape numpy array 24 | except: 25 | print('except happended') 26 | return 27 | #print(down_beat) 28 | #print(len(y) / sr) 29 | #import pdb; pdb.set_trace() 30 | #retrieve 1, 2, 3, 4, 1blocks 31 | count = 0 32 | bar_list = [] 33 | #print(file) 34 | name = file.replace('.wav', '') 35 | print(down_beat) 36 | for i in range(down_beat.shape[0]): 37 | if down_beat[i][1] == 1 and i + 4 < down_beat.shape[0] and down_beat[i+4][1] == 1: 38 | print(down_beat[i: i + 5, :]) 39 | start_time = down_beat[i][0] 40 | end_time = down_beat[i + 4][0] 41 | count += 1 42 | out_path = os.path.join(out_dir, f'{name}_{count}.wav') 43 | #print(len(y) / sr) 44 | #print(sr) 45 | y_one_bar, _ = librosa.core.load(file_path, offset=start_time, duration = end_time - start_time, sr=None) 46 | y_stretch = pyrb.time_stretch(y_one_bar, sr, (end_time - start_time) / 2) 47 | #print((end_time - start_time)) 48 | #print() 49 | sf.write(out_path, y_stretch, sr) 50 | 51 | print('save file: ', f'{name}_{count}.wav') 52 | #y, sr = librosa.core.load(out_path, sr=None) 53 | #print(librosa.get_duration(y, sr=sr)) 54 | 55 | if __name__ == '__main__': 56 | #dur_list = [] 57 | #for file in os.listdir(loop_dir): 58 | # file_path = os.path.join(loop_dir, file) 59 | # y, sr = librosa.core.load(file_path) 60 | # dur = librosa.get_duration(y, sr) 61 | # dur_list.append(dur) 62 | #num_bins = 10 63 | #plot.hist(dur_list, num_bins, density=True) 64 | #plot.savefig('./duration.png ') 65 | 66 | proc = DBNDownBeatTrackingProcessor(beats_per_bar=4, fps = 100) 67 | file_list = list(os.listdir(loop_dir)) 68 | #print(file_list[1]) 69 | #one_bar_segment(file_list[1]) 70 | #print(file_list[:10]) 71 | #for file in os.listdir(loop_dir): 72 | # file_list.append(file) 73 | with Pool(processes=10) as pool: 74 | pool.map(one_bar_segment, file_list) 75 | -------------------------------------------------------------------------------- /scripts/generate_freesound.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=3 python generate_audio.py \ 3 | --ckpt "freesound_checkpoint.pt" \ 4 | --pics 2000 --data_path "./data/freesound" \ 5 | --store_path "./generated_freesound_one_bar" \ 6 | --style_mixing 7 | -------------------------------------------------------------------------------- /scripts/generate_looperman_four_bar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=3 python generate_looperman_four_bar.py \ 3 | --ckpt "looperman_four_bar_checkpoint.pt" \ 4 | --pics 100 \ 5 | --data_path "./data/looperman_four_bar" \ 6 | --store_path "./generated_audio_looperman_four_bar" 7 | -------------------------------------------------------------------------------- /scripts/generate_looperman_one_bar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=3 python generate_audio.py \ 3 | --ckpt "looperman_one_bar_checkpoint.pt" \ 4 | --pics 2000 --data_path "./data/looperman/" \ 5 | --store_path "./generated_looperman_one_bar" 6 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=1 python train_drum.py \ 3 | --size 64 --batch 8 --sample_dir sample_Bandlab_Beats_one_bar \ 4 | --checkpoint_dir checkpoint_Bandlab_Beats_one_bar \ 5 | /home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Beats_one_bar/mel_80_320 6 | -------------------------------------------------------------------------------- /train_drum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | 15 | try: 16 | import wandb 17 | 18 | except ImportError: 19 | wandb = None 20 | 21 | from model_drum import Generator, Discriminator 22 | from dataset import MultiResolutionDataset_drum 23 | from distributed import ( 24 | get_rank, 25 | synchronize, 26 | reduce_loss_dict, 27 | reduce_sum, 28 | get_world_size, 29 | ) 30 | from non_leaking import augment, AdaptiveAugment 31 | 32 | 33 | def data_sampler(dataset, shuffle, distributed): 34 | if distributed: 35 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 36 | 37 | if shuffle: 38 | return data.RandomSampler(dataset) 39 | 40 | else: 41 | return data.SequentialSampler(dataset) 42 | 43 | 44 | def requires_grad(model, flag=True): 45 | for p in model.parameters(): 46 | p.requires_grad = flag 47 | 48 | 49 | def accumulate(model1, model2, decay=0.999): 50 | par1 = dict(model1.named_parameters()) 51 | par2 = dict(model2.named_parameters()) 52 | 53 | for k in par1.keys(): 54 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 55 | 56 | 57 | def sample_data(loader): 58 | while True: 59 | for batch in loader: 60 | yield batch 61 | 62 | 63 | def d_logistic_loss(real_pred, fake_pred): 64 | real_loss = F.softplus(-real_pred) 65 | fake_loss = F.softplus(fake_pred) 66 | 67 | return real_loss.mean() + fake_loss.mean() 68 | 69 | 70 | def d_r1_loss(real_pred, real_img): 71 | grad_real, = autograd.grad( 72 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 73 | ) 74 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 75 | 76 | return grad_penalty 77 | 78 | 79 | def g_nonsaturating_loss(fake_pred): 80 | loss = F.softplus(-fake_pred).mean() 81 | 82 | return loss 83 | 84 | 85 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 86 | noise = torch.randn_like(fake_img) / math.sqrt( 87 | fake_img.shape[2] * fake_img.shape[3] 88 | ) 89 | grad, = autograd.grad( 90 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 91 | ) 92 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 93 | 94 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 95 | 96 | path_penalty = (path_lengths - path_mean).pow(2).mean() 97 | 98 | return path_penalty, path_mean.detach(), path_lengths 99 | 100 | 101 | def make_noise(batch, latent_dim, n_noise, device): 102 | if n_noise == 1: 103 | return torch.randn(batch, latent_dim, device=device) 104 | 105 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 106 | 107 | return noises 108 | 109 | 110 | def mixing_noise(batch, latent_dim, prob, device): 111 | if prob > 0 and random.random() < prob: 112 | return make_noise(batch, latent_dim, 2, device) 113 | 114 | else: 115 | return [make_noise(batch, latent_dim, 1, device)] 116 | 117 | 118 | def set_grad_none(model, targets): 119 | for n, p in model.named_parameters(): 120 | if n in targets: 121 | p.grad = None 122 | 123 | 124 | def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): 125 | os.makedirs(args.sample_dir, exist_ok=True) 126 | os.makedirs(args.checkpoint_dir, exist_ok=True) 127 | 128 | loader = sample_data(loader) 129 | 130 | pbar = range(args.iter) 131 | 132 | if get_rank() == 0: 133 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 134 | 135 | mean_path_length = 0 136 | 137 | d_loss_val = 0 138 | r1_loss = torch.tensor(0.0, device=device) 139 | g_loss_val = 0 140 | path_loss = torch.tensor(0.0, device=device) 141 | path_lengths = torch.tensor(0.0, device=device) 142 | mean_path_length_avg = 0 143 | loss_dict = {} 144 | 145 | if args.distributed: 146 | g_module = generator.module 147 | d_module = discriminator.module 148 | 149 | else: 150 | g_module = generator 151 | d_module = discriminator 152 | 153 | accum = 0.5 ** (32 / (10 * 1000)) 154 | ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 155 | r_t_stat = 0 156 | 157 | if args.augment and args.augment_p == 0: 158 | ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) 159 | 160 | sample_z = torch.randn(args.n_sample, args.latent, device=device) 161 | 162 | for idx in pbar: 163 | i = idx + args.start_iter 164 | 165 | if i > args.iter: 166 | print("Done!") 167 | 168 | break 169 | 170 | real_img = next(loader) 171 | real_img = real_img.to(device) 172 | mean_fp = os.path.join(args.path, f'mean.mel.npy') 173 | std_fp = os.path.join(args.path, f'std.mel.npy') 174 | feat_dim = 80 175 | mean = torch.from_numpy(np.load(mean_fp)).float().cuda().view(1, feat_dim, 1) 176 | std = torch.from_numpy(np.load(std_fp)).float().cuda().view(1, feat_dim, 1) 177 | real_img = (real_img - mean) / std 178 | requires_grad(generator, False) 179 | requires_grad(discriminator, True) 180 | 181 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 182 | fake_img, _ = generator(noise) 183 | 184 | if args.augment: 185 | real_img_aug, _ = augment(real_img, ada_aug_p) 186 | fake_img, _ = augment(fake_img, ada_aug_p) 187 | 188 | else: 189 | real_img_aug = real_img 190 | 191 | fake_pred = discriminator(fake_img) 192 | real_pred = discriminator(real_img_aug) 193 | d_loss = d_logistic_loss(real_pred, fake_pred) 194 | 195 | loss_dict["d"] = d_loss 196 | loss_dict["real_score"] = real_pred.mean() 197 | loss_dict["fake_score"] = fake_pred.mean() 198 | 199 | discriminator.zero_grad() 200 | d_loss.backward() 201 | d_optim.step() 202 | 203 | if args.augment and args.augment_p == 0: 204 | ada_aug_p = ada_augment.tune(real_pred) 205 | r_t_stat = ada_augment.r_t_stat 206 | 207 | d_regularize = i % args.d_reg_every == 0 208 | 209 | if d_regularize: 210 | real_img.requires_grad = True 211 | real_pred = discriminator(real_img) 212 | r1_loss = d_r1_loss(real_pred, real_img) 213 | 214 | discriminator.zero_grad() 215 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() 216 | 217 | d_optim.step() 218 | 219 | loss_dict["r1"] = r1_loss 220 | 221 | requires_grad(generator, True) 222 | requires_grad(discriminator, False) 223 | 224 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 225 | fake_img, _ = generator(noise) 226 | 227 | if args.augment: 228 | fake_img, _ = augment(fake_img, ada_aug_p) 229 | 230 | fake_pred = discriminator(fake_img) 231 | g_loss = g_nonsaturating_loss(fake_pred) 232 | 233 | loss_dict["g"] = g_loss 234 | 235 | generator.zero_grad() 236 | g_loss.backward() 237 | g_optim.step() 238 | 239 | g_regularize = i % args.g_reg_every == 0 240 | 241 | if g_regularize: 242 | path_batch_size = max(1, args.batch // args.path_batch_shrink) 243 | noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) 244 | fake_img, latents = generator(noise, return_latents=True) 245 | 246 | path_loss, mean_path_length, path_lengths = g_path_regularize( 247 | fake_img, latents, mean_path_length 248 | ) 249 | 250 | generator.zero_grad() 251 | weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss 252 | 253 | if args.path_batch_shrink: 254 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0] 255 | 256 | weighted_path_loss.backward() 257 | 258 | g_optim.step() 259 | 260 | mean_path_length_avg = ( 261 | reduce_sum(mean_path_length).item() / get_world_size() 262 | ) 263 | 264 | loss_dict["path"] = path_loss 265 | loss_dict["path_length"] = path_lengths.mean() 266 | 267 | accumulate(g_ema, g_module, accum) 268 | 269 | loss_reduced = reduce_loss_dict(loss_dict) 270 | 271 | d_loss_val = loss_reduced["d"].mean().item() 272 | g_loss_val = loss_reduced["g"].mean().item() 273 | r1_val = loss_reduced["r1"].mean().item() 274 | path_loss_val = loss_reduced["path"].mean().item() 275 | real_score_val = loss_reduced["real_score"].mean().item() 276 | fake_score_val = loss_reduced["fake_score"].mean().item() 277 | path_length_val = loss_reduced["path_length"].mean().item() 278 | 279 | if get_rank() == 0: 280 | pbar.set_description( 281 | ( 282 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " 283 | f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " 284 | f"augment: {ada_aug_p:.4f}" 285 | ) 286 | ) 287 | 288 | if wandb and args.wandb: 289 | wandb.log( 290 | { 291 | "Generator": g_loss_val, 292 | "Discriminator": d_loss_val, 293 | "Augment": ada_aug_p, 294 | "Rt": r_t_stat, 295 | "R1": r1_val, 296 | "Path Length Regularization": path_loss_val, 297 | "Mean Path Length": mean_path_length, 298 | "Real Score": real_score_val, 299 | "Fake Score": fake_score_val, 300 | "Path Length": path_length_val, 301 | } 302 | ) 303 | 304 | if i % 100 == 0: 305 | with torch.no_grad(): 306 | g_ema.eval() 307 | sample, _ = g_ema([sample_z]) 308 | utils.save_image( 309 | sample, 310 | f"{args.sample_dir}/{str(i).zfill(6)}.png", 311 | nrow=int(args.n_sample ** 0.5), 312 | normalize=True, 313 | range=(-1, 1), 314 | ) 315 | 316 | if i % 10000 == 0: 317 | torch.save( 318 | { 319 | "g": g_module.state_dict(), 320 | "d": d_module.state_dict(), 321 | "g_ema": g_ema.state_dict(), 322 | "g_optim": g_optim.state_dict(), 323 | "d_optim": d_optim.state_dict(), 324 | "args": args, 325 | "ada_aug_p": ada_aug_p, 326 | }, 327 | f"{args.checkpoint_dir}/{str(i).zfill(6)}.pt", 328 | ) 329 | 330 | 331 | if __name__ == "__main__": 332 | device = "cuda" 333 | 334 | parser = argparse.ArgumentParser(description="StyleGAN2 trainer") 335 | 336 | parser.add_argument("path", type=str, help="path to the lmdb dataset") 337 | parser.add_argument( 338 | "--iter", type=int, default=800000, help="total training iterations" 339 | ) 340 | parser.add_argument( 341 | "--batch", type=int, default=16, help="batch sizes for each gpus" 342 | ) 343 | parser.add_argument( 344 | "--n_sample", 345 | type=int, 346 | default=16, 347 | help="number of the samples generated during training", 348 | ) 349 | parser.add_argument( 350 | "--size", type=int, default=256, help="image sizes for the model" 351 | ) 352 | parser.add_argument( 353 | "--r1", type=float, default=10, help="weight of the r1 regularization" 354 | ) 355 | parser.add_argument( 356 | "--path_regularize", 357 | type=float, 358 | default=2, 359 | help="weight of the path length regularization", 360 | ) 361 | parser.add_argument( 362 | "--path_batch_shrink", 363 | type=int, 364 | default=2, 365 | help="batch size reducing factor for the path length regularization (reduce memory consumption)", 366 | ) 367 | parser.add_argument( 368 | "--d_reg_every", 369 | type=int, 370 | default=16, 371 | help="interval of the applying r1 regularization", 372 | ) 373 | parser.add_argument( 374 | "--g_reg_every", 375 | type=int, 376 | default=4, 377 | help="interval of the applying path length regularization", 378 | ) 379 | parser.add_argument( 380 | "--mixing", type=float, default=0.9, help="probability of latent code mixing" 381 | ) 382 | parser.add_argument( 383 | "--ckpt", 384 | type=str, 385 | default=None, 386 | help="path to the checkpoints to resume training", 387 | ) 388 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate") 389 | parser.add_argument( 390 | "--channel_multiplier", 391 | type=int, 392 | default=2, 393 | help="channel multiplier factor for the model. config-f = 2, else = 1", 394 | ) 395 | parser.add_argument( 396 | "--wandb", action="store_true", help="use weights and biases logging" 397 | ) 398 | parser.add_argument( 399 | "--local_rank", type=int, default=0, help="local rank for distributed training" 400 | ) 401 | parser.add_argument( 402 | "--augment", action="store_true", help="apply non leaking augmentation" 403 | ) 404 | parser.add_argument( 405 | "--augment_p", 406 | type=float, 407 | default=0, 408 | help="probability of applying augmentation. 0 = use adaptive augmentation", 409 | ) 410 | parser.add_argument( 411 | "--ada_target", 412 | type=float, 413 | default=0.6, 414 | help="target augmentation probability for adaptive augmentation", 415 | ) 416 | parser.add_argument( 417 | "--ada_length", 418 | type=int, 419 | default=500 * 1000, 420 | help="target duraing to reach augmentation probability for adaptive augmentation", 421 | ) 422 | parser.add_argument( 423 | "--ada_every", 424 | type=int, 425 | default=256, 426 | help="probability update interval of the adaptive augmentation", 427 | ) 428 | 429 | parser.add_argument( 430 | "--sample_dir", 431 | type=str, 432 | default='sample', 433 | help="sample directory", 434 | ) 435 | 436 | parser.add_argument( 437 | "--checkpoint_dir", 438 | type=str, 439 | default='checkpoint', 440 | help="checkpoint directory", 441 | ) 442 | args = parser.parse_args() 443 | 444 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 445 | args.distributed = n_gpu > 1 446 | 447 | if args.distributed: 448 | torch.cuda.set_device(args.local_rank) 449 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 450 | synchronize() 451 | 452 | args.latent = 512 453 | args.n_mlp = 8 454 | 455 | args.start_iter = 0 456 | 457 | generator = Generator( 458 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 459 | ).to(device) 460 | discriminator = Discriminator( 461 | args.size, channel_multiplier=args.channel_multiplier 462 | ).to(device) 463 | g_ema = Generator( 464 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 465 | ).to(device) 466 | g_ema.eval() 467 | accumulate(g_ema, generator, 0) 468 | 469 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 470 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 471 | 472 | g_optim = optim.Adam( 473 | generator.parameters(), 474 | lr=args.lr * g_reg_ratio, 475 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), 476 | ) 477 | d_optim = optim.Adam( 478 | discriminator.parameters(), 479 | lr=args.lr * d_reg_ratio, 480 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 481 | ) 482 | 483 | if args.ckpt is not None: 484 | print("load model:", args.ckpt) 485 | 486 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 487 | 488 | try: 489 | ckpt_name = os.path.basename(args.ckpt) 490 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 491 | 492 | except ValueError: 493 | pass 494 | 495 | generator.load_state_dict(ckpt["g"]) 496 | discriminator.load_state_dict(ckpt["d"]) 497 | g_ema.load_state_dict(ckpt["g_ema"]) 498 | 499 | g_optim.load_state_dict(ckpt["g_optim"]) 500 | d_optim.load_state_dict(ckpt["d_optim"]) 501 | 502 | if args.distributed: 503 | generator = nn.parallel.DistributedDataParallel( 504 | generator, 505 | device_ids=[args.local_rank], 506 | output_device=args.local_rank, 507 | broadcast_buffers=False, 508 | ) 509 | 510 | discriminator = nn.parallel.DistributedDataParallel( 511 | discriminator, 512 | device_ids=[args.local_rank], 513 | output_device=args.local_rank, 514 | broadcast_buffers=False, 515 | ) 516 | 517 | transform = transforms.Compose( 518 | [ 519 | #transforms.RandomHorizontalFlip(), 520 | transforms.ToTensor(), 521 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 522 | ] 523 | ) 524 | 525 | dataset = MultiResolutionDataset_drum(args.path, transform) 526 | loader = data.DataLoader( 527 | dataset, 528 | batch_size=args.batch, 529 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 530 | drop_last=True, 531 | ) 532 | 533 | if get_rank() == 0 and wandb is not None and args.wandb: 534 | wandb.init(project="stylegan 2") 535 | 536 | train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) 537 | --------------------------------------------------------------------------------