├── .gitignore
├── LICENSE.md
├── README.md
├── __pycache__
└── model_drum.cpython-38.pyc
├── cog.yaml
├── data
├── freesound
│ ├── mean.mel.npy
│ └── std.mel.npy
├── looperman
│ ├── mean.mel.npy
│ └── std.mel.npy
└── looperman_four_bar
│ ├── mean.mel.npy
│ └── std.mel.npy
├── dataset.py
├── distributed.py
├── environment.yml
├── evaluation
├── FAD
│ └── looperman_2000.stats
├── IS
│ ├── attention_modules.py
│ ├── best_model.ckpt
│ ├── compute_is_score.sh
│ ├── inception_score.py
│ ├── model.py
│ └── modules.py
├── NDB_JS
│ ├── compute_ndb_js.py
│ ├── compute_ndb_js.sh
│ └── ndb.py
└── nine_audio
│ ├── Chillout
│ └── 1365.wav
│ ├── Drum_and_Bass
│ └── 14.wav
│ ├── Electronic
│ └── 1353.wav
│ ├── Hiphop
│ └── 323.wav
│ ├── Rap
│ └── 153.wav
│ ├── Rock
│ └── 752.wav
│ ├── Trap
│ └── 1395.wav
│ ├── dupstep
│ └── 113.wav
│ └── industrial
│ └── 1606.wav
├── generate_audio.py
├── generate_looperman_four_bar.py
├── melgan
├── .ipynb_checkpoints
│ └── modules-checkpoint.py
├── args.yml
├── best_netG.pt
└── modules.py
├── model_drum.py
├── model_drum_four_bar.py
├── non_leaking.py
├── op
├── __init__.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
├── predict.py
├── preprocess
├── collect_audio_clips.py
├── compute_mean_std.mel.py
├── extract_mel.py
├── make_dataset.py
└── trim_2_seconds.py
├── scripts
├── generate_freesound.sh
├── generate_looperman_four_bar.sh
├── generate_looperman_one_bar.sh
└── train.sh
└── train_drum.py
/.gitignore:
--------------------------------------------------------------------------------
1 | generated_looperman_one_bar
2 | generated_freesound_one_bar
3 | generated_interpolation_one_bar_2
4 | pretrained_model
5 | __pycache__/*.pyc
6 | __pycache__
7 | __pycache__/model_drum.cpython-38.pyc
8 | evaluation/NDB_JS/__pycache__
9 | evaluation/NDB_JS/looperman_2000.txt
10 | evaluation/NDB_JS/looper_2000/
11 | evaluation/NDB_JS/looper_2000.zip
12 | evaluation/user_study_analysis/
13 | evaluation/IS/freesound_styelgan2.pkl
14 | freesound_checkpoint.pt
15 | generated_audio_looperman_four_bar/
16 | looperman_four_bar_checkpoint.pt
17 | looperman_one_bar_checkpoint.pt
18 | freesound_mel_80_320.zip
19 | mel_80_320/
20 | freesound_checkpoint/
21 | freesound_sample_dir/
22 | evaluation/IS/compute_is_score_one.sh
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) loop-generation develope team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LoopTest
2 | [](./LICENSE.md)
3 | 
4 | 
5 |
6 |
7 | * This is the official repository of **A Benchmarking Initiative for Audio-domain Music Generation Using the FreeSound Loop Dataset** co-authored with [Paul Chen](https://paulyuchen.com/), [Arthur Yeh](http://yentung.com/) and my supervisor [Yi-Hsuan Yang](http://mac.citi.sinica.edu.tw/~yang/). The paper has been accepted by International Society for Music Information Retrieval Conference 2021. [[Demo Page]](https://loopgen.github.io/), [[arxiv]](https://arxiv.org/pdf/2108.01576.pdf).
8 | * We not only provided pretrained model to generate loops on your own but also provided scripts for you to evaluate the generated loops.
9 | ## Environment
10 | ```
11 | $ conda env create -f environment.yml
12 | ```
13 | ## Quick Start
14 |
15 | * Generate loops from one-bar looperman pretrained model
16 | ``` bash
17 | $ gdown --id 1GQpzWz9ycIm5wzkxLsVr-zN17GWD3_6K -O looperman_one_bar_checkpoint.pt
18 | $ bash scripts/generate_looperman_one_bar.sh
19 | ```
20 |
21 | * Generate loops from four-bar looperman pretrained model
22 | ``` bash
23 | $ gdown --id 19rk3vx7XM4dultTF1tN4srCpdya7uxBV -O looperman_four_bar_checkpoint.pt
24 | $ bash scripts/generate_looperman_four_bar.sh
25 | ```
26 |
27 | * Generate loops from freesound pretrained model
28 | ``` bash
29 | $ gdown --id 197DMCOASEMFBVi8GMahHfRwgJ0bhcUND -O freesound_checkpoint.pt
30 | $ bash scripts/generate_freesound.sh
31 | ```
32 | ## Pretrained Checkpoint
33 | * [Looperman pretrained one-bar model](https://drive.google.com/file/d/1GQpzWz9ycIm5wzkxLsVr-zN17GWD3_6K/view?usp=sharing)
34 | * [Looperman pretrained four-bar model](https://drive.google.com/file/d/19rk3vx7XM4dultTF1tN4srCpdya7uxBV/view?usp=sharing)
35 | * [Freesound pretrained one-bar model](https://drive.google.com/file/d/197DMCOASEMFBVi8GMahHfRwgJ0bhcUND/view?usp=sharing)
36 |
37 | ## Benchmarking Freesound Loop Dataset
38 | ### Download dataset
39 | ``` bash
40 |
41 | $ gdown --id 1fQfSZgD9uWbCdID4SzVqNGhsYNXOAbK5
42 | $ unzip freesound_mel_80_320.zip
43 |
44 | ```
45 | ### Training
46 |
47 | ``` bash
48 | $ CUDA_VISIBLE_DEVICES=2 python train_drum.py \
49 | --size 64 --batch 8 --sample_dir freesound_sample_dir \
50 | --checkpoint_dir freesound_checkpoint \
51 | --iter 100000
52 | mel_80_320
53 | ```
54 |
55 | ### Generate audio
56 | ```bash
57 | $ CUDA_VISIBLE_DEVICES=2 python generate_audio.py \
58 | --ckpt freesound_checkpoint/100000.pt \
59 | --pics 2000 --data_path "./data/freesound" \
60 | --store_path "./generated_freesound_one_bar"
61 | ```
62 | ### Evaluation
63 | #### NDB_JS
64 | * 2000 looperman melspectrogram [link](https://drive.google.com/file/d/1aFGPYlkkAysVBWp9VacHVk2tf-b4rLIh/view?usp=sharing)
65 | ``` bash
66 | $ cd evaluation/NDB_JS
67 | $ gdown --id 1aFGPYlkkAysVBWp9VacHVk2tf-b4rLIh
68 | $ unzip looper_2000.zip # contain 2000 looperman mel-sepctrogram
69 | $ rm looper_2000/.zip
70 | $ bash compute_ndb_js.sh
71 | ```
72 | #### IS
73 | * Short-Chunk CNN [checkpoint](./evaluation/IS/best_model.ckpt)
74 | ``` bash
75 | $ cd evaluation/IS
76 | $ bash compute compute_is_score.sh
77 | ```
78 | #### FAD
79 | * FAD looperman ground truth [link](./evaluation/FAD/looperman_2000.stats), follow the official [doc][fad] to install required packages.
80 |
81 | ``` bash
82 | $ ls --color=never generated_freesound_one_bar/100000/*.wav > freesound.csv
83 | $ python -m frechet_audio_distance.create_embeddings_main --input_files freesound.csv --stats freesound.stats
84 | $ python -m frechet_audio_distance.compute_fad --background_stats ./evaluation/FAD/looperman_2000.stats --test_stats freesound.stats
85 | ```
86 |
87 |
88 |
89 | ## Train the model with your loop dataset
90 | ### Preprocess the Loop Dataset
91 | In the [preprocess](./preprocess) directory and modify some settings (e.g. data path) in the codes and run them with the following orders
92 | ``` bash
93 | $ python trim_2_seconds.py # Cut loop into the single bar and stretch them to 2 second.
94 | $ python extract_mel.py # Extract mel-spectrogram from 2-second audio.
95 | $ python make_dataset.py
96 | $ python compute_mean_std.py
97 | ```
98 |
99 | ### Train the Model
100 | ``` bash
101 | CUDA_VISIBLE_DEVICES=2 python train_drum.py \
102 | --size 64 --batch 8 --sample_dir [sample_dir] \
103 | --checkpoint_dir [checkpoint_dir] \
104 | [mel-spectrogram dataset from the proprocessing]
105 | ```
106 | * checkpoint_dir stores model in the designated directory.
107 | * sample_dir stores mel-spectrogram generated from the model.
108 | * You should give the data directory in the end.
109 | * There is an example training [script](./scripts/train.sh)
110 |
111 | ## Vocoder
112 | We use [MelGAN][melgan] as the vocoder. We trained the vocoder with looperman dataset and use the vocoder in generating freesound and looperman models.
113 | The trained vocoder is in [melgan](./melgan) directory.
114 |
115 | ## References
116 | The code comes heavily from the code below
117 | * [StyleGAN2 from rosinality][stylegan2]
118 | * [Official MelGAN repo][melgan]
119 | * [Official UNAGAN repo from ciaua][unagan].
120 | * [Official Short Chunk CNN repo][cnn]
121 | * [FAD official document][fad]
122 |
123 | [fad]: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
124 | [cnn]: https://github.com/minzwon/sota-music-tagging-models
125 | [stylegan2]: https://github.com/rosinality/stylegan2-pytorch
126 | [unagan]: https://github.com/ciaua/unagan
127 | [melgan]: https://github.com/descriptinc/melgan-neurips
128 |
129 | ## Citation
130 | If you find this repo useful, please kindly cite with the following information.
131 | ```
132 | @inproceedings{ allenloopgen,
133 | title={A Benchmarking Initiative for Audio-domain Music Generation using the {FreeSound Loop Dataset}},
134 | author={Tun-Min Hung and Bo-Yu Chen and Yen-Tung Yeh, and Yi-Hsuan Yang},
135 | booktitle = {Proc. Int. Society for Music Information Retrieval Conf.},
136 | year={2021},
137 | }
138 | ```
139 |
--------------------------------------------------------------------------------
/__pycache__/model_drum.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/__pycache__/model_drum.cpython-38.pyc
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | predict: "predict.py:Predictor"
2 | build:
3 | gpu: true
4 | python_version: "3.8"
5 | system_packages:
6 | - "ffmpeg"
7 | python_packages:
8 | - "cython==0.29.24"
9 | - "stylegan2-pytorch==1.8.1"
10 | - "torch==1.6.0"
11 | - "torchaudio==0.6.0"
12 | - "torchvision==0.7.0"
13 | - "librosa==0.8.0"
14 | - "numpy==1.20.1"
15 | - "scipy==1.6.2"
16 | - "tqdm==4.55.2"
17 | - "pysoundfile==0.9.0.post1"
18 | - "pyrubberband==0.3.0"
19 | - "pydub==0.23.1"
20 | - "matplotlib==3.3.4"
21 | - "lmdb==0.96"
22 | - "pillow==8.2.0"
23 | - "ninja==1.10.2"
24 | pre_install:
25 | - "pip install madmom==0.16.1" # needs to be after python_packages since madmom needs cython without requiring it
26 |
--------------------------------------------------------------------------------
/data/freesound/mean.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/freesound/mean.mel.npy
--------------------------------------------------------------------------------
/data/freesound/std.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/freesound/std.mel.npy
--------------------------------------------------------------------------------
/data/looperman/mean.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman/mean.mel.npy
--------------------------------------------------------------------------------
/data/looperman/std.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman/std.mel.npy
--------------------------------------------------------------------------------
/data/looperman_four_bar/mean.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman_four_bar/mean.mel.npy
--------------------------------------------------------------------------------
/data/looperman_four_bar/std.mel.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/data/looperman_four_bar/std.mel.npy
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 | from torch.utils import data
8 | import numpy as np
9 | import os
10 | def data_sampler(dataset, shuffle, distributed):
11 | if distributed:
12 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
13 |
14 | if shuffle:
15 | return data.RandomSampler(dataset)
16 |
17 | else:
18 | return data.SequentialSampler(dataset)
19 |
20 | class MultiResolutionDataset(Dataset):
21 | def __init__(self, path, transform, resolution=256):
22 | self.env = lmdb.open(
23 | path,
24 | max_readers=32,
25 | readonly=True,
26 | lock=False,
27 | readahead=False,
28 | meminit=False,
29 | )
30 |
31 | if not self.env:
32 | raise IOError('Cannot open lmdb dataset', path)
33 |
34 | with self.env.begin(write=False) as txn:
35 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
36 |
37 | self.resolution = resolution
38 | self.transform = transform
39 |
40 | def __len__(self):
41 | return self.length
42 |
43 | def __getitem__(self, index):
44 | with self.env.begin(write=False) as txn:
45 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
46 | img_bytes = txn.get(key)
47 |
48 | buffer = BytesIO(img_bytes)
49 | img = Image.open(buffer)
50 | img = self.transform(img)
51 |
52 | return img
53 | class MultiResolutionDataset_drum(Dataset):
54 | def __init__(self, path, transform, resolution=None):
55 | self.path_list = []
56 | for file in os.listdir(path):
57 | if file.endswith('.npy') == True:
58 | if file.startswith('std') == False and file.startswith('mean') == False:
59 | self.path_list.append(os.path.join(path, file))
60 | self.resolution = resolution
61 | self.transform = transform
62 |
63 | def __len__(self):
64 | return len(self.path_list)
65 |
66 | def __getitem__(self, index):
67 |
68 | img = np.load(self.path_list[index])
69 | img = self.transform(img)
70 |
71 | return img
72 |
73 | class MultiResolutionDataset_drum_with_filename(Dataset):
74 | def __init__(self, path, transform, resolution=None):
75 | self.path_list = []
76 | for file in os.listdir(path):
77 | if file.endswith('.npy') == True:
78 | if file.startswith('std') == False and file.startswith('mean') == False:
79 | self.path_list.append(os.path.join(path, file))
80 | self.resolution = resolution
81 | self.transform = transform
82 |
83 | def __len__(self):
84 | return len(self.path_list)
85 |
86 | def __getitem__(self, index):
87 |
88 | img = np.load(self.path_list[index])
89 | img = self.transform(img)
90 |
91 | return img, self.path_list[index].split('/')[-1]
92 |
93 | class MultiResolutionDataset_drum_with_label(Dataset):
94 | def __init__(self, path, transform, label_dictionary, resolution=None):
95 | self.path_list = []
96 | for file in os.listdir(path):
97 | if file.endswith('.npy') == True:
98 | if file.startswith('std') == False and file.startswith('mean') == False:
99 | self.path_list.append(os.path.join(path, file))
100 | self.resolution = resolution
101 | self.transform = transform
102 | class MultiResolutionDataset_drum_with_label(Dataset):
103 | def __init__(self, path, transform, label_dictionary, resolution=None):
104 | self.path_list = []
105 | for file in os.listdir(path):
106 | if file.endswith('.npy') == True:
107 | if file.startswith('std') == False and file.startswith('mean') == False:
108 | self.path_list.append(os.path.join(path, file))
109 | self.resolution = resolution
110 | self.transform = transform
111 |
112 | ## read label dictionary
113 | import pickle as pickle
114 | with open(label_dictionary, 'rb') as f:
115 | self.label_dictionary = pickle.load(f)
116 |
117 | ## genre to int dictionary
118 | self.genre_to_int = {}
119 | count = 0
120 | for _, genre in self.label_dictionary.items():
121 | if self.genre_to_int.get(genre) == None:
122 | self.genre_to_int[genre] = count
123 | count += 1
124 | def __len__(self):
125 | return len(self.path_list)
126 |
127 | def __getitem__(self, index):
128 |
129 | img = np.load(self.path_list[index])
130 | img = self.transform(img)
131 |
132 | file_name = self.path_list[index].split('/')[-1]
133 | label = self.genre_to_int[self.label_dictionary[file_name]]
134 | return img, label
135 | if __name__ == '__main__':
136 | transform = transforms.Compose(
137 | [
138 | #transforms.RandomHorizontalFlip(),
139 | transforms.ToTensor(),
140 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
141 | ]
142 | )
143 | path = '/home/allenhung/nas189/home/style-based-gan-drum/training_data_one_bar_all/mel_80_320_genre_more_than_600'
144 | label_dictionary = '/home/allenhung/nas189/home/style-based-gan-drum/training_data_one_bar_all/dict_one_bar_more_than_600.pickle'
145 | dataset = MultiResolutionDataset_drum_with_label(path, transform, label_dictionary)
146 | loader = data.DataLoader(
147 | dataset,
148 | batch_size=2,
149 | sampler=data_sampler(dataset, shuffle=True, distributed=False),
150 | drop_last=True,
151 | )
152 | for data in loader:
153 | import pdb; pdb.set_trace()
154 |
--------------------------------------------------------------------------------
/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def reduce_sum(tensor):
45 | if not dist.is_available():
46 | return tensor
47 |
48 | if not dist.is_initialized():
49 | return tensor
50 |
51 | tensor = tensor.clone()
52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53 |
54 | return tensor
55 |
56 |
57 | def gather_grad(params):
58 | world_size = get_world_size()
59 |
60 | if world_size == 1:
61 | return
62 |
63 | for param in params:
64 | if param.grad is not None:
65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66 | param.grad.data.div_(world_size)
67 |
68 |
69 | def all_gather(data):
70 | world_size = get_world_size()
71 |
72 | if world_size == 1:
73 | return [data]
74 |
75 | buffer = pickle.dumps(data)
76 | storage = torch.ByteStorage.from_buffer(buffer)
77 | tensor = torch.ByteTensor(storage).to('cuda')
78 |
79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81 | dist.all_gather(size_list, local_size)
82 | size_list = [int(size.item()) for size in size_list]
83 | max_size = max(size_list)
84 |
85 | tensor_list = []
86 | for _ in size_list:
87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88 |
89 | if local_size != max_size:
90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91 | tensor = torch.cat((tensor, padding), 0)
92 |
93 | dist.all_gather(tensor_list, tensor)
94 |
95 | data_list = []
96 |
97 | for size, tensor in zip(size_list, tensor_list):
98 | buffer = tensor.cpu().numpy().tobytes()[:size]
99 | data_list.append(pickle.loads(buffer))
100 |
101 | return data_list
102 |
103 |
104 | def reduce_loss_dict(loss_dict):
105 | world_size = get_world_size()
106 |
107 | if world_size < 2:
108 | return loss_dict
109 |
110 | with torch.no_grad():
111 | keys = []
112 | losses = []
113 |
114 | for k in sorted(loss_dict.keys()):
115 | keys.append(k)
116 | losses.append(loss_dict[k])
117 |
118 | losses = torch.stack(losses, 0)
119 | dist.reduce(losses, dst=0)
120 |
121 | if dist.get_rank() == 0:
122 | losses /= world_size
123 |
124 | reduced_losses = {k: v for k, v in zip(keys, losses)}
125 |
126 | return reduced_losses
127 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: melgan_suc
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=1_gnu
10 | - absl-py=0.12.0=py38h06a4308_0
11 | - appdirs=1.4.4=py_0
12 | - audioread=2.1.9=py38h578d9bd_0
13 | - backcall=0.2.0=pyhd3eb1b0_0
14 | - blas=1.0=mkl
15 | - brotlipy=0.7.0=py38h27cfd23_1003
16 | - bzip2=1.0.8=h7b6447c_0
17 | - c-ares=1.17.1=h27cfd23_0
18 | - ca-certificates=2020.10.14=0
19 | - cairo=1.16.0=h7979940_1007
20 | - certifi=2020.6.20=py38_0
21 | - cffi=1.14.5=py38h261ae71_0
22 | - chardet=4.0.0=py38h06a4308_1003
23 | - click=7.1.2=py_0
24 | - coverage=5.5=py38h27cfd23_2
25 | - cryptography=3.4.7=py38hd23ed53_0
26 | - cudatoolkit=10.2.89=hfd86e86_1
27 | - cycler=0.10.0=py38_0
28 | - dbus=1.13.18=hb2f20db_0
29 | - decorator=5.0.6=pyhd3eb1b0_0
30 | - dlib=19.21.1=py38h2c889cf_0
31 | - expat=2.2.10=he6710b0_2
32 | - ffmpeg=4.3.1=h3215721_1
33 | - flask=1.1.2=py_0
34 | - flask-wtf=0.14.3=py_0
35 | - fontconfig=2.13.1=hba837de_1004
36 | - freetype=2.10.4=h5ab3b9f_0
37 | - gettext=0.21.0=hf68c758_0
38 | - glib=2.66.6=ha03b18c_2
39 | - glib-tools=2.66.6=ha03b18c_2
40 | - gmp=6.2.1=h2531618_2
41 | - gnutls=3.6.15=he1e5248_0
42 | - graphite2=1.3.14=h23475e2_0
43 | - grpcio=1.36.1=py38h2157cd5_1
44 | - gst-plugins-base=1.14.5=h0935bb2_2
45 | - gstreamer=1.18.3=h3560a44_0
46 | - harfbuzz=2.7.4=h5cf4720_0
47 | - hdf5=1.10.6=hb1b8bf9_0
48 | - icu=68.1=h2531618_0
49 | - idna=2.10=pyhd3eb1b0_0
50 | - importlib-metadata=3.10.0=py38h06a4308_0
51 | - intel-openmp=2021.2.0=h06a4308_610
52 | - ipykernel=5.3.4=py38h5ca1d4c_0
53 | - ipython=7.22.0=py38hb070fc8_0
54 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
55 | - itsdangerous=1.1.0=py_0
56 | - jasper=1.900.1=hd497a04_4
57 | - jedi=0.17.0=py38_0
58 | - jinja2=2.11.2=py_0
59 | - joblib=1.0.1=pyhd3eb1b0_0
60 | - jpeg=9d=h36c2ea0_0
61 | - jupyter_client=6.1.12=pyhd3eb1b0_0
62 | - jupyter_core=4.7.1=py38h06a4308_0
63 | - kiwisolver=1.3.1=py38h2531618_0
64 | - krb5=1.17.1=h173b8e3_0
65 | - lame=3.100=h7b6447c_0
66 | - lcms2=2.12=h3be6417_0
67 | - ld_impl_linux-64=2.33.1=h53a641e_7
68 | - libblas=3.9.0=1_h6e990d7_netlib
69 | - libcblas=3.9.0=3_h893e4fe_netlib
70 | - libclang=11.0.1=default_ha53f305_1
71 | - libedit=3.1.20210216=h27cfd23_1
72 | - libevent=2.1.10=hcdb4288_3
73 | - libffi=3.3=he6710b0_2
74 | - libflac=1.3.3=h9c3ff4c_1
75 | - libgcc-ng=9.3.0=h5dbcf3e_17
76 | - libgfortran-ng=7.5.0=hae1eefd_17
77 | - libgfortran4=7.5.0=hae1eefd_17
78 | - libglib=2.66.6=hdb14261_2
79 | - libgomp=9.3.0=h5dbcf3e_17
80 | - libiconv=1.16=h516909a_0
81 | - libidn2=2.3.0=h27cfd23_0
82 | - liblapack=3.9.0=3_h893e4fe_netlib
83 | - liblapacke=3.9.0=3_h893e4fe_netlib
84 | - libllvm10=10.0.1=hbcb73fb_5
85 | - libllvm11=11.0.1=hf817b99_0
86 | - libogg=1.3.4=h7f98852_0
87 | - libopencv=4.5.0=py38h703c3c0_7
88 | - libopus=1.3.1=h7b6447c_0
89 | - libpng=1.6.37=hbc83047_0
90 | - libpq=12.3=h255efa7_3
91 | - libprotobuf=3.14.0=h8c45485_0
92 | - librosa=0.8.0=pyh9f0ad1d_0
93 | - libsndfile=1.0.30=h9c3ff4c_1
94 | - libsodium=1.0.18=h7b6447c_0
95 | - libstdcxx-ng=9.3.0=h2ae2ef3_17
96 | - libtasn1=4.16.0=h27cfd23_0
97 | - libtiff=4.1.0=h2733197_1
98 | - libunistring=0.9.10=h27cfd23_0
99 | - libuuid=2.32.1=h7f98852_1000
100 | - libvorbis=1.3.7=h7b6447c_0
101 | - libwebp-base=1.2.0=h27cfd23_0
102 | - libxcb=1.14=h7b6447c_0
103 | - libxkbcommon=1.0.3=he3ba5ed_0
104 | - libxml2=2.9.10=h72842e0_3
105 | - llvmlite=0.36.0=py38h612dafd_4
106 | - lz4-c=1.9.3=h2531618_0
107 | - markdown=3.3.4=py38h06a4308_0
108 | - markupsafe=1.1.1=py38h7b6447c_0
109 | - matplotlib-base=3.3.4=py38h62a2d02_0
110 | - mkl=2021.2.0=h06a4308_296
111 | - mkl-service=2.3.0=py38h27cfd23_1
112 | - mkl_fft=1.3.0=py38h42c9631_2
113 | - mkl_random=1.2.1=py38ha9443f7_2
114 | - mysql-common=8.0.22=ha770c72_1
115 | - mysql-libs=8.0.22=h1fd7589_1
116 | - ncurses=6.2=he6710b0_1
117 | - nettle=3.7.2=hbbd107a_1
118 | - ninja=1.10.2=hff7bd54_1
119 | - nspr=4.29=h9c3ff4c_1
120 | - nss=3.61=hb5efdd6_0
121 | - numba=0.53.1=py38ha9443f7_0
122 | - numpy=1.20.1=py38h93e21f0_0
123 | - numpy-base=1.20.1=py38h7d8b39e_0
124 | - olefile=0.46=py_0
125 | - opencv=4.5.0=py38h578d9bd_7
126 | - openh264=2.1.1=h8b12597_0
127 | - openssl=1.1.1k=h27cfd23_0
128 | - packaging=20.9=pyhd3eb1b0_0
129 | - parso=0.8.2=pyhd3eb1b0_0
130 | - pcre=8.44=he6710b0_0
131 | - pexpect=4.8.0=pyhd3eb1b0_3
132 | - pickleshare=0.7.5=pyhd3eb1b0_1003
133 | - pillow=8.2.0=py38he98fc37_0
134 | - pip=21.0.1=py38h06a4308_0
135 | - pixman=0.40.0=h7b6447c_0
136 | - pooch=1.3.0=pyhd3eb1b0_0
137 | - prompt-toolkit=3.0.17=pyh06a4308_0
138 | - protobuf=3.14.0=py38h2531618_1
139 | - ptyprocess=0.7.0=pyhd3eb1b0_2
140 | - py-opencv=4.5.0=py38h81c977d_7
141 | - pycparser=2.20=py_2
142 | - pydub=0.23.1=py_0
143 | - pygments=2.8.1=pyhd3eb1b0_0
144 | - pyopenssl=20.0.1=pyhd3eb1b0_1
145 | - pyparsing=2.4.7=pyhd3eb1b0_0
146 | - pysocks=1.7.1=py38h06a4308_0
147 | - python=3.8.8=hdb3f193_5
148 | - python-dateutil=2.8.1=pyhd3eb1b0_0
149 | - python-lmdb=0.96=py38h950e882_1
150 | - python_abi=3.8=1_cp38
151 | - pytorch=1.6.0=py3.8_cuda10.2.89_cudnn7.6.5_0
152 | - pyyaml=5.3.1=py38h8df0ef7_1
153 | - pyzmq=20.0.0=py38h2531618_1
154 | - qt=5.12.9=h9d6b050_2
155 | - readline=8.1=h27cfd23_0
156 | - requests=2.25.1=pyhd3eb1b0_0
157 | - resampy=0.2.2=py_0
158 | - scikit-learn=0.24.1=py38ha9443f7_0
159 | - scipy=1.6.2=py38had2a1c9_1
160 | - setuptools=52.0.0=py38h06a4308_0
161 | - six=1.15.0=py38h06a4308_0
162 | - sqlite=3.35.4=hdfb4753_0
163 | - tbb=2020.3=hfd86e86_0
164 | - tensorboard=1.15.0=py38_0
165 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
166 | - tk=8.6.10=hbc83047_0
167 | - torchaudio=0.6.0=py38
168 | - torchvision=0.7.0=py38_cu102
169 | - tornado=6.1=py38h27cfd23_0
170 | - tqdm=4.55.2=pyhd8ed1ab_0
171 | - traitlets=5.0.5=pyhd3eb1b0_0
172 | - urllib3=1.26.4=pyhd3eb1b0_0
173 | - wcwidth=0.2.5=py_0
174 | - werkzeug=1.0.1=pyhd3eb1b0_0
175 | - wheel=0.36.2=pyhd3eb1b0_0
176 | - wtforms=2.3.3=py_0
177 | - x264=1!152.20180806=h14c3975_0
178 | - xorg-kbproto=1.0.7=h7f98852_1002
179 | - xorg-libice=1.0.10=h516909a_0
180 | - xorg-libsm=1.2.3=h84519dc_1000
181 | - xorg-libx11=1.6.12=h516909a_0
182 | - xorg-libxext=1.3.4=h516909a_0
183 | - xorg-libxrender=0.9.10=h516909a_1002
184 | - xorg-renderproto=0.11.1=h14c3975_1002
185 | - xorg-xextproto=7.3.0=h7f98852_1002
186 | - xorg-xproto=7.0.31=h27cfd23_1007
187 | - xz=5.2.5=h7b6447c_0
188 | - yaml=0.2.5=h7b6447c_0
189 | - zeromq=4.3.4=h2531618_0
190 | - zipp=3.4.1=pyhd3eb1b0_0
191 | - zlib=1.2.11=h7b6447c_3
192 | - zstd=1.4.9=haebb681_0
193 | - pip:
194 | - aim==2.3.0
195 | - aimrecords==0.0.7
196 | - anytree==2.8.0
197 | - base58==2.0.1
198 | - contrastive-learner==0.1.1
199 | - cython==0.29.21
200 | - docker==5.0.0
201 | - einops==0.3.0
202 | - fire==0.4.0
203 | - future==0.18.2
204 | - gitdb==4.0.7
205 | - gitpython==3.1.14
206 | - kornia==0.5.1
207 | - madmom==0.16.1
208 | - mido==1.2.9
209 | - psutil==5.8.0
210 | - py==1.10.0
211 | - py3nvml==0.2.6
212 | - pyrser==0.2.0
213 | - pyrubberband==0.3.0
214 | - pysoundfile==0.9.0.post1
215 | - retry==0.9.2
216 | - smmap==4.0.0
217 | - stylegan2-pytorch==1.8.1
218 | - termcolor==1.1.0
219 | - timm==0.4.5
220 | - vector-quantize-pytorch==0.1.0
221 | - websocket-client==0.59.0
222 | - xmltodict==0.12.0
223 | prefix: /home/allenhung/miniconda3/envs/melgan_suc
224 |
--------------------------------------------------------------------------------
/evaluation/FAD/looperman_2000.stats:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/FAD/looperman_2000.stats
--------------------------------------------------------------------------------
/evaluation/IS/attention_modules.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | # Code adopted from https://github.com/huggingface/pytorch-pretrained-BERT
3 |
4 | import math
5 | import copy
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 |
10 | # Gelu
11 | def gelu(x):
12 | """Implementation of the gelu activation function.
13 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
14 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
15 | Also see https://arxiv.org/abs/1606.08415
16 | """
17 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
18 |
19 | # LayerNorm
20 | try:
21 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
22 | except ImportError:
23 | #print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
24 | class BertLayerNorm(nn.Module):
25 | def __init__(self, hidden_size, eps=1e-12):
26 | """Construct a layernorm module in the TF style (epsilon inside the square root).
27 | """
28 | super(BertLayerNorm, self).__init__()
29 | self.weight = nn.Parameter(torch.ones(hidden_size))
30 | self.bias = nn.Parameter(torch.zeros(hidden_size))
31 | self.variance_epsilon = eps
32 |
33 | def forward(self, x):
34 | u = x.mean(-1, keepdim=True)
35 | s = (x - u).pow(2).mean(-1, keepdim=True)
36 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
37 | return self.weight * x + self.bias
38 |
39 |
40 | class BertConfig(object):
41 | def __init__(self,
42 | vocab_size,
43 | hidden_size=768,
44 | num_hidden_layers=12,
45 | num_attention_heads=12,
46 | intermediate_size=3072,
47 | hidden_act="gelu",
48 | hidden_dropout_prob=0.1,
49 | max_position_embeddings=512,
50 | attention_probs_dropout_prob=0.1,
51 | type_vocab_size=2):
52 | self.vocab_size = vocab_size
53 | self.hidden_size = hidden_size
54 | self.num_hidden_layers = num_hidden_layers
55 | self.num_attention_heads = num_attention_heads
56 | self.hidden_act = hidden_act
57 | self.intermediate_size = intermediate_size
58 | self.hidden_dropout_prob = hidden_dropout_prob
59 | self.max_position_embeddings = max_position_embeddings
60 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
61 | self.type_vocab_size = type_vocab_size
62 |
63 |
64 | class BertSelfAttention(nn.Module):
65 | def __init__(self, config):
66 | super(BertSelfAttention, self).__init__()
67 | if config.hidden_size % config.num_attention_heads != 0:
68 | raise ValueError(
69 | "The hidden size (%d) is not a multiple of the number of attention "
70 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
71 | self.num_attention_heads = config.num_attention_heads
72 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
73 | self.all_head_size = self.num_attention_heads * self.attention_head_size
74 |
75 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
76 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
77 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
78 |
79 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
80 |
81 | def transpose_for_scores(self, x):
82 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
83 | x = x.view(*new_x_shape)
84 | return x.permute(0, 2, 1, 3)
85 |
86 | def forward(self, hidden_states, attention_mask):
87 | mixed_query_layer = self.query(hidden_states)
88 | mixed_key_layer = self.key(hidden_states)
89 | mixed_value_layer = self.value(hidden_states)
90 |
91 | query_layer = self.transpose_for_scores(mixed_query_layer)
92 | key_layer = self.transpose_for_scores(mixed_key_layer)
93 | value_layer = self.transpose_for_scores(mixed_value_layer)
94 |
95 | # Take the dot product between "query" and "key" to get the raw attention scores.
96 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
97 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
98 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
99 | if attention_mask is not None:
100 | attention_scores = attention_scores + attention_mask
101 |
102 | # Normalize the attention scores to probabilities.
103 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
104 |
105 | # This is actually dropping out entire tokens to attend to, which might
106 | # seem a bit unusual, but is taken from the original Transformer paper.
107 | attention_probs = self.dropout(attention_probs)
108 |
109 | context_layer = torch.matmul(attention_probs, value_layer)
110 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
111 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
112 | context_layer = context_layer.view(*new_context_layer_shape)
113 | return context_layer
114 |
115 |
116 | class BertSelfOutput(nn.Module):
117 | def __init__(self, config):
118 | super(BertSelfOutput, self).__init__()
119 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
120 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
121 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
122 |
123 | def forward(self, hidden_states, input_tensor):
124 | hidden_states = self.dense(hidden_states)
125 | hidden_states = self.dropout(hidden_states)
126 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
127 | return hidden_states
128 |
129 |
130 | class BertAttention(nn.Module):
131 | def __init__(self, config):
132 | super(BertAttention, self).__init__()
133 | self.self = BertSelfAttention(config)
134 | self.output = BertSelfOutput(config)
135 |
136 | def forward(self, input_tensor, attention_mask):
137 | self_output = self.self(input_tensor, attention_mask)
138 | attention_output = self.output(self_output, input_tensor)
139 | return attention_output
140 |
141 |
142 | class BertIntermediate(nn.Module):
143 | def __init__(self, config):
144 | super(BertIntermediate, self).__init__()
145 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
146 | self.intermediate_act_fn = gelu
147 |
148 | def forward(self, hidden_states):
149 | hidden_states = self.dense(hidden_states)
150 | hidden_states = self.intermediate_act_fn(hidden_states)
151 | return hidden_states
152 |
153 |
154 | class BertOutput(nn.Module):
155 | def __init__(self, config):
156 | super(BertOutput, self).__init__()
157 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
158 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
159 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
160 |
161 | def forward(self, hidden_states, input_tensor):
162 | hidden_states = self.dense(hidden_states)
163 | hidden_states = self.dropout(hidden_states)
164 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
165 | return hidden_states
166 |
167 |
168 | class BertLayer(nn.Module):
169 | def __init__(self, config):
170 | super(BertLayer, self).__init__()
171 | self.attention = BertAttention(config)
172 | self.intermediate = BertIntermediate(config)
173 | self.output = BertOutput(config)
174 |
175 | def forward(self, hidden_states, attention_mask):
176 | attention_output = self.attention(hidden_states, attention_mask)
177 | intermediate_output = self.intermediate(attention_output)
178 | layer_output = self.output(intermediate_output, attention_output)
179 | return layer_output
180 |
181 |
182 | class BertEncoder(nn.Module):
183 | def __init__(self, config):
184 | super(BertEncoder, self).__init__()
185 | layer = BertLayer(config)
186 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
187 |
188 | def forward(self, hidden_states, attention_mask=None, output_all_encoded_layers=True):
189 | all_encoder_layers = []
190 | for layer_module in self.layer:
191 | hidden_states = layer_module(hidden_states, attention_mask)
192 | if output_all_encoded_layers:
193 | all_encoder_layers.append(hidden_states)
194 | if not output_all_encoded_layers:
195 | all_encoder_layers.append(hidden_states)
196 | return all_encoder_layers
197 |
198 |
199 | class BertEmbeddings(nn.Module):
200 | """Construct the embeddings from word, position and token_type embeddings.
201 | """
202 | def __init__(self, config):
203 | super(BertEmbeddings, self).__init__()
204 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
205 |
206 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
207 | # any TensorFlow checkpoint file
208 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
209 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
210 |
211 | def forward(self, input_ids, token_type_ids=None):
212 | seq_length = input_ids.size(1)
213 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
214 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids[:, :, 0])
215 |
216 | position_embeddings = self.position_embeddings(position_ids)
217 |
218 | embeddings = input_ids + position_embeddings
219 | #embeddings = input_ids
220 | embeddings = self.LayerNorm(embeddings)
221 | embeddings = self.dropout(embeddings)
222 | return embeddings
223 |
224 |
225 | class PositionalEncoding(nn.Module):
226 | def __init__(self, config):
227 | super(PositionalEncoding, self).__init__()
228 | emb_dim = config.hidden_size
229 | max_len = config.max_position_embeddings
230 | self.position_enc = self.position_encoding_init(max_len, emb_dim)
231 |
232 | @staticmethod
233 | def position_encoding_init(n_position, emb_dim):
234 | ''' Init the sinusoid position encoding table '''
235 |
236 | # keep dim 0 for padding token position encoding zero vector
237 | position_enc = np.array([
238 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
239 | if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)])
240 |
241 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim
242 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim
243 | return torch.from_numpy(position_enc).type(torch.FloatTensor)
244 |
245 | def forward(self, word_seq):
246 | position_encoding = self.position_enc.unsqueeze(0).expand_as(word_seq)
247 | position_encoding = position_encoding.to(word_seq.device)
248 | word_pos_encoded = word_seq + position_encoding
249 | return word_pos_encoded
250 |
251 | class BertPooler(nn.Module):
252 | def __init__(self, config):
253 | super(BertPooler, self).__init__()
254 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
255 | self.activation = nn.Tanh()
256 |
257 | def forward(self, hidden_states):
258 | # We "pool" the model by simply taking the hidden state corresponding
259 | # to the first token.
260 | first_token_tensor = hidden_states[:, 0]
261 | pooled_output = self.dense(first_token_tensor)
262 | pooled_output = self.activation(pooled_output)
263 | return pooled_output
264 |
--------------------------------------------------------------------------------
/evaluation/IS/best_model.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/IS/best_model.ckpt
--------------------------------------------------------------------------------
/evaluation/IS/compute_is_score.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=2 python inception_score.py \
3 | ./best_model.ckpt \
4 | --data_dir ../../generated_freesound_one_bar/freesound_checkpoint/100000/mel_80_320 --classes 66 \
5 | --mean_std_dir ../../data/freesound \
6 | --store_every_score freesound_styelgan2.pkl
7 |
8 |
--------------------------------------------------------------------------------
/evaluation/IS/inception_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data
5 | import numpy as np
6 | from scipy.stats import entropy
7 | import argparse
8 | import sys
9 | import os
10 | from model import FCN, ShortChunkCNN_one_bar
11 | sys.path.append('../')
12 | sys.path.append('../../')
13 | from dataset import MultiResolutionDataset_drum, data_sampler
14 | from torchvision import transforms
15 | from scipy.stats import ttest_ind
16 | import pickle
17 | #from module import genre_classifier
18 | def inception_score(args, transform):
19 | #Preprocess
20 | #100 files under args.data_dir
21 | #classify into args.class classes
22 | N = len([name for name in os.listdir(args.data_dir) if os.path.isfile(os.path.join(args.data_dir, name))])
23 | preds = np.zeros((N, args.classes))
24 |
25 | # load checkpoint
26 | checkpoint = torch.load(args.path)
27 | #print(checkpoint['model_state_dict'])
28 | model = ShortChunkCNN_one_bar(n_class = args.classes).cuda()
29 | #print(model)
30 | #print(checkpoint['model_state_dict'].keys())
31 | model.load_state_dict(checkpoint['model_state_dict'])
32 | model.eval()
33 |
34 | #Load data
35 | dataset = MultiResolutionDataset_drum(args.data_dir, transform)
36 | loader =torch.utils.data.DataLoader(
37 | dataset,
38 | batch_size=2,
39 | sampler=data_sampler(dataset, shuffle=True, distributed=False),
40 | drop_last=True,
41 | )
42 |
43 | mean_fp = os.path.join(args.mean_std_dir, f'mean.mel.npy')
44 | std_fp = os.path.join(args.mean_std_dir, f'std.mel.npy')
45 | feat_dim = 80
46 | mean = torch.from_numpy(np.load(mean_fp)).float().cuda().view(1, feat_dim, 1)
47 | std = torch.from_numpy(np.load(std_fp)).float().cuda().view(1, feat_dim, 1)
48 | #Model inference
49 | for i, data in enumerate(loader):
50 | data = data.cuda() # [bs, 1, 80, 320]
51 |
52 | data = data * std + mean
53 | data = data.squeeze(1)
54 | output = model(data) # [bs, args.class]
55 |
56 | logit = F.softmax(output, dim = 1).data.cpu().numpy() # [bs, args.class]
57 |
58 | preds[i * 2 : (i + 1) * 2] = logit
59 | #KL divergence
60 | scores = []
61 | py = np.mean(preds, axis=0)
62 |
63 | for i in range(preds.shape[0]):
64 | pyx = preds[i, :]
65 | scores.append(entropy(pyx, py))
66 | #Data PostProcessing
67 | is_score = np.exp(np.mean(scores))
68 | std = np.exp(np.std(scores))
69 | every_score = np.exp(scores)
70 | return is_score, std, every_score
71 |
72 |
73 |
74 |
75 |
76 |
77 | if __name__ == '__main__':
78 | parser = argparse.ArgumentParser(description="compute inception score")
79 |
80 | parser.add_argument("path", type=str, help="path to the model")
81 |
82 | parser.add_argument("--data_dir", type=str, help="has 100 npy normalized file under this directory")
83 |
84 | parser.add_argument("--classes", type=int, help="number of classes")
85 |
86 | parser.add_argument("--mean_std_dir", type=str, help="directory which has mean and std npy file")
87 | parser.add_argument("--store_every_score", type=str)
88 |
89 | args = parser.parse_args()
90 |
91 | transform = transforms.Compose(
92 | [
93 | #transforms.RandomHorizontalFlip(),
94 | transforms.ToTensor(),
95 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
96 | ]
97 | )
98 |
99 | is_score, std, every_score = inception_score(args, transform)
100 | with open(f'{args.store_every_score}', 'wb') as f:
101 | pickle.dump(every_score, f)
102 | print(is_score, std)
103 |
--------------------------------------------------------------------------------
/evaluation/IS/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | import torchaudio
6 | import sys
7 | from torch.autograd import Variable
8 | import math
9 | import librosa
10 |
11 |
12 | class Conv_1d(nn.Module):
13 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
14 | super(Conv_1d, self).__init__()
15 | self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
16 | self.bn = nn.BatchNorm1d(output_channels)
17 | self.relu = nn.ReLU()
18 | self.mp = nn.MaxPool1d(pooling)
19 | def forward(self, x):
20 | out = self.mp(self.relu(self.bn(self.conv(x))))
21 | return out
22 |
23 |
24 | class Conv_2d(nn.Module):
25 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
26 | super(Conv_2d, self).__init__()
27 | self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
28 | self.bn = nn.BatchNorm2d(output_channels)
29 | self.relu = nn.ReLU()
30 | self.mp = nn.MaxPool2d(pooling)
31 | def forward(self, x):
32 | out = self.mp(self.relu(self.bn(self.conv(x))))
33 | return out
34 |
35 |
36 | class Res_2d(nn.Module):
37 | def __init__(self, input_channels, output_channels, shape=3, stride=2):
38 | super(Res_2d, self).__init__()
39 | # convolution
40 | self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
41 | self.bn_1 = nn.BatchNorm2d(output_channels)
42 | self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2)
43 | self.bn_2 = nn.BatchNorm2d(output_channels)
44 |
45 | # residual
46 | self.diff = False
47 | if (stride != 1) or (input_channels != output_channels):
48 | self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
49 | self.bn_3 = nn.BatchNorm2d(output_channels)
50 | self.diff = True
51 | self.relu = nn.ReLU()
52 |
53 | def forward(self, x):
54 | # convolution
55 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
56 |
57 | # residual
58 | if self.diff:
59 | x = self.bn_3(self.conv_3(x))
60 | out = x + out
61 | out = self.relu(out)
62 | return out
63 |
64 |
65 | class Res_2d_mp(nn.Module):
66 | def __init__(self, input_channels, output_channels, pooling=2):
67 | super(Res_2d_mp, self).__init__()
68 | self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1)
69 | self.bn_1 = nn.BatchNorm2d(output_channels)
70 | self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1)
71 | self.bn_2 = nn.BatchNorm2d(output_channels)
72 | self.relu = nn.ReLU()
73 | self.mp = nn.MaxPool2d(pooling)
74 | def forward(self, x):
75 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
76 | out = x + out
77 | out = self.mp(self.relu(out))
78 | return out
79 |
80 |
81 | class ResSE_1d(nn.Module):
82 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=3):
83 | super(ResSE_1d, self).__init__()
84 | # convolution
85 | self.conv_1 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
86 | self.bn_1 = nn.BatchNorm1d(output_channels)
87 | self.conv_2 = nn.Conv1d(output_channels, output_channels, shape, padding=shape//2)
88 | self.bn_2 = nn.BatchNorm1d(output_channels)
89 |
90 | # squeeze & excitation
91 | self.dense1 = nn.Linear(output_channels, output_channels)
92 | self.dense2 = nn.Linear(output_channels, output_channels)
93 |
94 | # residual
95 | self.diff = False
96 | if (stride != 1) or (input_channels != output_channels):
97 | self.conv_3 = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
98 | self.bn_3 = nn.BatchNorm1d(output_channels)
99 | self.diff = True
100 | self.relu = nn.ReLU()
101 | self.sigmoid = nn.Sigmoid()
102 | self.mp = nn.MaxPool1d(pooling)
103 |
104 | def forward(self, x):
105 | # convolution
106 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
107 |
108 | # squeeze & excitation
109 | se_out = nn.AvgPool1d(out.size(-1))(out)
110 | se_out = se_out.squeeze(-1)
111 | se_out = self.relu(self.dense1(se_out))
112 | se_out = self.sigmoid(self.dense2(se_out))
113 | se_out = se_out.unsqueeze(-1)
114 | out = torch.mul(out, se_out)
115 |
116 | # residual
117 | if self.diff:
118 | x = self.bn_3(self.conv_3(x))
119 | out = x + out
120 | out = self.mp(self.relu(out))
121 | return out
122 |
123 |
124 | class Conv_V(nn.Module):
125 | # vertical convolution
126 | def __init__(self, input_channels, output_channels, filter_shape):
127 | super(Conv_V, self).__init__()
128 | self.conv = nn.Conv2d(input_channels, output_channels, filter_shape,
129 | padding=(0, filter_shape[1]//2))
130 | self.bn = nn.BatchNorm2d(output_channels)
131 | self.relu = nn.ReLU()
132 |
133 | def forward(self, x):
134 | x = self.relu(self.bn(self.conv(x)))
135 | freq = x.size(2)
136 | out = nn.MaxPool2d((freq, 1), stride=(freq, 1))(x)
137 | out = out.squeeze(2)
138 | return out
139 |
140 |
141 | class Conv_H(nn.Module):
142 | # horizontal convolution
143 | def __init__(self, input_channels, output_channels, filter_length):
144 | super(Conv_H, self).__init__()
145 | self.conv = nn.Conv1d(input_channels, output_channels, filter_length,
146 | padding=filter_length//2)
147 | self.bn = nn.BatchNorm1d(output_channels)
148 | self.relu = nn.ReLU()
149 |
150 | def forward(self, x):
151 | freq = x.size(2)
152 | out = nn.AvgPool2d((freq, 1), stride=(freq, 1))(x)
153 | out = out.squeeze(2)
154 | out = self.relu(self.bn(self.conv(out)))
155 | return out
156 |
157 |
158 | # Modules for harmonic filters
159 | def hz_to_midi(hz):
160 | return 12 * (torch.log2(hz) - np.log2(440.0)) + 69
161 |
162 | def midi_to_hz(midi):
163 | return 440.0 * (2.0 ** ((midi - 69.0)/12.0))
164 |
165 | def note_to_midi(note):
166 | return librosa.core.note_to_midi(note)
167 |
168 | def hz_to_note(hz):
169 | return librosa.core.hz_to_note(hz)
170 |
171 | def initialize_filterbank(sample_rate, n_harmonic, semitone_scale):
172 | # MIDI
173 | # lowest note
174 | low_midi = note_to_midi('C1')
175 |
176 | # highest note
177 | high_note = hz_to_note(sample_rate / (2 * n_harmonic))
178 | high_midi = note_to_midi(high_note)
179 |
180 | # number of scales
181 | level = (high_midi - low_midi) * semitone_scale
182 | midi = np.linspace(low_midi, high_midi, level + 1)
183 | hz = midi_to_hz(midi[:-1])
184 |
185 | # stack harmonics
186 | harmonic_hz = []
187 | for i in range(n_harmonic):
188 | harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1)))
189 |
190 | return harmonic_hz, level
191 |
192 |
193 | class HarmonicSTFT(nn.Module):
194 | def __init__(self,
195 | sample_rate=16000,
196 | n_fft=513,
197 | win_length=None,
198 | hop_length=None,
199 | pad=0,
200 | power=2,
201 | normalized=False,
202 | n_harmonic=6,
203 | semitone_scale=2,
204 | bw_Q=1.0,
205 | learn_bw=None):
206 | super(HarmonicSTFT, self).__init__()
207 |
208 | # Parameters
209 | self.sample_rate = sample_rate
210 | self.n_harmonic = n_harmonic
211 | self.bw_alpha = 0.1079
212 | self.bw_beta = 24.7
213 |
214 | # Spectrogram
215 | self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length,
216 | hop_length=None, pad=0,
217 | window_fn=torch.hann_window,
218 | power=power, normalized=normalized, wkwargs=None)
219 | self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
220 |
221 | # Initialize the filterbank. Equally spaced in MIDI scale.
222 | harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale)
223 |
224 | # Center frequncies to tensor
225 | self.f0 = torch.tensor(harmonic_hz.astype('float32'))
226 |
227 | # Bandwidth parameters
228 | if learn_bw == 'only_Q':
229 | self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32')))
230 | elif learn_bw == 'fix':
231 | self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32'))
232 |
233 | def get_harmonic_fb(self):
234 | # bandwidth
235 | bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q
236 | bw = bw.unsqueeze(0) # (1, n_band)
237 | f0 = self.f0.unsqueeze(0) # (1, n_band)
238 | fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1)
239 |
240 | up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw)
241 | down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw)
242 | fb = torch.max(self.zero, torch.min(down_slope, up_slope))
243 | return fb
244 |
245 | def to_device(self, device, n_bins):
246 | self.f0 = self.f0.to(device)
247 | self.bw_Q = self.bw_Q.to(device)
248 | # fft bins
249 | self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins)
250 | self.fft_bins = self.fft_bins.to(device)
251 | self.zero = torch.zeros(1)
252 | self.zero = self.zero.to(device)
253 |
254 | def forward(self, waveform):
255 | # stft
256 | spectrogram = self.spec(waveform)
257 |
258 | # to device
259 | self.to_device(waveform.device, spectrogram.size(1))
260 |
261 | # triangle filter
262 | harmonic_fb = self.get_harmonic_fb()
263 | harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2)
264 |
265 | # (batch, channel, length) -> (batch, harmonic, f0, length)
266 | b, c, l = harmonic_spec.size()
267 | harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l)
268 |
269 | # amplitude to db
270 | harmonic_spec = self.amplitude_to_db(harmonic_spec)
271 | return harmonic_spec
272 |
--------------------------------------------------------------------------------
/evaluation/NDB_JS/compute_ndb_js.py:
--------------------------------------------------------------------------------
1 | import ndb
2 | import argparse
3 | import os
4 | import numpy as np
5 |
6 | if __name__ == '__main__':
7 |
8 | parser = argparse.ArgumentParser(description="compute ndb and JS divergence")
9 |
10 | parser.add_argument("--real_dir", type=str)
11 |
12 | parser.add_argument("--gen_dir", type=str)
13 |
14 | parser.add_argument("--mean_std_dir", type=str, help="directory which has mean and std npy file")
15 |
16 | args = parser.parse_args()
17 |
18 | dim = 80 * 320
19 | n_train = 2000 #Don't change this line, this is the amount of looperman dataset
20 | n_test = 2000 # You can change this line depends on how many generated audio you have in the generation directory
21 |
22 | # load train samples
23 | mean_fp = os.path.join(args.mean_std_dir, f'mean.mel.npy')
24 | std_fp = os.path.join(args.mean_std_dir, f'std.mel.npy')
25 | feat_dim = 80
26 | mean = np.load(mean_fp).reshape((1, feat_dim, 1))
27 | std = np.load(std_fp).reshape((1, feat_dim, 1))
28 |
29 | train_samples = np.zeros(shape = [n_train, dim])
30 |
31 | for i, path in enumerate(os.listdir(args.real_dir)):
32 | train_path = os.path.join(args.real_dir, path)
33 | train_numpy = np.load(train_path)
34 | train_numpy = (train_numpy - mean) / std
35 | train_samples[i, :] = train_numpy.reshape((-1, ))
36 |
37 | # load test samples
38 |
39 |
40 | test_samples = np.zeros(shape = [n_test, dim])
41 |
42 | for i, path in enumerate(os.listdir(args.gen_dir)):
43 |
44 | test_path = os.path.join(args.gen_dir, path)
45 | test_numpy = np.load(test_path)
46 | test_samples[i, :] = test_numpy.reshape((-1, ))
47 |
48 | # NDB and JSD calculation
49 | k = 100
50 | train_ndb = ndb.NDB(training_data=train_samples, number_of_bins=k, whitening=True)
51 |
52 | train_ndb.evaluate(test_samples, model_label='Test')
53 |
--------------------------------------------------------------------------------
/evaluation/NDB_JS/compute_ndb_js.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python compute_ndb_js.py --mean_std_dir ../../data/looperman --real_dir ./looper_2000 --gen_dir ../../generated_freesound_one_bar/freesound_checkpoint/100000/mel_80_320
3 |
--------------------------------------------------------------------------------
/evaluation/NDB_JS/ndb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from sklearn.cluster import KMeans
4 | from scipy.stats import norm
5 | from matplotlib import pyplot as plt
6 | import pickle as pkl
7 |
8 | class NDB:
9 | def __init__(self, training_data=None, number_of_bins=100, significance_level=0.05, z_threshold=None,
10 | whitening=False, max_dims=None, cache_folder=None):
11 | """
12 | NDB Evaluation Class
13 | :param training_data: Optional - the training samples - array of m x d floats (m samples of dimension d)
14 | :param number_of_bins: Number of bins (clusters) default=100
15 | :param significance_level: The statistical significance level for the two-sample test
16 | :param z_threshold: Allow defining a threshold in terms of difference/SE for defining a bin as statistically different
17 | :param whitening: Perform data whitening - subtract mean and divide by per-dimension std
18 | :param max_dims: Max dimensions to use in K-means. By default derived automatically from d
19 | :param bins_file: Optional - file to write / read-from the clusters (to avoid re-calculation)
20 | """
21 | self.number_of_bins = number_of_bins
22 | self.significance_level = significance_level
23 | self.z_threshold = z_threshold
24 | self.whitening = whitening
25 | self.ndb_eps = 1e-6
26 | self.training_mean = 0.0
27 | self.training_std = 1.0
28 | self.max_dims = max_dims
29 | self.cache_folder = cache_folder
30 | self.bin_centers = None
31 | self.bin_proportions = None
32 | self.ref_sample_size = None
33 | self.used_d_indices = None
34 | self.results_file = None
35 | self.test_name = 'ndb_{}_bins_{}'.format(self.number_of_bins, 'whiten' if self.whitening else 'orig')
36 | self.cached_results = {}
37 | if self.cache_folder:
38 | self.results_file = os.path.join(cache_folder, self.test_name+'_results.pkl')
39 | if os.path.isfile(self.results_file):
40 | # print('Loading previous results from', self.results_file, ':')
41 | self.cached_results = pkl.load(open(self.results_file, 'rb'))
42 | # print(self.cached_results.keys())
43 | if training_data is not None or cache_folder is not None:
44 | bins_file = None
45 | if cache_folder:
46 | os.makedirs(cache_folder, exist_ok=True)
47 | bins_file = os.path.join(cache_folder, self.test_name+'.pkl')
48 | self.construct_bins(training_data, bins_file)
49 |
50 | def construct_bins(self, training_samples, bins_file):
51 | """
52 | Performs K-means clustering of the training samples
53 | :param training_samples: An array of m x d floats (m samples of dimension d)
54 | """
55 |
56 | if self.__read_from_bins_file(bins_file):
57 | return
58 | n, d = training_samples.shape
59 | k = self.number_of_bins
60 | if self.whitening:
61 | self.training_mean = np.mean(training_samples, axis=0)
62 | self.training_std = np.std(training_samples, axis=0) + self.ndb_eps
63 |
64 | if self.max_dims is None and d > 1000:
65 | # To ran faster, perform binning on sampled data dimension (i.e. don't use all channels of all pixels)
66 | self.max_dims = d//6
67 |
68 | whitened_samples = (training_samples-self.training_mean)/self.training_std
69 | d_used = d if self.max_dims is None else min(d, self.max_dims)
70 | self.used_d_indices = np.random.choice(d, d_used, replace=False)
71 |
72 | print('Performing K-Means clustering of {} samples in dimension {} / {} to {} clusters ...'.format(n, d_used, d, k))
73 | print('Can take a couple of minutes...')
74 | if n//k > 1000:
75 | print('Training data size should be ~500 times the number of bins (for reasonable speed and accuracy)')
76 |
77 | clusters = KMeans(n_clusters=k, max_iter=100, n_jobs=-1).fit(whitened_samples[:, self.used_d_indices])
78 |
79 | bin_centers = np.zeros([k, d])
80 | for i in range(k):
81 | bin_centers[i, :] = np.mean(whitened_samples[clusters.labels_ == i, :], axis=0)
82 |
83 | # Organize bins by size
84 | label_vals, label_counts = np.unique(clusters.labels_, return_counts=True)
85 | bin_order = np.argsort(-label_counts)
86 | self.bin_proportions = label_counts[bin_order] / np.sum(label_counts)
87 | self.bin_centers = bin_centers[bin_order, :]
88 | self.ref_sample_size = n
89 | self.__write_to_bins_file(bins_file)
90 | print('Done.')
91 |
92 | def evaluate(self, query_samples, model_label=None):
93 | """
94 | Assign each sample to the nearest bin center (in L2). Pre-whiten if required. and calculate the NDB
95 | (Number of statistically Different Bins) and JS divergence scores.
96 | :param query_samples: An array of m x d floats (m samples of dimension d)
97 | :param model_label: optional label string for the evaluated model, allows plotting results of multiple models
98 | :return: results dictionary containing NDB and JS scores and array of labels (assigned bin for each query sample)
99 | """
100 | n = query_samples.shape[0]
101 | query_bin_proportions, query_bin_assignments = self.__calculate_bin_proportions(query_samples)
102 | # print(query_bin_proportions)
103 | different_bins = NDB.two_proportions_z_test(self.bin_proportions, self.ref_sample_size, query_bin_proportions,
104 | n, significance_level=self.significance_level,
105 | z_threshold=self.z_threshold)
106 | ndb = np.count_nonzero(different_bins)
107 | js = NDB.jensen_shannon_divergence(self.bin_proportions, query_bin_proportions)
108 | results = {'NDB': ndb,
109 | 'JS': js,
110 | 'Proportions': query_bin_proportions,
111 | 'N': n,
112 | 'Bin-Assignment': query_bin_assignments,
113 | 'Different-Bins': different_bins}
114 |
115 | if model_label:
116 | print('Results for {} samples from {}: '.format(n, model_label), end='')
117 | self.cached_results[model_label] = results
118 | if self.results_file:
119 | # print('Storing result to', self.results_file)
120 | pkl.dump(self.cached_results, open(self.results_file, 'wb'))
121 |
122 | print('NDB =', ndb, 'NDB/K =', ndb/self.number_of_bins, ', JS =', js)
123 | return results
124 |
125 | def print_results(self):
126 | print('NSB results (K={}{}):'.format(self.number_of_bins, ', data whitening' if self.whitening else ''))
127 | for model in sorted(list(self.cached_results.keys())):
128 | res = self.cached_results[model]
129 | print('%s: NDB = %d, NDB/K = %.3f, JS = %.4f' % (model, res['NDB'], res['NDB']/self.number_of_bins, res['JS']))
130 |
131 | def plot_results(self, models_to_plot=None):
132 | """
133 | Plot the binning proportions of different methods
134 | :param models_to_plot: optional list of model labels to plot
135 | """
136 | K = self.number_of_bins
137 | w = 1.0 / (len(self.cached_results)+1)
138 | assert K == self.bin_proportions.size
139 | assert self.cached_results
140 |
141 | # Used for plotting only
142 | def calc_se(p1, n1, p2, n2):
143 | p = (p1 * n1 + p2 * n2) / (n1 + n2)
144 | return np.sqrt(p * (1 - p) * (1/n1 + 1/n2))
145 |
146 | if not models_to_plot:
147 | models_to_plot = sorted(list(self.cached_results.keys()))
148 |
149 | # Visualize the standard errors using the train proportions and size and query sample size
150 | train_se = calc_se(self.bin_proportions, self.ref_sample_size,
151 | self.bin_proportions, self.cached_results[models_to_plot[0]]['N'])
152 | plt.bar(np.arange(0, K)+0.5, height=train_se*2.0, bottom=self.bin_proportions-train_se,
153 | width=1.0, label='Train$\pm$SE', color='gray')
154 |
155 | ymax = 0.0
156 | for i, model in enumerate(models_to_plot):
157 | results = self.cached_results[model]
158 | label = '%s (%i : %.4f)' % (model, results['NDB'], results['JS'])
159 | ymax = max(ymax, np.max(results['Proportions']))
160 | if K <= 70:
161 | plt.bar(np.arange(0, K)+(i+1.0)*w, results['Proportions'], width=w, label=label)
162 | else:
163 | plt.plot(np.arange(0, K)+0.5, results['Proportions'], '--*', label=label)
164 | plt.legend(loc='best')
165 | plt.ylim((0.0, min(ymax, np.max(self.bin_proportions)*4.0)))
166 | plt.grid(True)
167 | plt.title('Binning Proportions Evaluation Results for {} bins (NDB : JS)'.format(K))
168 | plt.show()
169 |
170 | def __calculate_bin_proportions(self, samples):
171 | if self.bin_centers is None:
172 | print('First run construct_bins on samples from the reference training data')
173 | assert samples.shape[1] == self.bin_centers.shape[1]
174 | n, d = samples.shape
175 | k = self.bin_centers.shape[0]
176 | D = np.zeros([n, k], dtype=samples.dtype)
177 |
178 | print('Calculating bin assignments for {} samples...'.format(n))
179 | whitened_samples = (samples-self.training_mean)/self.training_std
180 | for i in range(k):
181 | print('.', end='', flush=True)
182 | D[:, i] = np.linalg.norm(whitened_samples[:, self.used_d_indices] - self.bin_centers[i, self.used_d_indices],
183 | ord=2, axis=1)
184 | print()
185 | labels = np.argmin(D, axis=1)
186 | probs = np.zeros([k])
187 | label_vals, label_counts = np.unique(labels, return_counts=True)
188 | probs[label_vals] = label_counts / n
189 | return probs, labels
190 |
191 | def __read_from_bins_file(self, bins_file):
192 | if bins_file and os.path.isfile(bins_file):
193 | print('Loading binning results from', bins_file)
194 | bins_data = pkl.load(open(bins_file,'rb'))
195 | self.bin_proportions = bins_data['proportions']
196 | self.bin_centers = bins_data['centers']
197 | self.ref_sample_size = bins_data['n']
198 | self.training_mean = bins_data['mean']
199 | self.training_std = bins_data['std']
200 | self.used_d_indices = bins_data['d_indices']
201 | return True
202 | return False
203 |
204 | def __write_to_bins_file(self, bins_file):
205 | if bins_file:
206 | print('Caching binning results to', bins_file)
207 | bins_data = {'proportions': self.bin_proportions,
208 | 'centers': self.bin_centers,
209 | 'n': self.ref_sample_size,
210 | 'mean': self.training_mean,
211 | 'std': self.training_std,
212 | 'd_indices': self.used_d_indices}
213 | pkl.dump(bins_data, open(bins_file, 'wb'))
214 |
215 | @staticmethod
216 | def two_proportions_z_test(p1, n1, p2, n2, significance_level, z_threshold=None):
217 | # Per http://stattrek.com/hypothesis-test/difference-in-proportions.aspx
218 | # See also http://www.itl.nist.gov/div898/software/dataplot/refman1/auxillar/binotest.htm
219 | p = (p1 * n1 + p2 * n2) / (n1 + n2)
220 | se = np.sqrt(p * (1 - p) * (1/n1 + 1/n2))
221 | z = (p1 - p2) / se
222 | # Allow defining a threshold in terms as Z (difference relative to the SE) rather than in p-values.
223 | if z_threshold is not None:
224 | return abs(z) > z_threshold
225 | p_values = 2.0 * norm.cdf(-1.0 * np.abs(z)) # Two-tailed test
226 | return p_values < significance_level
227 |
228 | @staticmethod
229 | def jensen_shannon_divergence(p, q):
230 | """
231 | Calculates the symmetric Jensen–Shannon divergence between the two PDFs
232 | """
233 | m = (p + q) * 0.5
234 | return 0.5 * (NDB.kl_divergence(p, m) + NDB.kl_divergence(q, m))
235 |
236 | @staticmethod
237 | def kl_divergence(p, q):
238 | """
239 | The Kullback–Leibler divergence.
240 | Defined only if q != 0 whenever p != 0.
241 | """
242 | assert np.all(np.isfinite(p))
243 | assert np.all(np.isfinite(q))
244 | assert not np.any(np.logical_and(p != 0, q == 0))
245 |
246 | p_pos = (p > 0)
247 | return np.sum(p[p_pos] * np.log(p[p_pos] / q[p_pos]))
248 |
249 |
250 | if __name__ == "__main__":
251 | dim=80 * 320
252 | k=100
253 | n_train = k*10
254 | n_test = k*10
255 |
256 | train_samples = np.random.uniform(size=[n_train,dim])
257 | ndb = NDB(training_data=train_samples, number_of_bins=k, whitening=True)
258 |
259 | test_samples = np.random.uniform(high=1.0, size=[n_test, dim])
260 | ndb.evaluate(test_samples, model_label='Test')
261 |
262 | test_samples = np.random.uniform(high=0.9, size=[n_test, dim])
263 | ndb.evaluate(test_samples, model_label='Good')
264 |
265 | test_samples = np.random.uniform(high=0.75, size=[n_test, dim])
266 | ndb.evaluate(test_samples, model_label='Bad')
267 |
268 | ndb.plot_results(models_to_plot=['Test', 'Good', 'Bad'])
269 |
--------------------------------------------------------------------------------
/evaluation/nine_audio/Chillout/1365.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Chillout/1365.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Drum_and_Bass/14.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Drum_and_Bass/14.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Electronic/1353.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Electronic/1353.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Hiphop/323.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Hiphop/323.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Rap/153.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Rap/153.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Rock/752.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Rock/752.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/Trap/1395.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/Trap/1395.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/dupstep/113.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/dupstep/113.wav
--------------------------------------------------------------------------------
/evaluation/nine_audio/industrial/1606.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/evaluation/nine_audio/industrial/1606.wav
--------------------------------------------------------------------------------
/generate_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torchvision import utils
5 | from model_drum import Generator
6 | from tqdm import tqdm
7 |
8 | import sys
9 | sys.path.append('./melgan')
10 | from modules import Generator_melgan
11 |
12 | import yaml
13 | import os
14 |
15 | import librosa
16 |
17 | import soundfile as sf
18 |
19 | import numpy as np
20 |
21 | import os
22 | def read_yaml(fp):
23 | with open(fp) as file:
24 | # return yaml.load(file)
25 | return yaml.load(file, Loader=yaml.Loader)
26 |
27 | def generate(args, g_ema, device, mean_latent):
28 | epoch = args.ckpt.split('.')[0]
29 |
30 | os.makedirs(f'{args.store_path}/{epoch}', exist_ok=True)
31 | os.makedirs(f'{args.store_path}/{epoch}/mel_80_320', exist_ok=True)
32 | feat_dim = 80
33 | mean_fp = f'{args.data_path}/mean.mel.npy'
34 | std_fp = f'{args.data_path}/std.mel.npy'
35 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device)
36 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device)
37 | vocoder_config_fp = './melgan/args.yml'
38 | vocoder_config = read_yaml(vocoder_config_fp)
39 |
40 | n_mel_channels = vocoder_config.n_mel_channels
41 | ngf = vocoder_config.ngf
42 | n_residual_layers = vocoder_config.n_residual_layers
43 | sr=44100
44 |
45 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device)
46 | vocoder.eval()
47 |
48 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt')
49 | vocoder.load_state_dict(torch.load(vocoder_param_fp))
50 |
51 |
52 | with torch.no_grad():
53 | g_ema.eval()
54 | for i in tqdm(range(args.pics)):
55 | sample_z = torch.randn(args.sample, args.latent, device=device)
56 |
57 | sample, _ = g_ema(
58 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
59 | )
60 | np.save(f'{args.store_path}/{epoch}/mel_80_320/{i}.npy', sample.squeeze().data.cpu().numpy())
61 |
62 | utils.save_image(
63 | sample,
64 | f"{args.store_path}/{epoch}/{str(i).zfill(6)}.png",
65 | nrow=1,
66 | normalize=True,
67 | range=(-1, 1),
68 | )
69 | de_norm = sample.squeeze(0) * std + mean
70 | audio_output = vocoder(de_norm)
71 | sf.write(f'{args.store_path}/{epoch}/{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
72 | print('generate {}th wav file'.format(i))
73 | @torch.no_grad()
74 | def style_mixing(args, generator, step, mean_style, n_source, n_target, device, j):
75 | index = 2
76 | # create directory
77 | os.makedirs(f'./generated_interpolation_one_bar_{index}/{j}', exist_ok=True)
78 |
79 | # load melgan vocoder
80 | feat_dim = 80
81 | mean_fp = f'{args.data_path}/mean.mel.npy'
82 | std_fp = f'{args.data_path}/std.mel.npy'
83 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device)
84 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device)
85 | vocoder_config_fp = './melgan/args.yml'
86 | vocoder_config = read_yaml(vocoder_config_fp)
87 |
88 | n_mel_channels = vocoder_config.n_mel_channels
89 | ngf = vocoder_config.ngf
90 | n_residual_layers = vocoder_config.n_residual_layers
91 | sr=44100
92 |
93 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device)
94 | vocoder.eval()
95 |
96 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt')
97 | vocoder.load_state_dict(torch.load(vocoder_param_fp))
98 |
99 | #generate spectrogram
100 | source_code = torch.randn(n_source, 512).to(device)
101 | target_code = torch.randn(n_target, 512).to(device)
102 |
103 | shape = 4 * 2 ** step
104 | alpha = 1
105 |
106 | images = [torch.ones(1, 1, 80, 320).to(device) * -1]
107 |
108 | source_image,_ = generator(
109 | [source_code], truncation=args.truncation, truncation_latent=mean_style
110 | )
111 | target_image,_ = generator(
112 | [target_code], truncation=args.truncation, truncation_latent=mean_style
113 | )
114 |
115 | images.append(source_image)
116 |
117 | for i in range(n_source):
118 | de_norm = source_image[i] * std + mean
119 | audio_output = vocoder(de_norm)
120 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/source_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
121 |
122 | for i in range(n_target):
123 | de_norm = target_image[i] * std + mean
124 | audio_output = vocoder(de_norm)
125 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
126 |
127 | for i in range(n_target):
128 | image, _ = generator(
129 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
130 | truncation_latent=mean_style,
131 | inject_index = index
132 | )
133 |
134 | for k in range(n_source):
135 | de_norm = image[k] * std + mean
136 | audio_output = vocoder(de_norm)
137 | sf.write(f'./generated_interpolation_one_bar_{index}/{j}/source_{k}_target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
138 |
139 | images.append(target_image[i].unsqueeze(0))
140 | images.append(image)
141 |
142 | images = torch.cat(images, 0)
143 | utils.save_image(
144 | images, f'./generated_interpolation_one_bar_{index}/{j}/sample_mixing.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1)
145 | )
146 | return images
147 | if __name__ == "__main__":
148 | device = "cuda"
149 |
150 | parser = argparse.ArgumentParser(description="Generate samples from the generator")
151 |
152 | parser.add_argument(
153 | "--size", type=int, default=64, help="output image size of the generator"
154 | )
155 | parser.add_argument(
156 | "--sample",
157 | type=int,
158 | default=1,
159 | help="number of samples to be generated for each image",
160 | )
161 | parser.add_argument(
162 | "--pics", type=int, default=20, help="number of images to be generated"
163 | )
164 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
165 | parser.add_argument(
166 | "--truncation_mean",
167 | type=int,
168 | default=4096,
169 | help="number of vectors to calculate mean for the truncation",
170 | )
171 | parser.add_argument(
172 | "--ckpt",
173 | type=str,
174 | default="stylegan2-ffhq-config-f.pt",
175 | help="path to the model checkpoint",
176 | )
177 | parser.add_argument(
178 | "--data_path",
179 | type=str,
180 | help="path store the std and mean of mel",
181 | )
182 | parser.add_argument(
183 | "--store_path",
184 | type=str,
185 | help="path store the generated audio",
186 | )
187 | parser.add_argument(
188 | "--channel_multiplier",
189 | type=int,
190 | default=2,
191 | help="channel multiplier of the generator. config-f = 2, else = 1",
192 | )
193 | parser.add_argument("--style_mixing", action = "store_true")
194 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix')
195 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix')
196 | args = parser.parse_args()
197 |
198 | args.latent = 512
199 | args.n_mlp = 8
200 |
201 | g_ema = Generator(
202 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
203 | ).to(device)
204 | checkpoint = torch.load(args.ckpt)
205 |
206 | g_ema.load_state_dict(checkpoint["g_ema"])
207 |
208 | if args.truncation < 1:
209 | with torch.no_grad():
210 | mean_latent = g_ema.mean_latent(args.truncation_mean)
211 | else:
212 | mean_latent = None
213 |
214 | # Generate audio
215 | generate(args, g_ema, device, mean_latent)
216 |
217 | #Style mixing
218 | if args.style_mixing == True:
219 | step = 0
220 | for j in range(20):
221 | img = style_mixing(args,g_ema, step, mean_latent, args.n_col, args.n_row, device, j)
222 |
--------------------------------------------------------------------------------
/generate_looperman_four_bar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torchvision import utils
5 | from model_drum_four_bar import Generator
6 | from tqdm import tqdm
7 |
8 | import sys
9 | sys.path.append('./melgan')
10 | from modules import Generator_melgan
11 |
12 | import yaml
13 | import os
14 |
15 | import librosa
16 |
17 | import soundfile as sf
18 |
19 | import numpy as np
20 |
21 | import os
22 | def read_yaml(fp):
23 | with open(fp) as file:
24 | # return yaml.load(file)
25 | return yaml.load(file, Loader=yaml.Loader)
26 |
27 | def generate(args, g_ema, device, mean_latent):
28 | epoch = args.ckpt.split('.')[0]
29 |
30 | os.makedirs(f'{args.store_path}/{epoch}', exist_ok=True)
31 | os.makedirs(f'{args.store_path}/{epoch}/mel_80_320', exist_ok=True)
32 | feat_dim = 80
33 | mean_fp = f'{args.data_path}/mean.mel.npy'
34 | std_fp = f'{args.data_path}/std.mel.npy'
35 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device)
36 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device)
37 | vocoder_config_fp = './melgan/args.yml'
38 | vocoder_config = read_yaml(vocoder_config_fp)
39 |
40 | n_mel_channels = vocoder_config.n_mel_channels
41 | ngf = vocoder_config.ngf
42 | n_residual_layers = vocoder_config.n_residual_layers
43 | sr=44100
44 |
45 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device)
46 | vocoder.eval()
47 |
48 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt')
49 | vocoder.load_state_dict(torch.load(vocoder_param_fp))
50 |
51 |
52 | with torch.no_grad():
53 | g_ema.eval()
54 | for i in tqdm(range(args.pics)):
55 | sample_z = torch.randn(args.sample, args.latent, device=device)
56 |
57 | sample, _ = g_ema(
58 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
59 | )
60 | np.save(f'{args.store_path}/{epoch}/mel_80_320/{i}.npy', sample.squeeze().data.cpu().numpy())
61 |
62 | utils.save_image(
63 | sample,
64 | f"{args.store_path}/{epoch}/{str(i).zfill(6)}.png",
65 | nrow=1,
66 | normalize=True,
67 | range=(-1, 1),
68 | )
69 | de_norm = sample.squeeze(0) * std + mean
70 | audio_output = vocoder(de_norm)
71 | sf.write(f'{args.store_path}/{epoch}/{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
72 | print('generate {}th wav file'.format(i))
73 | @torch.no_grad()
74 | def style_mixing(args, generator, step, mean_style, n_source, n_target, device, j):
75 | index = 5
76 | # create directory
77 | os.makedirs(f'./generated_interpolation_{index}/{j}', exist_ok=True)
78 |
79 | # load melgan vocoder
80 | feat_dim = 80
81 | mean_fp = f'{args.data_path}/mean.mel.npy'
82 | std_fp = f'{args.data_path}/std.mel.npy'
83 | mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1).to(device)
84 | std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1).to(device)
85 | vocoder_config_fp = './melgan/args.yml'
86 | vocoder_config = read_yaml(vocoder_config_fp)
87 |
88 | n_mel_channels = vocoder_config.n_mel_channels
89 | ngf = vocoder_config.ngf
90 | n_residual_layers = vocoder_config.n_residual_layers
91 | sr=44100
92 |
93 | vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(device)
94 | vocoder.eval()
95 |
96 | vocoder_param_fp = os.path.join('./melgan', 'best_netG.pt')
97 | vocoder.load_state_dict(torch.load(vocoder_param_fp))
98 |
99 | #generate spectrogram
100 | source_code = torch.randn(n_source, 512).to(device)
101 | target_code = torch.randn(n_target, 512).to(device)
102 |
103 | shape = 4 * 2 ** step
104 | alpha = 1
105 |
106 | images = [torch.ones(1, 1, 80, 320).to(device) * -1]
107 |
108 | source_image,_ = generator(
109 | [source_code], truncation=args.truncation, truncation_latent=mean_style
110 | )
111 | target_image,_ = generator(
112 | [target_code], truncation=args.truncation, truncation_latent=mean_style
113 | )
114 |
115 | images.append(source_image)
116 |
117 | for i in range(n_source):
118 | de_norm = source_image[i] * std + mean
119 | audio_output = vocoder(de_norm)
120 | sf.write(f'./generated_interpolation_{index}/{j}/source_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
121 |
122 | for i in range(n_target):
123 | de_norm = target_image[i] * std + mean
124 | audio_output = vocoder(de_norm)
125 | sf.write(f'./generated_interpolation_{index}/{j}/target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
126 |
127 | for i in range(n_target):
128 | image, _ = generator(
129 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
130 | truncation_latent=mean_style,
131 | inject_index = index
132 | )
133 |
134 | for k in range(n_source):
135 | de_norm = image[k] * std + mean
136 | audio_output = vocoder(de_norm)
137 | sf.write(f'./generated_interpolation_{index}/{j}/source_{k}_target_{i}.wav', audio_output.squeeze().detach().cpu().numpy(), sr)
138 |
139 | images.append(target_image[i].unsqueeze(0))
140 | images.append(image)
141 |
142 | images = torch.cat(images, 0)
143 | utils.save_image(
144 | images, f'./generated_interpolation_{index}/{j}/sample_mixing.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1)
145 | )
146 | return images
147 | if __name__ == "__main__":
148 | device = "cuda"
149 |
150 | parser = argparse.ArgumentParser(description="Generate samples from the generator")
151 |
152 | parser.add_argument(
153 | "--size", type=int, default=64, help="output image size of the generator"
154 | )
155 | parser.add_argument(
156 | "--sample",
157 | type=int,
158 | default=1,
159 | help="number of samples to be generated for each image",
160 | )
161 | parser.add_argument(
162 | "--pics", type=int, default=20, help="number of images to be generated"
163 | )
164 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
165 | parser.add_argument(
166 | "--truncation_mean",
167 | type=int,
168 | default=4096,
169 | help="number of vectors to calculate mean for the truncation",
170 | )
171 | parser.add_argument(
172 | "--ckpt",
173 | type=str,
174 | default="stylegan2-ffhq-config-f.pt",
175 | help="path to the model checkpoint",
176 | )
177 | parser.add_argument(
178 | "--data_path",
179 | type=str,
180 | help="path store the std and mean of mel",
181 | )
182 | parser.add_argument(
183 | "--store_path",
184 | type=str,
185 | help="path store the generated audio",
186 | )
187 | parser.add_argument(
188 | "--channel_multiplier",
189 | type=int,
190 | default=2,
191 | help="channel multiplier of the generator. config-f = 2, else = 1",
192 | )
193 |
194 | parser.add_argument("--style_mixing", action = "store_true")
195 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix')
196 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix')
197 | args = parser.parse_args()
198 |
199 | args.latent = 512
200 | args.n_mlp = 8
201 |
202 | g_ema = Generator(
203 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
204 | ).to(device)
205 | checkpoint = torch.load(args.ckpt)
206 |
207 | g_ema.load_state_dict(checkpoint["g_ema"])
208 |
209 | if args.truncation < 1:
210 | with torch.no_grad():
211 | mean_latent = g_ema.mean_latent(args.truncation_mean)
212 | else:
213 | mean_latent = None
214 |
215 | generate(args, g_ema, device, mean_latent)
216 | # Style mixing
217 | if args.style_mixing == True:
218 | step = 0
219 | for j in range(20):
220 | img = style_mixing(args,g_ema, step, mean_latent, args.n_col, args.n_row, device, j)
221 |
--------------------------------------------------------------------------------
/melgan/.ipynb_checkpoints/modules-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from librosa.filters import mel as librosa_mel_fn
5 | from torch.nn.utils import weight_norm
6 | import numpy as np
7 |
8 |
9 | def weights_init(m):
10 | classname = m.__class__.__name__
11 | if classname.find("Conv") != -1:
12 | m.weight.data.normal_(0.0, 0.02)
13 | elif classname.find("BatchNorm2d") != -1:
14 | m.weight.data.normal_(1.0, 0.02)
15 | m.bias.data.fill_(0)
16 |
17 |
18 | def WNConv1d(*args, **kwargs):
19 | return weight_norm(nn.Conv1d(*args, **kwargs))
20 |
21 |
22 | def WNConvTranspose1d(*args, **kwargs):
23 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
24 |
25 |
26 | class Audio2Mel(nn.Module):
27 | def __init__(
28 | self,
29 | n_fft=1024,
30 | hop_length=256,
31 | win_length=1024,
32 | sampling_rate=22050,
33 | n_mel_channels=80,
34 | mel_fmin=0.0,
35 | mel_fmax=None,
36 | ):
37 | super().__init__()
38 | ##############################################
39 | # FFT Parameters #
40 | ##############################################
41 | window = torch.hann_window(win_length).float()
42 | mel_basis = librosa_mel_fn(
43 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax
44 | )
45 | mel_basis = torch.from_numpy(mel_basis).float()
46 | self.register_buffer("mel_basis", mel_basis)
47 | self.register_buffer("window", window)
48 | self.n_fft = n_fft
49 | self.hop_length = hop_length
50 | self.win_length = win_length
51 | self.sampling_rate = sampling_rate
52 | self.n_mel_channels = n_mel_channels
53 |
54 | def forward(self, audio):
55 | p = (self.n_fft - self.hop_length) // 2
56 | audio = F.pad(audio, (p, p), "reflect").squeeze(1)
57 | fft = torch.stft(
58 | audio,
59 | n_fft=self.n_fft,
60 | hop_length=self.hop_length,
61 | win_length=self.win_length,
62 | window=self.window,
63 | center=False,
64 | )
65 | real_part, imag_part = fft.unbind(-1)
66 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
67 | mel_output = torch.matmul(self.mel_basis, magnitude)
68 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
69 | return log_mel_spec
70 |
71 |
72 | class ResnetBlock(nn.Module):
73 | def __init__(self, dim, dilation=1):
74 | super().__init__()
75 | self.block = nn.Sequential(
76 | nn.LeakyReLU(0.2),
77 | nn.ReflectionPad1d(dilation),
78 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
79 | nn.LeakyReLU(0.2),
80 | WNConv1d(dim, dim, kernel_size=1),
81 | )
82 | self.shortcut = WNConv1d(dim, dim, kernel_size=1)
83 |
84 | def forward(self, x):
85 | return self.shortcut(x) + self.block(x)
86 |
87 |
88 | class Generator_melgan(nn.Module):
89 | def __init__(self, input_size, ngf, n_residual_layers):
90 | super().__init__()
91 | ratios = [8, 8, 2, 2]
92 | self.hop_length = np.prod(ratios)
93 | mult = int(2 ** len(ratios))
94 |
95 | model = [
96 | nn.ReflectionPad1d(3),
97 | WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
98 | ]
99 |
100 | # Upsample to raw audio scale
101 | for i, r in enumerate(ratios):
102 | model += [
103 | nn.LeakyReLU(0.2),
104 | WNConvTranspose1d(
105 | mult * ngf,
106 | mult * ngf // 2,
107 | kernel_size=r * 2,
108 | stride=r,
109 | padding=r // 2 + r % 2,
110 | output_padding=r % 2,
111 | ),
112 | ]
113 |
114 | for j in range(n_residual_layers):
115 | model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
116 |
117 | mult //= 2
118 |
119 | model += [
120 | nn.LeakyReLU(0.2),
121 | nn.ReflectionPad1d(3),
122 | WNConv1d(ngf, 1, kernel_size=7, padding=0),
123 | nn.Tanh(),
124 | ]
125 |
126 | self.model = nn.Sequential(*model)
127 | self.apply(weights_init)
128 |
129 | def forward(self, x):
130 | return self.model(x)
131 |
132 |
133 | class NLayerDiscriminator(nn.Module):
134 | def __init__(self, ndf, n_layers, downsampling_factor):
135 | super().__init__()
136 | model = nn.ModuleDict()
137 |
138 | model["layer_0"] = nn.Sequential(
139 | nn.ReflectionPad1d(7),
140 | WNConv1d(1, ndf, kernel_size=15),
141 | nn.LeakyReLU(0.2, True),
142 | )
143 |
144 | nf = ndf
145 | stride = downsampling_factor
146 | for n in range(1, n_layers + 1):
147 | nf_prev = nf
148 | nf = min(nf * stride, 1024)
149 |
150 | model["layer_%d" % n] = nn.Sequential(
151 | WNConv1d(
152 | nf_prev,
153 | nf,
154 | kernel_size=stride * 10 + 1,
155 | stride=stride,
156 | padding=stride * 5,
157 | groups=nf_prev // 4,
158 | ),
159 | nn.LeakyReLU(0.2, True),
160 | )
161 |
162 | nf = min(nf * 2, 1024)
163 | model["layer_%d" % (n_layers + 1)] = nn.Sequential(
164 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
165 | nn.LeakyReLU(0.2, True),
166 | )
167 |
168 | model["layer_%d" % (n_layers + 2)] = WNConv1d(
169 | nf, 1, kernel_size=3, stride=1, padding=1
170 | )
171 |
172 | self.model = model
173 |
174 | def forward(self, x):
175 | results = []
176 | for key, layer in self.model.items():
177 | x = layer(x)
178 | results.append(x)
179 | return results
180 |
181 |
182 | class Discriminator(nn.Module):
183 | def __init__(self, num_D, ndf, n_layers, downsampling_factor):
184 | super().__init__()
185 | self.model = nn.ModuleDict()
186 | for i in range(num_D):
187 | self.model[f"disc_{i}"] = NLayerDiscriminator(
188 | ndf, n_layers, downsampling_factor
189 | )
190 |
191 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
192 | self.apply(weights_init)
193 |
194 | def forward(self, x):
195 | results = []
196 | for key, disc in self.model.items():
197 | results.append(disc(x))
198 | x = self.downsample(x)
199 | return results
200 |
--------------------------------------------------------------------------------
/melgan/args.yml:
--------------------------------------------------------------------------------
1 | !!python/object:argparse.Namespace
2 | batch_size: 16
3 | cond_disc: false
4 | data_path: !!python/object/apply:pathlib.PosixPath
5 | - /
6 | - home
7 | - allenhung
8 | - nas189
9 | - home
10 | - one_bar_loop_all
11 | downsamp_factor: 4
12 | epochs: 30000
13 | lambda_feat: 10
14 | load_path: null
15 | log_interval: 100
16 | n_layers_D: 4
17 | n_mel_channels: 80
18 | n_residual_layers: 3
19 | n_test_samples: 8
20 | ndf: 16
21 | ngf: 32
22 | num_D: 3
23 | save_interval: 1000
24 | save_path: logs/one_bar_all
25 | seq_len: 8192
26 |
--------------------------------------------------------------------------------
/melgan/best_netG.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenhung1025/LoopTest/b105df76344e4393f1db9116bc64cbffc1dd83aa/melgan/best_netG.pt
--------------------------------------------------------------------------------
/melgan/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from librosa.filters import mel as librosa_mel_fn
5 | from torch.nn.utils import weight_norm
6 | import numpy as np
7 |
8 |
9 | def weights_init(m):
10 | classname = m.__class__.__name__
11 | if classname.find("Conv") != -1:
12 | m.weight.data.normal_(0.0, 0.02)
13 | elif classname.find("BatchNorm2d") != -1:
14 | m.weight.data.normal_(1.0, 0.02)
15 | m.bias.data.fill_(0)
16 |
17 |
18 | def WNConv1d(*args, **kwargs):
19 | return weight_norm(nn.Conv1d(*args, **kwargs))
20 |
21 |
22 | def WNConvTranspose1d(*args, **kwargs):
23 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
24 |
25 |
26 | class Audio2Mel(nn.Module):
27 | def __init__(
28 | self,
29 | n_fft=1024,
30 | hop_length=256,
31 | win_length=1024,
32 | sampling_rate=22050,
33 | n_mel_channels=80,
34 | mel_fmin=0.0,
35 | mel_fmax=None,
36 | ):
37 | super().__init__()
38 | ##############################################
39 | # FFT Parameters #
40 | ##############################################
41 | window = torch.hann_window(win_length).float()
42 | mel_basis = librosa_mel_fn(
43 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax
44 | )
45 | mel_basis = torch.from_numpy(mel_basis).float()
46 | self.register_buffer("mel_basis", mel_basis)
47 | self.register_buffer("window", window)
48 | self.n_fft = n_fft
49 | self.hop_length = hop_length
50 | self.win_length = win_length
51 | self.sampling_rate = sampling_rate
52 | self.n_mel_channels = n_mel_channels
53 |
54 | def forward(self, audio):
55 | p = (self.n_fft - self.hop_length) // 2
56 | audio = F.pad(audio, (p, p), "reflect").squeeze(1)
57 | fft = torch.stft(
58 | audio,
59 | n_fft=self.n_fft,
60 | hop_length=self.hop_length,
61 | win_length=self.win_length,
62 | window=self.window,
63 | center=False,
64 | )
65 | real_part, imag_part = fft.unbind(-1)
66 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
67 | mel_output = torch.matmul(self.mel_basis, magnitude)
68 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
69 | return log_mel_spec
70 |
71 |
72 | class ResnetBlock(nn.Module):
73 | def __init__(self, dim, dilation=1):
74 | super().__init__()
75 | self.block = nn.Sequential(
76 | nn.LeakyReLU(0.2),
77 | nn.ReflectionPad1d(dilation),
78 | WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
79 | nn.LeakyReLU(0.2),
80 | WNConv1d(dim, dim, kernel_size=1),
81 | )
82 | self.shortcut = WNConv1d(dim, dim, kernel_size=1)
83 |
84 | def forward(self, x):
85 | return self.shortcut(x) + self.block(x)
86 |
87 |
88 | class Generator_melgan(nn.Module):
89 | def __init__(self, input_size, ngf, n_residual_layers):
90 | super().__init__()
91 | ratios = [8, 8, 2, 2]
92 | self.hop_length = np.prod(ratios)
93 | mult = int(2 ** len(ratios))
94 |
95 | model = [
96 | nn.ReflectionPad1d(3),
97 | WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0),
98 | ]
99 |
100 | # Upsample to raw audio scale
101 | for i, r in enumerate(ratios):
102 | model += [
103 | nn.LeakyReLU(0.2),
104 | WNConvTranspose1d(
105 | mult * ngf,
106 | mult * ngf // 2,
107 | kernel_size=r * 2,
108 | stride=r,
109 | padding=r // 2 + r % 2,
110 | output_padding=r % 2,
111 | ),
112 | ]
113 |
114 | for j in range(n_residual_layers):
115 | model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
116 |
117 | mult //= 2
118 |
119 | model += [
120 | nn.LeakyReLU(0.2),
121 | nn.ReflectionPad1d(3),
122 | WNConv1d(ngf, 1, kernel_size=7, padding=0),
123 | nn.Tanh(),
124 | ]
125 |
126 | self.model = nn.Sequential(*model)
127 | self.apply(weights_init)
128 |
129 | def forward(self, x):
130 | return self.model(x)
131 |
132 |
133 | class NLayerDiscriminator(nn.Module):
134 | def __init__(self, ndf, n_layers, downsampling_factor):
135 | super().__init__()
136 | model = nn.ModuleDict()
137 |
138 | model["layer_0"] = nn.Sequential(
139 | nn.ReflectionPad1d(7),
140 | WNConv1d(1, ndf, kernel_size=15),
141 | nn.LeakyReLU(0.2, True),
142 | )
143 |
144 | nf = ndf
145 | stride = downsampling_factor
146 | for n in range(1, n_layers + 1):
147 | nf_prev = nf
148 | nf = min(nf * stride, 1024)
149 |
150 | model["layer_%d" % n] = nn.Sequential(
151 | WNConv1d(
152 | nf_prev,
153 | nf,
154 | kernel_size=stride * 10 + 1,
155 | stride=stride,
156 | padding=stride * 5,
157 | groups=nf_prev // 4,
158 | ),
159 | nn.LeakyReLU(0.2, True),
160 | )
161 |
162 | nf = min(nf * 2, 1024)
163 | model["layer_%d" % (n_layers + 1)] = nn.Sequential(
164 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
165 | nn.LeakyReLU(0.2, True),
166 | )
167 |
168 | model["layer_%d" % (n_layers + 2)] = WNConv1d(
169 | nf, 1, kernel_size=3, stride=1, padding=1
170 | )
171 |
172 | self.model = model
173 |
174 | def forward(self, x):
175 | results = []
176 | for key, layer in self.model.items():
177 | x = layer(x)
178 | results.append(x)
179 | return results
180 |
181 |
182 | class Discriminator(nn.Module):
183 | def __init__(self, num_D, ndf, n_layers, downsampling_factor):
184 | super().__init__()
185 | self.model = nn.ModuleDict()
186 | for i in range(num_D):
187 | self.model[f"disc_{i}"] = NLayerDiscriminator(
188 | ndf, n_layers, downsampling_factor
189 | )
190 |
191 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
192 | self.apply(weights_init)
193 |
194 | def forward(self, x):
195 | results = []
196 | for key, disc in self.model.items():
197 | results.append(disc(x))
198 | x = self.downsample(x)
199 | return results
200 |
--------------------------------------------------------------------------------
/model_drum.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import functools
4 | import operator
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.autograd import Function
10 |
11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
12 |
13 |
14 | class PixelNorm(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20 |
21 |
22 | def make_kernel(k):
23 | k = torch.tensor(k, dtype=torch.float32)
24 |
25 | if k.ndim == 1:
26 | k = k[None, :] * k[:, None]
27 |
28 | k /= k.sum()
29 |
30 | return k
31 |
32 |
33 | class Upsample(nn.Module):
34 | def __init__(self, kernel, factor=2):
35 | super().__init__()
36 |
37 | self.factor = factor
38 | kernel = make_kernel(kernel) * (factor ** 2)
39 | self.register_buffer("kernel", kernel)
40 |
41 | p = kernel.shape[0] - factor
42 |
43 | pad0 = (p + 1) // 2 + factor - 1
44 | pad1 = p // 2
45 |
46 | self.pad = (pad0, pad1)
47 |
48 | def forward(self, input):
49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50 |
51 | return out
52 |
53 |
54 | class Downsample(nn.Module):
55 | def __init__(self, kernel, factor=2):
56 | super().__init__()
57 |
58 | self.factor = factor
59 | kernel = make_kernel(kernel)
60 | self.register_buffer("kernel", kernel)
61 |
62 | p = kernel.shape[0] - factor
63 |
64 | pad0 = (p + 1) // 2
65 | pad1 = p // 2
66 |
67 | self.pad = (pad0, pad1)
68 |
69 | def forward(self, input):
70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71 |
72 | return out
73 |
74 |
75 | class Blur(nn.Module):
76 | def __init__(self, kernel, pad, upsample_factor=1):
77 | super().__init__()
78 |
79 | kernel = make_kernel(kernel)
80 |
81 | if upsample_factor > 1:
82 | kernel = kernel * (upsample_factor ** 2)
83 |
84 | self.register_buffer("kernel", kernel)
85 |
86 | self.pad = pad
87 |
88 | def forward(self, input):
89 | out = upfirdn2d(input, self.kernel, pad=self.pad)
90 |
91 | return out
92 |
93 |
94 | class EqualConv2d(nn.Module):
95 | def __init__(
96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97 | ):
98 | super().__init__()
99 |
100 | self.weight = nn.Parameter(
101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102 | )
103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104 |
105 | self.stride = stride
106 | self.padding = padding
107 |
108 | if bias:
109 | self.bias = nn.Parameter(torch.zeros(out_channel))
110 |
111 | else:
112 | self.bias = None
113 |
114 | def forward(self, input):
115 | out = F.conv2d(
116 | input,
117 | self.weight * self.scale,
118 | bias=self.bias,
119 | stride=self.stride,
120 | padding=self.padding,
121 | )
122 |
123 | return out
124 |
125 | def __repr__(self):
126 | return (
127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129 | )
130 |
131 |
132 | class EqualLinear(nn.Module):
133 | def __init__(
134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135 | ):
136 | super().__init__()
137 |
138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139 |
140 | if bias:
141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142 |
143 | else:
144 | self.bias = None
145 |
146 | self.activation = activation
147 |
148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149 | self.lr_mul = lr_mul
150 |
151 | def forward(self, input):
152 | if self.activation:
153 | out = F.linear(input, self.weight * self.scale)
154 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
155 |
156 | else:
157 | out = F.linear(
158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
159 | )
160 |
161 | return out
162 |
163 | def __repr__(self):
164 | return (
165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166 | )
167 |
168 |
169 | class ModulatedConv2d(nn.Module):
170 | def __init__(
171 | self,
172 | in_channel,
173 | out_channel,
174 | kernel_size,
175 | style_dim,
176 | demodulate=True,
177 | upsample=False,
178 | downsample=False,
179 | blur_kernel=[1, 3, 3, 1],
180 | ):
181 | super().__init__()
182 |
183 | self.eps = 1e-8
184 | self.kernel_size = kernel_size
185 | self.in_channel = in_channel
186 | self.out_channel = out_channel
187 | self.upsample = upsample
188 | self.downsample = downsample
189 |
190 | if upsample:
191 | factor = 2
192 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
193 | pad0 = (p + 1) // 2 + factor - 1
194 | pad1 = p // 2 + 1
195 |
196 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
197 |
198 | if downsample:
199 | factor = 2
200 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
201 | pad0 = (p + 1) // 2
202 | pad1 = p // 2
203 |
204 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
205 |
206 | fan_in = in_channel * kernel_size ** 2
207 | self.scale = 1 / math.sqrt(fan_in)
208 | self.padding = kernel_size // 2
209 |
210 | self.weight = nn.Parameter(
211 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
212 | )
213 |
214 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
215 |
216 | self.demodulate = demodulate
217 |
218 | def __repr__(self):
219 | return (
220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221 | f"upsample={self.upsample}, downsample={self.downsample})"
222 | )
223 |
224 | def forward(self, input, style):
225 | batch, in_channel, height, width = input.shape
226 |
227 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
228 | weight = self.scale * self.weight * style
229 |
230 | if self.demodulate:
231 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
232 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
233 |
234 | weight = weight.view(
235 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
236 | )
237 |
238 | if self.upsample:
239 | input = input.view(1, batch * in_channel, height, width)
240 | weight = weight.view(
241 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
242 | )
243 | weight = weight.transpose(1, 2).reshape(
244 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
245 | )
246 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
247 | _, _, height, width = out.shape
248 | out = out.view(batch, self.out_channel, height, width)
249 | out = self.blur(out)
250 |
251 | elif self.downsample:
252 | input = self.blur(input)
253 | _, _, height, width = input.shape
254 | input = input.view(1, batch * in_channel, height, width)
255 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
256 | _, _, height, width = out.shape
257 | out = out.view(batch, self.out_channel, height, width)
258 |
259 | else:
260 | input = input.view(1, batch * in_channel, height, width)
261 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
262 | _, _, height, width = out.shape
263 | out = out.view(batch, self.out_channel, height, width)
264 |
265 | return out
266 |
267 |
268 | class NoiseInjection(nn.Module):
269 | def __init__(self):
270 | super().__init__()
271 |
272 | self.weight = nn.Parameter(torch.zeros(1))
273 |
274 | def forward(self, image, noise=None):
275 | if noise is None:
276 | batch, _, height, width = image.shape
277 | noise = image.new_empty(batch, 1, height, width).normal_()
278 |
279 | return image + self.weight * noise
280 |
281 |
282 | class ConstantInput(nn.Module):
283 | def __init__(self, channel, size=4):
284 | super().__init__()
285 |
286 | self.input = nn.Parameter(torch.randn(1, channel, 5, 20))
287 |
288 | def forward(self, input):
289 | batch = input.shape[0]
290 | out = self.input.repeat(batch, 1, 1, 1)
291 |
292 | return out
293 |
294 |
295 | class StyledConv(nn.Module):
296 | def __init__(
297 | self,
298 | in_channel,
299 | out_channel,
300 | kernel_size,
301 | style_dim,
302 | upsample=False,
303 | blur_kernel=[1, 3, 3, 1],
304 | demodulate=True,
305 | ):
306 | super().__init__()
307 |
308 | self.conv = ModulatedConv2d(
309 | in_channel,
310 | out_channel,
311 | kernel_size,
312 | style_dim,
313 | upsample=upsample,
314 | blur_kernel=blur_kernel,
315 | demodulate=demodulate,
316 | )
317 |
318 | self.noise = NoiseInjection()
319 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
320 | # self.activate = ScaledLeakyReLU(0.2)
321 | self.activate = FusedLeakyReLU(out_channel)
322 |
323 | def forward(self, input, style, noise=None):
324 | out = self.conv(input, style)
325 | out = self.noise(out, noise=noise)
326 | # out = out + self.bias
327 | out = self.activate(out)
328 |
329 | return out
330 |
331 |
332 | class ToRGB(nn.Module):
333 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
334 | super().__init__()
335 |
336 | if upsample:
337 | self.upsample = Upsample(blur_kernel)
338 |
339 | self.conv = ModulatedConv2d(in_channel, 1, 1, style_dim, demodulate=False)
340 | self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
341 |
342 | def forward(self, input, style, skip=None):
343 | out = self.conv(input, style)
344 | out = out + self.bias
345 |
346 | if skip is not None:
347 | skip = self.upsample(skip)
348 |
349 | out = out + skip
350 |
351 | return out
352 |
353 |
354 | class Generator(nn.Module):
355 | def __init__(
356 | self,
357 | size,
358 | style_dim,
359 | n_mlp,
360 | channel_multiplier=2,
361 | blur_kernel=[1, 3, 3, 1],
362 | lr_mlp=0.01,
363 | ):
364 | super().__init__()
365 |
366 | self.size = size
367 |
368 | self.style_dim = style_dim
369 |
370 | layers = [PixelNorm()]
371 |
372 | for i in range(n_mlp):
373 | layers.append(
374 | EqualLinear(
375 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
376 | )
377 | )
378 |
379 | self.style = nn.Sequential(*layers)
380 |
381 | self.channels = {
382 | 4: 512,
383 | 8: 512,
384 | 16: 512,
385 | 32: 512,
386 | 64: 256 * channel_multiplier,
387 | 128: 128 * channel_multiplier,
388 | 256: 64 * channel_multiplier,
389 | 512: 32 * channel_multiplier,
390 | 1024: 16 * channel_multiplier,
391 | }
392 |
393 | self.input = ConstantInput(self.channels[4])
394 | self.conv1 = StyledConv(
395 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
396 | )
397 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
398 |
399 | self.log_size = int(math.log(size, 2))
400 | self.num_layers = (self.log_size - 2) * 2 + 1 #9
401 |
402 | self.convs = nn.ModuleList()
403 | self.upsamples = nn.ModuleList()
404 | self.to_rgbs = nn.ModuleList()
405 | self.noises = nn.Module()
406 |
407 | in_channel = self.channels[4]
408 |
409 | for layer_idx in range(self.num_layers):
410 | res = (layer_idx + 5) // 2
411 | shape = [1, 1, 2 ** res, 2 ** res]
412 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
413 |
414 | for i in range(3, self.log_size + 1):
415 | out_channel = self.channels[2 ** i]
416 |
417 | self.convs.append(
418 | StyledConv(
419 | in_channel,
420 | out_channel,
421 | 3,
422 | style_dim,
423 | upsample=True,
424 | blur_kernel=blur_kernel,
425 | )
426 | )
427 |
428 | self.convs.append(
429 | StyledConv(
430 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
431 | )
432 | )
433 |
434 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
435 |
436 | in_channel = out_channel
437 |
438 | self.n_latent = self.log_size * 2 - 2
439 |
440 | def make_noise(self):
441 | device = self.input.input.device
442 |
443 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
444 |
445 | for i in range(3, self.log_size + 1):
446 | for _ in range(2):
447 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
448 |
449 | return noises
450 |
451 | def mean_latent(self, n_latent):
452 | latent_in = torch.randn(
453 | n_latent, self.style_dim, device=self.input.input.device
454 | )
455 | latent = self.style(latent_in).mean(0, keepdim=True)
456 |
457 | return latent
458 |
459 | def get_latent(self, input):
460 | return self.style(input)
461 |
462 | def forward(
463 | self,
464 | styles,
465 | return_latents=False,
466 | inject_index=None,
467 | truncation=1,
468 | truncation_latent=None,
469 | input_is_latent=False,
470 | noise=None,
471 | randomize_noise=True,
472 | ):
473 | if not input_is_latent:
474 | styles = [self.style(s) for s in styles]
475 |
476 | if noise is None:
477 | if randomize_noise:
478 | noise = [None] * self.num_layers
479 | else:
480 | noise = [
481 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
482 | ]
483 |
484 | if truncation < 1:
485 | style_t = []
486 |
487 | for style in styles:
488 | style_t.append(
489 | truncation_latent + truncation * (style - truncation_latent)
490 | )
491 |
492 | styles = style_t
493 |
494 | if len(styles) < 2:
495 | inject_index = self.n_latent
496 |
497 | if styles[0].ndim < 3:
498 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
499 |
500 | else:
501 | latent = styles[0]
502 |
503 | else:
504 | if inject_index is None:
505 | inject_index = random.randint(1, self.n_latent - 1)
506 |
507 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
509 |
510 | latent = torch.cat([latent, latent2], 1)
511 |
512 | out = self.input(latent)
513 | out = self.conv1(out, latent[:, 0], noise=noise[0])
514 |
515 | skip = self.to_rgb1(out, latent[:, 1])
516 |
517 | i = 1
518 | for conv1, conv2, noise1, noise2, to_rgb in zip(
519 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
520 | ):
521 | out = conv1(out, latent[:, i], noise=noise1)
522 | out = conv2(out, latent[:, i + 1], noise=noise2)
523 | skip = to_rgb(out, latent[:, i + 2], skip)
524 |
525 | i += 2
526 |
527 | image = skip
528 |
529 | if return_latents:
530 | return image, latent
531 |
532 | else:
533 | return image, None
534 |
535 |
536 | class ConvLayer(nn.Sequential):
537 | def __init__(
538 | self,
539 | in_channel,
540 | out_channel,
541 | kernel_size,
542 | downsample=False,
543 | blur_kernel=[1, 3, 3, 1],
544 | bias=True,
545 | activate=True,
546 | ):
547 | layers = []
548 |
549 | if downsample:
550 | factor = 2
551 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
552 | pad0 = (p + 1) // 2
553 | pad1 = p // 2
554 |
555 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
556 |
557 | stride = 2
558 | self.padding = 0
559 |
560 | else:
561 | stride = 1
562 | self.padding = kernel_size // 2
563 |
564 | layers.append(
565 | EqualConv2d(
566 | in_channel,
567 | out_channel,
568 | kernel_size,
569 | padding=self.padding,
570 | stride=stride,
571 | bias=bias and not activate,
572 | )
573 | )
574 |
575 | if activate:
576 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
577 |
578 | super().__init__(*layers)
579 |
580 |
581 | class ResBlock(nn.Module):
582 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
583 | super().__init__()
584 |
585 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
586 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
587 |
588 | self.skip = ConvLayer(
589 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
590 | )
591 |
592 | def forward(self, input):
593 | out = self.conv1(input)
594 | out = self.conv2(out)
595 |
596 | skip = self.skip(input)
597 | out = (out + skip) / math.sqrt(2)
598 |
599 | return out
600 |
601 |
602 | class Discriminator(nn.Module):
603 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
604 | super().__init__()
605 |
606 | channels = {
607 | 4: 512,
608 | 8: 512,
609 | 16: 512,
610 | 32: 512,
611 | 64: 256 * channel_multiplier,
612 | 128: 128 * channel_multiplier,
613 | 256: 64 * channel_multiplier,
614 | 512: 32 * channel_multiplier,
615 | 1024: 16 * channel_multiplier,
616 | }
617 |
618 | convs = [ConvLayer(1, channels[size], 1)]
619 |
620 | log_size = int(math.log(size, 2))
621 |
622 | in_channel = channels[size]
623 |
624 | for i in range(log_size, 2, -1):
625 | out_channel = channels[2 ** (i - 1)]
626 |
627 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
628 |
629 | in_channel = out_channel
630 |
631 | self.convs = nn.Sequential(*convs)
632 |
633 | self.stddev_group = 4
634 | self.stddev_feat = 1
635 |
636 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
637 | self.final_linear = nn.Sequential(
638 | EqualLinear(channels[4] * 5 * 20, channels[4], activation="fused_lrelu"),
639 | EqualLinear(channels[4], 1),
640 | )
641 |
642 | def forward(self, input):
643 | out = self.convs(input)
644 |
645 | batch, channel, height, width = out.shape
646 | group = min(batch, self.stddev_group)
647 | stddev = out.view(
648 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
649 | )
650 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
651 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
652 | stddev = stddev.repeat(group, 1, height, width)
653 | out = torch.cat([out, stddev], 1)
654 |
655 | out = self.final_conv(out)
656 |
657 | out = out.view(batch, -1)
658 | out = self.final_linear(out)
659 |
660 | return out
661 |
662 | if __name__ == "__main__":
663 | generator = Generator(size=64,
664 | style_dim=512,
665 | n_mlp=8,
666 | channel_multiplier=2,
667 | blur_kernel=[1, 3, 3, 1],
668 | lr_mlp=0.01)
669 | noise = torch.randn(1, 16, 512).unbind(0)
670 | fake_img, _ = generator(noise) #[16, 80, 320]
671 | dis = Discriminator(size=64)
672 | fake_output = dis(fake_img)
673 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/model_drum_four_bar.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import functools
4 | import operator
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.autograd import Function
10 |
11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
12 |
13 |
14 | class PixelNorm(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20 |
21 |
22 | def make_kernel(k):
23 | k = torch.tensor(k, dtype=torch.float32)
24 |
25 | if k.ndim == 1:
26 | k = k[None, :] * k[:, None]
27 |
28 | k /= k.sum()
29 |
30 | return k
31 |
32 |
33 | class Upsample(nn.Module):
34 | def __init__(self, kernel, factor=2):
35 | super().__init__()
36 |
37 | self.factor = factor
38 | kernel = make_kernel(kernel) * (factor ** 2)
39 | self.register_buffer("kernel", kernel)
40 |
41 | p = kernel.shape[0] - factor
42 |
43 | pad0 = (p + 1) // 2 + factor - 1
44 | pad1 = p // 2
45 |
46 | self.pad = (pad0, pad1)
47 |
48 | def forward(self, input):
49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50 |
51 | return out
52 |
53 |
54 | class Downsample(nn.Module):
55 | def __init__(self, kernel, factor=2):
56 | super().__init__()
57 |
58 | self.factor = factor
59 | kernel = make_kernel(kernel)
60 | self.register_buffer("kernel", kernel)
61 |
62 | p = kernel.shape[0] - factor
63 |
64 | pad0 = (p + 1) // 2
65 | pad1 = p // 2
66 |
67 | self.pad = (pad0, pad1)
68 |
69 | def forward(self, input):
70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71 |
72 | return out
73 |
74 |
75 | class Blur(nn.Module):
76 | def __init__(self, kernel, pad, upsample_factor=1):
77 | super().__init__()
78 |
79 | kernel = make_kernel(kernel)
80 |
81 | if upsample_factor > 1:
82 | kernel = kernel * (upsample_factor ** 2)
83 |
84 | self.register_buffer("kernel", kernel)
85 |
86 | self.pad = pad
87 |
88 | def forward(self, input):
89 | out = upfirdn2d(input, self.kernel, pad=self.pad)
90 |
91 | return out
92 |
93 |
94 | class EqualConv2d(nn.Module):
95 | def __init__(
96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97 | ):
98 | super().__init__()
99 |
100 | self.weight = nn.Parameter(
101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102 | )
103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104 |
105 | self.stride = stride
106 | self.padding = padding
107 |
108 | if bias:
109 | self.bias = nn.Parameter(torch.zeros(out_channel))
110 |
111 | else:
112 | self.bias = None
113 |
114 | def forward(self, input):
115 | out = F.conv2d(
116 | input,
117 | self.weight * self.scale,
118 | bias=self.bias,
119 | stride=self.stride,
120 | padding=self.padding,
121 | )
122 |
123 | return out
124 |
125 | def __repr__(self):
126 | return (
127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129 | )
130 |
131 |
132 | class EqualLinear(nn.Module):
133 | def __init__(
134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135 | ):
136 | super().__init__()
137 |
138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139 |
140 | if bias:
141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142 |
143 | else:
144 | self.bias = None
145 |
146 | self.activation = activation
147 |
148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149 | self.lr_mul = lr_mul
150 |
151 | def forward(self, input):
152 | if self.activation:
153 | out = F.linear(input, self.weight * self.scale)
154 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
155 |
156 | else:
157 | out = F.linear(
158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
159 | )
160 |
161 | return out
162 |
163 | def __repr__(self):
164 | return (
165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166 | )
167 |
168 |
169 | class ModulatedConv2d(nn.Module):
170 | def __init__(
171 | self,
172 | in_channel,
173 | out_channel,
174 | kernel_size,
175 | style_dim,
176 | demodulate=True,
177 | upsample=False,
178 | downsample=False,
179 | blur_kernel=[1, 3, 3, 1],
180 | ):
181 | super().__init__()
182 |
183 | self.eps = 1e-8
184 | self.kernel_size = kernel_size
185 | self.in_channel = in_channel
186 | self.out_channel = out_channel
187 | self.upsample = upsample
188 | self.downsample = downsample
189 |
190 | if upsample:
191 | factor = 2
192 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
193 | pad0 = (p + 1) // 2 + factor - 1
194 | pad1 = p // 2 + 1
195 |
196 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
197 |
198 | if downsample:
199 | factor = 2
200 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
201 | pad0 = (p + 1) // 2
202 | pad1 = p // 2
203 |
204 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
205 |
206 | fan_in = in_channel * kernel_size ** 2
207 | self.scale = 1 / math.sqrt(fan_in)
208 | self.padding = kernel_size // 2
209 |
210 | self.weight = nn.Parameter(
211 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
212 | )
213 |
214 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
215 |
216 | self.demodulate = demodulate
217 |
218 | def __repr__(self):
219 | return (
220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221 | f"upsample={self.upsample}, downsample={self.downsample})"
222 | )
223 |
224 | def forward(self, input, style):
225 | batch, in_channel, height, width = input.shape
226 |
227 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
228 | weight = self.scale * self.weight * style
229 |
230 | if self.demodulate:
231 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
232 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
233 |
234 | weight = weight.view(
235 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
236 | )
237 |
238 | if self.upsample:
239 | input = input.view(1, batch * in_channel, height, width)
240 | weight = weight.view(
241 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
242 | )
243 | weight = weight.transpose(1, 2).reshape(
244 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
245 | )
246 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
247 | _, _, height, width = out.shape
248 | out = out.view(batch, self.out_channel, height, width)
249 | out = self.blur(out)
250 |
251 | elif self.downsample:
252 | input = self.blur(input)
253 | _, _, height, width = input.shape
254 | input = input.view(1, batch * in_channel, height, width)
255 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
256 | _, _, height, width = out.shape
257 | out = out.view(batch, self.out_channel, height, width)
258 |
259 | else:
260 | input = input.view(1, batch * in_channel, height, width)
261 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
262 | _, _, height, width = out.shape
263 | out = out.view(batch, self.out_channel, height, width)
264 |
265 | return out
266 |
267 |
268 | class NoiseInjection(nn.Module):
269 | def __init__(self):
270 | super().__init__()
271 |
272 | self.weight = nn.Parameter(torch.zeros(1))
273 |
274 | def forward(self, image, noise=None):
275 | if noise is None:
276 | batch, _, height, width = image.shape
277 | noise = image.new_empty(batch, 1, height, width).normal_()
278 |
279 | return image + self.weight * noise
280 |
281 |
282 | class ConstantInput(nn.Module):
283 | def __init__(self, channel, size=4):
284 | super().__init__()
285 |
286 | self.input = nn.Parameter(torch.randn(1, channel, 5, 80))
287 |
288 | def forward(self, input):
289 | batch = input.shape[0]
290 | out = self.input.repeat(batch, 1, 1, 1)
291 |
292 | return out
293 |
294 |
295 | class StyledConv(nn.Module):
296 | def __init__(
297 | self,
298 | in_channel,
299 | out_channel,
300 | kernel_size,
301 | style_dim,
302 | upsample=False,
303 | blur_kernel=[1, 3, 3, 1],
304 | demodulate=True,
305 | ):
306 | super().__init__()
307 |
308 | self.conv = ModulatedConv2d(
309 | in_channel,
310 | out_channel,
311 | kernel_size,
312 | style_dim,
313 | upsample=upsample,
314 | blur_kernel=blur_kernel,
315 | demodulate=demodulate,
316 | )
317 |
318 | self.noise = NoiseInjection()
319 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
320 | # self.activate = ScaledLeakyReLU(0.2)
321 | self.activate = FusedLeakyReLU(out_channel)
322 |
323 | def forward(self, input, style, noise=None):
324 | out = self.conv(input, style)
325 | out = self.noise(out, noise=noise)
326 | # out = out + self.bias
327 | out = self.activate(out)
328 |
329 | return out
330 |
331 |
332 | class ToRGB(nn.Module):
333 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
334 | super().__init__()
335 |
336 | if upsample:
337 | self.upsample = Upsample(blur_kernel)
338 |
339 | self.conv = ModulatedConv2d(in_channel, 1, 1, style_dim, demodulate=False)
340 | self.bias = nn.Parameter(torch.zeros(1, 1, 1, 1))
341 |
342 | def forward(self, input, style, skip=None):
343 | out = self.conv(input, style)
344 | out = out + self.bias
345 |
346 | if skip is not None:
347 | skip = self.upsample(skip)
348 |
349 | out = out + skip
350 |
351 | return out
352 |
353 |
354 | class Generator(nn.Module):
355 | def __init__(
356 | self,
357 | size,
358 | style_dim,
359 | n_mlp,
360 | channel_multiplier=2,
361 | blur_kernel=[1, 3, 3, 1],
362 | lr_mlp=0.01,
363 | ):
364 | super().__init__()
365 |
366 | self.size = size
367 |
368 | self.style_dim = style_dim
369 |
370 | layers = [PixelNorm()]
371 |
372 | for i in range(n_mlp):
373 | layers.append(
374 | EqualLinear(
375 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
376 | )
377 | )
378 |
379 | self.style = nn.Sequential(*layers)
380 |
381 | self.channels = {
382 | 4: 512,
383 | 8: 512,
384 | 16: 512,
385 | 32: 512,
386 | 64: 256 * channel_multiplier,
387 | 128: 128 * channel_multiplier,
388 | 256: 64 * channel_multiplier,
389 | 512: 32 * channel_multiplier,
390 | 1024: 16 * channel_multiplier,
391 | }
392 |
393 | self.input = ConstantInput(self.channels[4])
394 | self.conv1 = StyledConv(
395 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
396 | )
397 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
398 |
399 | self.log_size = int(math.log(size, 2))
400 | self.num_layers = (self.log_size - 2) * 2 + 1
401 |
402 | self.convs = nn.ModuleList()
403 | self.upsamples = nn.ModuleList()
404 | self.to_rgbs = nn.ModuleList()
405 | self.noises = nn.Module()
406 |
407 | in_channel = self.channels[4]
408 |
409 | for layer_idx in range(self.num_layers):
410 | res = (layer_idx + 5) // 2
411 | shape = [1, 1, 2 ** res, 2 ** res]
412 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
413 |
414 | for i in range(3, self.log_size + 1):
415 | out_channel = self.channels[2 ** i]
416 |
417 | self.convs.append(
418 | StyledConv(
419 | in_channel,
420 | out_channel,
421 | 3,
422 | style_dim,
423 | upsample=True,
424 | blur_kernel=blur_kernel,
425 | )
426 | )
427 |
428 | self.convs.append(
429 | StyledConv(
430 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
431 | )
432 | )
433 |
434 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
435 |
436 | in_channel = out_channel
437 |
438 | self.n_latent = self.log_size * 2 - 2
439 |
440 | def make_noise(self):
441 | device = self.input.input.device
442 |
443 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
444 |
445 | for i in range(3, self.log_size + 1):
446 | for _ in range(2):
447 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
448 |
449 | return noises
450 |
451 | def mean_latent(self, n_latent):
452 | latent_in = torch.randn(
453 | n_latent, self.style_dim, device=self.input.input.device
454 | )
455 | latent = self.style(latent_in).mean(0, keepdim=True)
456 |
457 | return latent
458 |
459 | def get_latent(self, input):
460 | return self.style(input)
461 |
462 | def forward(
463 | self,
464 | styles,
465 | return_latents=False,
466 | inject_index=None,
467 | truncation=1,
468 | truncation_latent=None,
469 | input_is_latent=False,
470 | noise=None,
471 | randomize_noise=True,
472 | ):
473 | if not input_is_latent:
474 | styles = [self.style(s) for s in styles]
475 |
476 | if noise is None:
477 | if randomize_noise:
478 | noise = [None] * self.num_layers
479 | else:
480 | noise = [
481 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
482 | ]
483 |
484 | if truncation < 1:
485 | style_t = []
486 |
487 | for style in styles:
488 | style_t.append(
489 | truncation_latent + truncation * (style - truncation_latent)
490 | )
491 |
492 | styles = style_t
493 |
494 | if len(styles) < 2:
495 | inject_index = self.n_latent
496 |
497 | if styles[0].ndim < 3:
498 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
499 |
500 | else:
501 | latent = styles[0]
502 |
503 | else:
504 | if inject_index is None:
505 | inject_index = random.randint(1, self.n_latent - 1)
506 |
507 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
508 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
509 |
510 | latent = torch.cat([latent, latent2], 1)
511 |
512 | out = self.input(latent)
513 | out = self.conv1(out, latent[:, 0], noise=noise[0])
514 |
515 | skip = self.to_rgb1(out, latent[:, 1])
516 |
517 | i = 1
518 | for conv1, conv2, noise1, noise2, to_rgb in zip(
519 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
520 | ):
521 | out = conv1(out, latent[:, i], noise=noise1)
522 | out = conv2(out, latent[:, i + 1], noise=noise2)
523 | skip = to_rgb(out, latent[:, i + 2], skip)
524 |
525 | i += 2
526 |
527 | image = skip
528 |
529 | if return_latents:
530 | return image, latent
531 |
532 | else:
533 | return image, None
534 |
535 |
536 | class ConvLayer(nn.Sequential):
537 | def __init__(
538 | self,
539 | in_channel,
540 | out_channel,
541 | kernel_size,
542 | downsample=False,
543 | blur_kernel=[1, 3, 3, 1],
544 | bias=True,
545 | activate=True,
546 | ):
547 | layers = []
548 |
549 | if downsample:
550 | factor = 2
551 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
552 | pad0 = (p + 1) // 2
553 | pad1 = p // 2
554 |
555 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
556 |
557 | stride = 2
558 | self.padding = 0
559 |
560 | else:
561 | stride = 1
562 | self.padding = kernel_size // 2
563 |
564 | layers.append(
565 | EqualConv2d(
566 | in_channel,
567 | out_channel,
568 | kernel_size,
569 | padding=self.padding,
570 | stride=stride,
571 | bias=bias and not activate,
572 | )
573 | )
574 |
575 | if activate:
576 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
577 |
578 | super().__init__(*layers)
579 |
580 |
581 | class ResBlock(nn.Module):
582 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
583 | super().__init__()
584 |
585 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
586 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
587 |
588 | self.skip = ConvLayer(
589 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
590 | )
591 |
592 | def forward(self, input):
593 | out = self.conv1(input)
594 | out = self.conv2(out)
595 |
596 | skip = self.skip(input)
597 | out = (out + skip) / math.sqrt(2)
598 |
599 | return out
600 |
601 |
602 | class Discriminator(nn.Module):
603 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
604 | super().__init__()
605 |
606 | channels = {
607 | 4: 512,
608 | 8: 512,
609 | 16: 512,
610 | 32: 512,
611 | 64: 256 * channel_multiplier,
612 | 128: 128 * channel_multiplier,
613 | 256: 64 * channel_multiplier,
614 | 512: 32 * channel_multiplier,
615 | 1024: 16 * channel_multiplier,
616 | }
617 |
618 | convs = [ConvLayer(1, channels[size], 1)]
619 |
620 | log_size = int(math.log(size, 2))
621 |
622 | in_channel = channels[size]
623 |
624 | for i in range(log_size, 2, -1):
625 | out_channel = channels[2 ** (i - 1)]
626 |
627 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
628 |
629 | in_channel = out_channel
630 |
631 | self.convs = nn.Sequential(*convs)
632 |
633 | self.stddev_group = 4
634 | self.stddev_feat = 1
635 |
636 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
637 | self.final_linear = nn.Sequential(
638 | EqualLinear(channels[4] * 5 * 80 , channels[4], activation="fused_lrelu"),
639 | EqualLinear(channels[4], 1),
640 | )
641 |
642 | def forward(self, input):
643 | out = self.convs(input)
644 |
645 | batch, channel, height, width = out.shape
646 | group = min(batch, self.stddev_group)
647 | stddev = out.view(
648 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
649 | )
650 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
651 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
652 | stddev = stddev.repeat(group, 1, height, width)
653 | out = torch.cat([out, stddev], 1)
654 |
655 | out = self.final_conv(out)
656 | out = out.view(batch, -1)
657 | out = self.final_linear(out)
658 |
659 | return out
660 |
661 | if __name__ == "__main__":
662 | generator = Generator(size=64,
663 | style_dim=512,
664 | n_mlp=8,
665 | channel_multiplier=2,
666 | blur_kernel=[1, 3, 3, 1],
667 | lr_mlp=0.01).cuda()
668 | noise = torch.randn(1, 2, 512).cuda().unbind(0)
669 | fake_img, _ = generator(noise) #[4, 80, 320]
670 | #fake_img = torch.randn(4, 1,80, 1280).cuda()
671 | dis = Discriminator(size=64).cuda()
672 | fake_output = dis(fake_img)
673 | import pdb; pdb.set_trace()
674 |
--------------------------------------------------------------------------------
/non_leaking.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | from distributed import reduce_sum
7 | from op import upfirdn2d
8 |
9 |
10 | class AdaptiveAugment:
11 | def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
12 | self.ada_aug_target = ada_aug_target
13 | self.ada_aug_len = ada_aug_len
14 | self.update_every = update_every
15 |
16 | self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
17 | self.r_t_stat = 0
18 | self.ada_aug_p = 0
19 |
20 | @torch.no_grad()
21 | def tune(self, real_pred):
22 | ada_aug_data = torch.tensor(
23 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
24 | device=real_pred.device,
25 | )
26 | self.ada_aug_buf += reduce_sum(ada_aug_data)
27 |
28 | if self.ada_aug_buf[1] > self.update_every - 1:
29 | pred_signs, n_pred = self.ada_aug_buf.tolist()
30 |
31 | self.r_t_stat = pred_signs / n_pred
32 |
33 | if self.r_t_stat > self.ada_aug_target:
34 | sign = 1
35 |
36 | else:
37 | sign = -1
38 |
39 | self.ada_aug_p += sign * n_pred / self.ada_aug_len
40 | self.ada_aug_p = min(1, max(0, self.ada_aug_p))
41 | self.ada_aug_buf.mul_(0)
42 |
43 | return self.ada_aug_p
44 |
45 |
46 | SYM6 = (
47 | 0.015404109327027373,
48 | 0.0034907120842174702,
49 | -0.11799011114819057,
50 | -0.048311742585633,
51 | 0.4910559419267466,
52 | 0.787641141030194,
53 | 0.3379294217276218,
54 | -0.07263752278646252,
55 | -0.021060292512300564,
56 | 0.04472490177066578,
57 | 0.0017677118642428036,
58 | -0.007800708325034148,
59 | )
60 |
61 |
62 | def translate_mat(t_x, t_y):
63 | batch = t_x.shape[0]
64 |
65 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
66 | translate = torch.stack((t_x, t_y), 1)
67 | mat[:, :2, 2] = translate
68 |
69 | return mat
70 |
71 |
72 | def rotate_mat(theta):
73 | batch = theta.shape[0]
74 |
75 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
76 | sin_t = torch.sin(theta)
77 | cos_t = torch.cos(theta)
78 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
79 | mat[:, :2, :2] = rot
80 |
81 | return mat
82 |
83 |
84 | def scale_mat(s_x, s_y):
85 | batch = s_x.shape[0]
86 |
87 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1)
88 | mat[:, 0, 0] = s_x
89 | mat[:, 1, 1] = s_y
90 |
91 | return mat
92 |
93 |
94 | def translate3d_mat(t_x, t_y, t_z):
95 | batch = t_x.shape[0]
96 |
97 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
98 | translate = torch.stack((t_x, t_y, t_z), 1)
99 | mat[:, :3, 3] = translate
100 |
101 | return mat
102 |
103 |
104 | def rotate3d_mat(axis, theta):
105 | batch = theta.shape[0]
106 |
107 | u_x, u_y, u_z = axis
108 |
109 | eye = torch.eye(3).unsqueeze(0)
110 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
111 | outer = torch.tensor(axis)
112 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
113 |
114 | sin_t = torch.sin(theta).view(-1, 1, 1)
115 | cos_t = torch.cos(theta).view(-1, 1, 1)
116 |
117 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
118 |
119 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
120 | eye_4[:, :3, :3] = rot
121 |
122 | return eye_4
123 |
124 |
125 | def scale3d_mat(s_x, s_y, s_z):
126 | batch = s_x.shape[0]
127 |
128 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
129 | mat[:, 0, 0] = s_x
130 | mat[:, 1, 1] = s_y
131 | mat[:, 2, 2] = s_z
132 |
133 | return mat
134 |
135 |
136 | def luma_flip_mat(axis, i):
137 | batch = i.shape[0]
138 |
139 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
140 | axis = torch.tensor(axis + (0,))
141 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
142 |
143 | return eye - flip
144 |
145 |
146 | def saturation_mat(axis, i):
147 | batch = i.shape[0]
148 |
149 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
150 | axis = torch.tensor(axis + (0,))
151 | axis = torch.ger(axis, axis)
152 | saturate = axis + (eye - axis) * i.view(-1, 1, 1)
153 |
154 | return saturate
155 |
156 |
157 | def lognormal_sample(size, mean=0, std=1):
158 | return torch.empty(size).log_normal_(mean=mean, std=std)
159 |
160 |
161 | def category_sample(size, categories):
162 | category = torch.tensor(categories)
163 | sample = torch.randint(high=len(categories), size=(size,))
164 |
165 | return category[sample]
166 |
167 |
168 | def uniform_sample(size, low, high):
169 | return torch.empty(size).uniform_(low, high)
170 |
171 |
172 | def normal_sample(size, mean=0, std=1):
173 | return torch.empty(size).normal_(mean, std)
174 |
175 |
176 | def bernoulli_sample(size, p):
177 | return torch.empty(size).bernoulli_(p)
178 |
179 |
180 | def random_mat_apply(p, transform, prev, eye):
181 | size = transform.shape[0]
182 | select = bernoulli_sample(size, p).view(size, 1, 1)
183 | select_transform = select * transform + (1 - select) * eye
184 |
185 | return select_transform @ prev
186 |
187 |
188 | def sample_affine(p, size, height, width):
189 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1)
190 | eye = G
191 |
192 | # flip
193 | param = category_sample(size, (0, 1))
194 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size))
195 | G = random_mat_apply(p, Gc, G, eye)
196 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
197 |
198 | # 90 rotate
199 | param = category_sample(size, (0, 3))
200 | Gc = rotate_mat(-math.pi / 2 * param)
201 | G = random_mat_apply(p, Gc, G, eye)
202 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
203 |
204 | # integer translate
205 | param = uniform_sample(size, -0.125, 0.125)
206 | param_height = torch.round(param * height) / height
207 | param_width = torch.round(param * width) / width
208 | Gc = translate_mat(param_width, param_height)
209 | G = random_mat_apply(p, Gc, G, eye)
210 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
211 |
212 | # isotropic scale
213 | param = lognormal_sample(size, std=0.2 * math.log(2))
214 | Gc = scale_mat(param, param)
215 | G = random_mat_apply(p, Gc, G, eye)
216 | # print('isotropic scale', G, scale_mat(param, param), sep='\n')
217 |
218 | p_rot = 1 - math.sqrt(1 - p)
219 |
220 | # pre-rotate
221 | param = uniform_sample(size, -math.pi, math.pi)
222 | Gc = rotate_mat(-param)
223 | G = random_mat_apply(p_rot, Gc, G, eye)
224 | # print('pre-rotate', G, rotate_mat(-param), sep='\n')
225 |
226 | # anisotropic scale
227 | param = lognormal_sample(size, std=0.2 * math.log(2))
228 | Gc = scale_mat(param, 1 / param)
229 | G = random_mat_apply(p, Gc, G, eye)
230 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
231 |
232 | # post-rotate
233 | param = uniform_sample(size, -math.pi, math.pi)
234 | Gc = rotate_mat(-param)
235 | G = random_mat_apply(p_rot, Gc, G, eye)
236 | # print('post-rotate', G, rotate_mat(-param), sep='\n')
237 |
238 | # fractional translate
239 | param = normal_sample(size, std=0.125)
240 | Gc = translate_mat(param, param)
241 | G = random_mat_apply(p, Gc, G, eye)
242 | # print('fractional translate', G, translate_mat(param, param), sep='\n')
243 |
244 | return G
245 |
246 |
247 | def sample_color(p, size):
248 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
249 | eye = C
250 | axis_val = 1 / math.sqrt(3)
251 | axis = (axis_val, axis_val, axis_val)
252 |
253 | # brightness
254 | param = normal_sample(size, std=0.2)
255 | Cc = translate3d_mat(param, param, param)
256 | C = random_mat_apply(p, Cc, C, eye)
257 |
258 | # contrast
259 | param = lognormal_sample(size, std=0.5 * math.log(2))
260 | Cc = scale3d_mat(param, param, param)
261 | C = random_mat_apply(p, Cc, C, eye)
262 |
263 | # luma flip
264 | param = category_sample(size, (0, 1))
265 | Cc = luma_flip_mat(axis, param)
266 | C = random_mat_apply(p, Cc, C, eye)
267 |
268 | # hue rotation
269 | param = uniform_sample(size, -math.pi, math.pi)
270 | Cc = rotate3d_mat(axis, param)
271 | C = random_mat_apply(p, Cc, C, eye)
272 |
273 | # saturation
274 | param = lognormal_sample(size, std=1 * math.log(2))
275 | Cc = saturation_mat(axis, param)
276 | C = random_mat_apply(p, Cc, C, eye)
277 |
278 | return C
279 |
280 |
281 | def make_grid(shape, x0, x1, y0, y1, device):
282 | n, c, h, w = shape
283 | grid = torch.empty(n, h, w, 3, device=device)
284 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
285 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
286 | grid[:, :, :, 2] = 1
287 |
288 | return grid
289 |
290 |
291 | def affine_grid(grid, mat):
292 | n, h, w, _ = grid.shape
293 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
294 |
295 |
296 | def get_padding(G, height, width):
297 | extreme = (
298 | G[:, :2, :]
299 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t()
300 | )
301 |
302 | size = torch.tensor((width, height))
303 |
304 | pad_low = (
305 | ((extreme.min(-1).values + 1) * size)
306 | .clamp(max=0)
307 | .abs()
308 | .ceil()
309 | .max(0)
310 | .values.to(torch.int64)
311 | .tolist()
312 | )
313 | pad_high = (
314 | (extreme.max(-1).values * size - size)
315 | .clamp(min=0)
316 | .ceil()
317 | .max(0)
318 | .values.to(torch.int64)
319 | .tolist()
320 | )
321 |
322 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1]
323 |
324 |
325 | def try_sample_affine_and_pad(img, p, pad_k, G=None):
326 | batch, _, height, width = img.shape
327 |
328 | G_try = G
329 |
330 | while True:
331 | if G is None:
332 | G_try = sample_affine(p, batch, height, width)
333 |
334 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(
335 | torch.inverse(G_try), height, width
336 | )
337 |
338 | try:
339 | img_pad = F.pad(
340 | img,
341 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k),
342 | mode="reflect",
343 | )
344 |
345 | except RuntimeError:
346 | continue
347 |
348 | break
349 |
350 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
351 |
352 |
353 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
354 | kernel = antialiasing_kernel
355 | len_k = len(kernel)
356 | pad_k = (len_k + 1) // 2
357 |
358 | kernel = torch.as_tensor(kernel)
359 | kernel = torch.ger(kernel, kernel).to(img)
360 | kernel_flip = torch.flip(kernel, (0, 1))
361 |
362 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
363 | img, p, pad_k, G
364 | )
365 |
366 | p_ux1 = pad_x1
367 | p_ux2 = pad_x2 + 1
368 | p_uy1 = pad_y1
369 | p_uy2 = pad_y2 + 1
370 | w_p = img_pad.shape[3] - len_k + 1
371 | h_p = img_pad.shape[2] - len_k + 1
372 | h_o = img.shape[2]
373 | w_o = img.shape[3]
374 |
375 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2)
376 |
377 | grid = make_grid(
378 | img_2x.shape,
379 | -2 * p_ux1 / w_o - 1,
380 | 2 * (w_p - p_ux1) / w_o - 1,
381 | -2 * p_uy1 / h_o - 1,
382 | 2 * (h_p - p_uy1) / h_o - 1,
383 | device=img_2x.device,
384 | ).to(img_2x)
385 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x))
386 | grid = grid * torch.tensor(
387 | [w_o / w_p, h_o / h_p], device=grid.device
388 | ) + torch.tensor(
389 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device
390 | )
391 |
392 | img_affine = F.grid_sample(
393 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros"
394 | )
395 |
396 | img_down = upfirdn2d(img_affine, kernel, down=2)
397 |
398 | end_y = -pad_y2 - 1
399 | if end_y == 0:
400 | end_y = img_down.shape[2]
401 |
402 | end_x = -pad_x2 - 1
403 | if end_x == 0:
404 | end_x = img_down.shape[3]
405 |
406 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x]
407 |
408 | return img, G
409 |
410 |
411 | def apply_color(img, mat):
412 | batch = img.shape[0]
413 | img = img.permute(0, 2, 3, 1)
414 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
415 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
416 | img = img @ mat_mul + mat_add
417 | img = img.permute(0, 3, 1, 2)
418 |
419 | return img
420 |
421 |
422 | def random_apply_color(img, p, C=None):
423 | if C is None:
424 | C = sample_color(p, img.shape[0])
425 |
426 | img = apply_color(img, C.to(img))
427 |
428 | return img, C
429 |
430 |
431 | def augment(img, p, transform_matrix=(None, None)):
432 | img, G = random_apply_affine(img, p, transform_matrix[0])
433 | img, C = random_apply_color(img, p, transform_matrix[1])
434 |
435 | return img, (G, C)
436 |
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | upfirdn2d_op = load(
11 | "upfirdn2d",
12 | sources=[
13 | os.path.join(module_path, "upfirdn2d.cpp"),
14 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
15 | ],
16 | )
17 |
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | if input.device.type == "cpu":
147 | out = upfirdn2d_native(
148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
149 | )
150 |
151 | else:
152 | out = UpFirDn2d.apply(
153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
154 | )
155 |
156 | return out
157 |
158 |
159 | def upfirdn2d_native(
160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
161 | ):
162 | _, channel, in_h, in_w = input.shape
163 | input = input.reshape(-1, in_h, in_w, 1)
164 |
165 | _, in_h, in_w, minor = input.shape
166 | kernel_h, kernel_w = kernel.shape
167 |
168 | out = input.view(-1, in_h, 1, in_w, 1, minor)
169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
171 |
172 | out = F.pad(
173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
174 | )
175 | out = out[
176 | :,
177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
179 | :,
180 | ]
181 |
182 | out = out.permute(0, 3, 1, 2)
183 | out = out.reshape(
184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
185 | )
186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
187 | out = F.conv2d(out, w)
188 | out = out.reshape(
189 | -1,
190 | minor,
191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
193 | )
194 | out = out.permute(0, 2, 3, 1)
195 | out = out[:, ::down_y, ::down_x, :]
196 |
197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
199 |
200 | return out.view(-1, channel, out_h, out_w)
201 |
--------------------------------------------------------------------------------
/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import subprocess
3 | import tempfile
4 | from pathlib import Path
5 | import os
6 | import torch
7 | import numpy as np
8 | import soundfile as sf
9 | import yaml
10 | import cog
11 |
12 | from model_drum_four_bar import Generator
13 |
14 | sys.path.append("./melgan")
15 |
16 | from modules import Generator_melgan
17 |
18 |
19 | class Predictor(cog.Predictor):
20 | def setup(self):
21 | self.device = "cuda"
22 | checkpoint_path = "checkpoint-four-bar.pt"
23 | self.latent = 512
24 |
25 | self.g_ema = Generator(
26 | size=64,
27 | style_dim=self.latent,
28 | n_mlp=8,
29 | channel_multiplier=2,
30 | ).to(self.device)
31 | checkpoint = torch.load(checkpoint_path)
32 | self.g_ema.load_state_dict(checkpoint["g_ema"])
33 | self.g_ema.eval()
34 |
35 | data_path = "data/looperman_four_bar"
36 | feat_dim = 80
37 | mean_fp = f"{data_path}/mean.mel.npy"
38 | std_fp = f"{data_path}/std.mel.npy"
39 | self.mean = (
40 | torch.from_numpy(np.load(mean_fp))
41 | .float()
42 | .view(1, feat_dim, 1)
43 | .to(self.device)
44 | )
45 | self.std = (
46 | torch.from_numpy(np.load(std_fp))
47 | .float()
48 | .view(1, feat_dim, 1)
49 | .to(self.device)
50 | )
51 | vocoder_config_fp = "melgan/args.yml"
52 | vocoder_config = read_yaml(vocoder_config_fp)
53 | n_mel_channels = vocoder_config.n_mel_channels
54 | ngf = vocoder_config.ngf
55 | n_residual_layers = vocoder_config.n_residual_layers
56 | self.sr = 44100
57 |
58 | self.vocoder = Generator_melgan(n_mel_channels, ngf, n_residual_layers).to(
59 | self.device
60 | )
61 | self.vocoder.eval()
62 | vocoder_param_fp = "melgan/best_netG.pt"
63 | self.vocoder.load_state_dict(torch.load(vocoder_param_fp))
64 |
65 | @cog.input("seed", type=int, default=-1, help="Random seed, -1 for random")
66 | def predict(self, seed):
67 | if seed < 0:
68 | seed = int.from_bytes(os.urandom(2), "big")
69 | torch.manual_seed(seed)
70 | np.random.seed(seed)
71 | print(f"Prediction seed: {seed}")
72 |
73 | sample_z = torch.randn(1, self.latent, device=self.device)
74 | sample, _ = self.g_ema([sample_z], truncation=1, truncation_latent=None)
75 | de_norm = sample.squeeze(0) * self.std + self.mean
76 | audio_output = self.vocoder(de_norm)
77 | out_dir = Path(tempfile.mkdtemp())
78 | wav_path = out_dir / "out.wav"
79 | mp3_path = out_dir / "out.mp3"
80 |
81 | try:
82 | sf.write(
83 | str(wav_path), audio_output.squeeze().detach().cpu().numpy(), self.sr
84 | )
85 | subprocess.check_output(
86 | [
87 | "ffmpeg",
88 | "-i",
89 | str(wav_path),
90 | str(mp3_path),
91 | ],
92 | )
93 | return mp3_path
94 | finally:
95 | wav_path.unlink(missing_ok=True)
96 |
97 |
98 | def read_yaml(fp):
99 | with open(fp) as file:
100 | # return yaml.load(file)
101 | return yaml.load(file, Loader=yaml.Loader)
102 |
--------------------------------------------------------------------------------
/preprocess/collect_audio_clips.py:
--------------------------------------------------------------------------------
1 | import os
2 | # import numpy as np
3 | import librosa
4 | from pydub import AudioSegment
5 | # from shutil import copyfile
6 | from multiprocessing import Pool
7 |
8 |
9 | def process_one(args):
10 | fn, out_dir = args
11 |
12 | in_fp = os.path.join(audio_dir, f'{fn}.wav')
13 | if not os.path.exists(in_fp):
14 | print('Not exists')
15 | return
16 |
17 | duration = librosa.get_duration(filename=in_fp)
18 |
19 | # duration = song.duration_seconds
20 | num_subclips = int(duration // subclip_duration)
21 | # num_subclips = int(np.ceil(duration / subclip_duration))
22 |
23 | try:
24 | song = AudioSegment.from_wav(in_fp)
25 | except Exception:
26 | print('Error in loading')
27 | return
28 |
29 | for ii in range(num_subclips):
30 | start = ii*subclip_duration
31 | end = (ii+1)*subclip_duration
32 | print(fn, start, end)
33 |
34 | out_fp = os.path.join(out_dir, f'{fn}.{start}_{end}.wav')
35 | if os.path.exists(out_fp):
36 | print('Done before')
37 | continue
38 |
39 | subclip = song[start*1000:end*1000]
40 | subclip.export(out_fp, format='wav')
41 |
42 |
43 | if __name__ == '__main__':
44 | audio_dir = '/home/allenhung/nas189/Database/Looper_man/drum_loops/audio' # Clean audios or separated audios from mixture
45 | out_dir = './training_data/drum_clips_7.9/'
46 |
47 | subclip_duration = 7.9
48 | sr = 22050
49 | ext = '.wav'
50 |
51 | # ### Process ###
52 | num_samples = int(round(subclip_duration * sr))
53 | os.makedirs(out_dir, exist_ok=True)
54 |
55 | fns = [fn.replace(ext, '') for fn in os.listdir(audio_dir) if fn.endswith('.wav')]
56 | print(fns)
57 | pool = Pool(10)
58 |
59 | args_list = []
60 |
61 | for fn in fns:
62 | args_list.append((fn, out_dir))
63 |
64 | pool.map(process_one, args_list)
65 |
--------------------------------------------------------------------------------
/preprocess/compute_mean_std.mel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | from sklearn.preprocessing import StandardScaler
5 |
6 |
7 | if __name__ == "__main__":
8 |
9 | feat_type = 'mel'
10 | exp_dir = '/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar/mel_80_320'
11 |
12 | out_dir = exp_dir
13 |
14 | # ### Process ###
15 |
16 | dataset_fp = os.path.join(exp_dir, f'dataset.pkl')
17 | #feat_dir = os.path.join(exp_dir, feat_type)
18 | feat_dir = exp_dir
19 | out_fp_mean = os.path.join(out_dir, f'mean.{feat_type}.npy')
20 | out_fp_std = os.path.join(out_dir, f'std.{feat_type}.npy')
21 |
22 | with open(dataset_fp, 'rb') as f:
23 | dataset = pickle.load(f)
24 |
25 | in_fns = [fn for fn, _ in dataset]
26 |
27 | scaler = StandardScaler()
28 |
29 | for fn in in_fns:
30 | print(fn)
31 | in_fp = os.path.join(feat_dir, f'{fn}.npy')
32 | data = np.load(in_fp).T
33 | print(data.shape)
34 | print('data: ', data)
35 | scaler.partial_fit(data)
36 | print(scaler.mean_, scaler.scale_)
37 | if True in np.isnan(scaler.scale_):
38 | break
39 |
40 | mean = scaler.mean_
41 | std = scaler.scale_
42 | np.save(out_fp_mean, mean)
43 | np.save(out_fp_std, std)
44 |
--------------------------------------------------------------------------------
/preprocess/extract_mel.py:
--------------------------------------------------------------------------------
1 | # import glob
2 | import os
3 | import librosa
4 | import numpy as np
5 | # from utils.display import *
6 | # from utils.dsp import *
7 | # import hparams as hp
8 | # from multiprocessing import Pool, cpu_count
9 | from multiprocessing import Pool
10 | from librosa.filters import mel as librosa_mel_fn
11 | import torch
12 | from torch import nn
13 | from torch.nn import functional as F
14 |
15 | import argparse
16 | '''
17 | Modified from
18 | https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py#L26
19 | '''
20 |
21 |
22 | class Audio2Mel(nn.Module):
23 | def __init__(
24 | self,
25 | n_fft=1024,
26 | hop_length=256,
27 | win_length=1024,
28 | sampling_rate=22050,
29 | n_mel_channels=80,
30 | mel_fmin=0.0,
31 | mel_fmax=None,
32 | ):
33 | super().__init__()
34 | ##############################################
35 | # FFT Parameters #
36 | ##############################################
37 | window = torch.hann_window(win_length).float()
38 | mel_basis = librosa_mel_fn(
39 | sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax
40 | )
41 | mel_basis = torch.from_numpy(mel_basis).float()
42 | self.register_buffer("mel_basis", mel_basis)
43 | self.register_buffer("window", window)
44 | self.n_fft = n_fft
45 | self.hop_length = hop_length
46 | self.win_length = win_length
47 | self.sampling_rate = sampling_rate
48 | self.n_mel_channels = n_mel_channels
49 |
50 | def forward(self, audio):
51 | p = (self.n_fft - self.hop_length) // 2
52 | audio = F.pad(audio, (p, p), "reflect").squeeze(1)
53 | fft = torch.stft(
54 | audio,
55 | n_fft=self.n_fft,
56 | hop_length=self.hop_length,
57 | win_length=self.win_length,
58 | window=self.window,
59 | center=False,
60 | )
61 | real_part, imag_part = fft.unbind(-1)
62 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
63 | mel_output = torch.matmul(self.mel_basis, magnitude)
64 | log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
65 | return log_mel_spec[:, :, :]
66 |
67 |
68 | def convert_file(path):
69 | y, _ = librosa.load(path, sr=sr)
70 | peak = np.abs(y).max()
71 | if peak_norm or peak > 1.0:
72 | y /= peak
73 |
74 | y = torch.from_numpy(y)
75 | y = y[None, None]
76 | mel = extract_func(y)
77 | mel = mel.numpy()
78 | mel = mel[0]
79 | print(mel.shape)
80 |
81 | return mel.astype(np.float32)
82 |
83 |
84 | def process_audios(path):
85 | id = path.split('/')[-1][:-4]
86 |
87 | out_dir = os.path.join(base_out_dir, feat_type)
88 | os.makedirs(out_dir, exist_ok=True)
89 |
90 | out_fp = os.path.join(out_dir, f'{id}.npy')
91 |
92 | if os.path.exists(out_fp):
93 | print('Done before')
94 | return id, 0
95 |
96 | try:
97 | m = convert_file(path)
98 |
99 | np.save(out_fp, m, allow_pickle=False)
100 | except Exception:
101 | return id, 0
102 | return id, m.shape[-1]
103 |
104 |
105 | if __name__ == "__main__":
106 | parser = argparse.ArgumentParser(description="compute inception score")
107 |
108 | parser.add_argument("--epoch", type=str, help="path to the model")
109 |
110 | args = parser.parse_args()
111 | base_out_dir = f'/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar'
112 | os.makedirs(base_out_dir, exist_ok=True)
113 | clip_dir = f'/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar' #out_dir from step1
114 |
115 | feat_type = 'mel_80_320'
116 | extension = '.wav'
117 | peak_norm = True
118 |
119 | n_fft = 1024
120 | hop_length = 275 #[241, 482, 964, 1928, 3856]
121 | win_length = 1024
122 | sampling_rate = 44100
123 | n_mel_channels = 80 #[80, 40, 20, 10, 5]
124 |
125 | # ### Process ###
126 | extract_func = Audio2Mel(n_fft, hop_length, win_length, sampling_rate, n_mel_channels)
127 | sr = sampling_rate
128 |
129 | audio_fns = [fn for fn in os.listdir(clip_dir) if fn.endswith(extension)]
130 |
131 | audio_fns = sorted(list(audio_fns))
132 |
133 | audio_files = [os.path.join(clip_dir, fn) for fn in audio_fns]
134 |
135 | pool = Pool(processes=20)
136 | dataset = []
137 |
138 | for i, (id, length) in enumerate(pool.imap_unordered(process_audios, audio_files), 1):
139 | print(id)
140 | if length == 0:
141 | continue
142 | dataset += [(id, length)]
143 |
--------------------------------------------------------------------------------
/preprocess/make_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from multiprocessing import Pool
4 | import pickle
5 |
6 |
7 | def process_audios(feat_fn):
8 | feat_fp = os.path.join(feat_dir, f'{feat_fn}.npy')
9 |
10 | if os.path.exists(feat_fp):
11 | return feat_fn, np.load(feat_fp).shape[-1]
12 | else:
13 | return feat_fn, 0
14 |
15 |
16 | if __name__ == "__main__":
17 | feat_type = ''
18 | exp_dir = '/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar/mel_80_320' # base_out_dir from step2
19 |
20 | out_fp = os.path.join(exp_dir, 'dataset.pkl')
21 |
22 | # ### Process ###
23 | feat_dir = os.path.join(exp_dir, feat_type)
24 |
25 | feat_fns = [fn.replace('.npy', '') for fn in os.listdir(feat_dir)]
26 |
27 | pool = Pool(processes=20)
28 | dataset = []
29 |
30 | for i, (feat_fn, length) in enumerate(pool.imap_unordered(process_audios, feat_fns), 1):
31 | print(feat_fn)
32 | if length == 0:
33 | continue
34 | dataset += [(feat_fn, length)]
35 |
36 | with open(out_fp, 'wb') as f:
37 | pickle.dump(dataset, f)
38 |
39 | print(len(dataset))
40 |
--------------------------------------------------------------------------------
/preprocess/trim_2_seconds.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import pyrubberband as pyrb
3 | import madmom
4 | from madmom.features.downbeats import RNNDownBeatProcessor
5 | from madmom.features.downbeats import DBNDownBeatTrackingProcessor
6 | import os
7 | import matplotlib.pyplot as plot
8 | import numpy as np
9 | import soundfile as sf
10 | from multiprocessing import Pool
11 | loop_dir='/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar'
12 | out_dir='/home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Guitar_one_bar'
13 | os.makedirs(out_dir, exist_ok=True)
14 | def one_bar_segment(file):
15 | file_path = os.path.join(loop_dir, file)
16 | try:
17 | y, sr = librosa.core.load(file_path, sr=None) # sr = None will retrieve the original sampling rate = 44100
18 | except:
19 | print('load file failed')
20 | return
21 | try:
22 | act = RNNDownBeatProcessor()(file_path)
23 | down_beat=proc(act) # [..., 2] 2d-shape numpy array
24 | except:
25 | print('except happended')
26 | return
27 | #print(down_beat)
28 | #print(len(y) / sr)
29 | #import pdb; pdb.set_trace()
30 | #retrieve 1, 2, 3, 4, 1blocks
31 | count = 0
32 | bar_list = []
33 | #print(file)
34 | name = file.replace('.wav', '')
35 | print(down_beat)
36 | for i in range(down_beat.shape[0]):
37 | if down_beat[i][1] == 1 and i + 4 < down_beat.shape[0] and down_beat[i+4][1] == 1:
38 | print(down_beat[i: i + 5, :])
39 | start_time = down_beat[i][0]
40 | end_time = down_beat[i + 4][0]
41 | count += 1
42 | out_path = os.path.join(out_dir, f'{name}_{count}.wav')
43 | #print(len(y) / sr)
44 | #print(sr)
45 | y_one_bar, _ = librosa.core.load(file_path, offset=start_time, duration = end_time - start_time, sr=None)
46 | y_stretch = pyrb.time_stretch(y_one_bar, sr, (end_time - start_time) / 2)
47 | #print((end_time - start_time))
48 | #print()
49 | sf.write(out_path, y_stretch, sr)
50 |
51 | print('save file: ', f'{name}_{count}.wav')
52 | #y, sr = librosa.core.load(out_path, sr=None)
53 | #print(librosa.get_duration(y, sr=sr))
54 |
55 | if __name__ == '__main__':
56 | #dur_list = []
57 | #for file in os.listdir(loop_dir):
58 | # file_path = os.path.join(loop_dir, file)
59 | # y, sr = librosa.core.load(file_path)
60 | # dur = librosa.get_duration(y, sr)
61 | # dur_list.append(dur)
62 | #num_bins = 10
63 | #plot.hist(dur_list, num_bins, density=True)
64 | #plot.savefig('./duration.png ')
65 |
66 | proc = DBNDownBeatTrackingProcessor(beats_per_bar=4, fps = 100)
67 | file_list = list(os.listdir(loop_dir))
68 | #print(file_list[1])
69 | #one_bar_segment(file_list[1])
70 | #print(file_list[:10])
71 | #for file in os.listdir(loop_dir):
72 | # file_list.append(file)
73 | with Pool(processes=10) as pool:
74 | pool.map(one_bar_segment, file_list)
75 |
--------------------------------------------------------------------------------
/scripts/generate_freesound.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=3 python generate_audio.py \
3 | --ckpt "freesound_checkpoint.pt" \
4 | --pics 2000 --data_path "./data/freesound" \
5 | --store_path "./generated_freesound_one_bar" \
6 | --style_mixing
7 |
--------------------------------------------------------------------------------
/scripts/generate_looperman_four_bar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=3 python generate_looperman_four_bar.py \
3 | --ckpt "looperman_four_bar_checkpoint.pt" \
4 | --pics 100 \
5 | --data_path "./data/looperman_four_bar" \
6 | --store_path "./generated_audio_looperman_four_bar"
7 |
--------------------------------------------------------------------------------
/scripts/generate_looperman_one_bar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=3 python generate_audio.py \
3 | --ckpt "looperman_one_bar_checkpoint.pt" \
4 | --pics 2000 --data_path "./data/looperman/" \
5 | --store_path "./generated_looperman_one_bar"
6 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=1 python train_drum.py \
3 | --size 64 --batch 8 --sample_dir sample_Bandlab_Beats_one_bar \
4 | --checkpoint_dir checkpoint_Bandlab_Beats_one_bar \
5 | /home/allenhung/nas189/home/bandlab/BANDLAB_INSTRUMENT/Beats_one_bar/mel_80_320
6 |
--------------------------------------------------------------------------------
/train_drum.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import random
4 | import os
5 |
6 | import numpy as np
7 | import torch
8 | from torch import nn, autograd, optim
9 | from torch.nn import functional as F
10 | from torch.utils import data
11 | import torch.distributed as dist
12 | from torchvision import transforms, utils
13 | from tqdm import tqdm
14 |
15 | try:
16 | import wandb
17 |
18 | except ImportError:
19 | wandb = None
20 |
21 | from model_drum import Generator, Discriminator
22 | from dataset import MultiResolutionDataset_drum
23 | from distributed import (
24 | get_rank,
25 | synchronize,
26 | reduce_loss_dict,
27 | reduce_sum,
28 | get_world_size,
29 | )
30 | from non_leaking import augment, AdaptiveAugment
31 |
32 |
33 | def data_sampler(dataset, shuffle, distributed):
34 | if distributed:
35 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
36 |
37 | if shuffle:
38 | return data.RandomSampler(dataset)
39 |
40 | else:
41 | return data.SequentialSampler(dataset)
42 |
43 |
44 | def requires_grad(model, flag=True):
45 | for p in model.parameters():
46 | p.requires_grad = flag
47 |
48 |
49 | def accumulate(model1, model2, decay=0.999):
50 | par1 = dict(model1.named_parameters())
51 | par2 = dict(model2.named_parameters())
52 |
53 | for k in par1.keys():
54 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
55 |
56 |
57 | def sample_data(loader):
58 | while True:
59 | for batch in loader:
60 | yield batch
61 |
62 |
63 | def d_logistic_loss(real_pred, fake_pred):
64 | real_loss = F.softplus(-real_pred)
65 | fake_loss = F.softplus(fake_pred)
66 |
67 | return real_loss.mean() + fake_loss.mean()
68 |
69 |
70 | def d_r1_loss(real_pred, real_img):
71 | grad_real, = autograd.grad(
72 | outputs=real_pred.sum(), inputs=real_img, create_graph=True
73 | )
74 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
75 |
76 | return grad_penalty
77 |
78 |
79 | def g_nonsaturating_loss(fake_pred):
80 | loss = F.softplus(-fake_pred).mean()
81 |
82 | return loss
83 |
84 |
85 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
86 | noise = torch.randn_like(fake_img) / math.sqrt(
87 | fake_img.shape[2] * fake_img.shape[3]
88 | )
89 | grad, = autograd.grad(
90 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
91 | )
92 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
93 |
94 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
95 |
96 | path_penalty = (path_lengths - path_mean).pow(2).mean()
97 |
98 | return path_penalty, path_mean.detach(), path_lengths
99 |
100 |
101 | def make_noise(batch, latent_dim, n_noise, device):
102 | if n_noise == 1:
103 | return torch.randn(batch, latent_dim, device=device)
104 |
105 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
106 |
107 | return noises
108 |
109 |
110 | def mixing_noise(batch, latent_dim, prob, device):
111 | if prob > 0 and random.random() < prob:
112 | return make_noise(batch, latent_dim, 2, device)
113 |
114 | else:
115 | return [make_noise(batch, latent_dim, 1, device)]
116 |
117 |
118 | def set_grad_none(model, targets):
119 | for n, p in model.named_parameters():
120 | if n in targets:
121 | p.grad = None
122 |
123 |
124 | def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
125 | os.makedirs(args.sample_dir, exist_ok=True)
126 | os.makedirs(args.checkpoint_dir, exist_ok=True)
127 |
128 | loader = sample_data(loader)
129 |
130 | pbar = range(args.iter)
131 |
132 | if get_rank() == 0:
133 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
134 |
135 | mean_path_length = 0
136 |
137 | d_loss_val = 0
138 | r1_loss = torch.tensor(0.0, device=device)
139 | g_loss_val = 0
140 | path_loss = torch.tensor(0.0, device=device)
141 | path_lengths = torch.tensor(0.0, device=device)
142 | mean_path_length_avg = 0
143 | loss_dict = {}
144 |
145 | if args.distributed:
146 | g_module = generator.module
147 | d_module = discriminator.module
148 |
149 | else:
150 | g_module = generator
151 | d_module = discriminator
152 |
153 | accum = 0.5 ** (32 / (10 * 1000))
154 | ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
155 | r_t_stat = 0
156 |
157 | if args.augment and args.augment_p == 0:
158 | ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device)
159 |
160 | sample_z = torch.randn(args.n_sample, args.latent, device=device)
161 |
162 | for idx in pbar:
163 | i = idx + args.start_iter
164 |
165 | if i > args.iter:
166 | print("Done!")
167 |
168 | break
169 |
170 | real_img = next(loader)
171 | real_img = real_img.to(device)
172 | mean_fp = os.path.join(args.path, f'mean.mel.npy')
173 | std_fp = os.path.join(args.path, f'std.mel.npy')
174 | feat_dim = 80
175 | mean = torch.from_numpy(np.load(mean_fp)).float().cuda().view(1, feat_dim, 1)
176 | std = torch.from_numpy(np.load(std_fp)).float().cuda().view(1, feat_dim, 1)
177 | real_img = (real_img - mean) / std
178 | requires_grad(generator, False)
179 | requires_grad(discriminator, True)
180 |
181 | noise = mixing_noise(args.batch, args.latent, args.mixing, device)
182 | fake_img, _ = generator(noise)
183 |
184 | if args.augment:
185 | real_img_aug, _ = augment(real_img, ada_aug_p)
186 | fake_img, _ = augment(fake_img, ada_aug_p)
187 |
188 | else:
189 | real_img_aug = real_img
190 |
191 | fake_pred = discriminator(fake_img)
192 | real_pred = discriminator(real_img_aug)
193 | d_loss = d_logistic_loss(real_pred, fake_pred)
194 |
195 | loss_dict["d"] = d_loss
196 | loss_dict["real_score"] = real_pred.mean()
197 | loss_dict["fake_score"] = fake_pred.mean()
198 |
199 | discriminator.zero_grad()
200 | d_loss.backward()
201 | d_optim.step()
202 |
203 | if args.augment and args.augment_p == 0:
204 | ada_aug_p = ada_augment.tune(real_pred)
205 | r_t_stat = ada_augment.r_t_stat
206 |
207 | d_regularize = i % args.d_reg_every == 0
208 |
209 | if d_regularize:
210 | real_img.requires_grad = True
211 | real_pred = discriminator(real_img)
212 | r1_loss = d_r1_loss(real_pred, real_img)
213 |
214 | discriminator.zero_grad()
215 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
216 |
217 | d_optim.step()
218 |
219 | loss_dict["r1"] = r1_loss
220 |
221 | requires_grad(generator, True)
222 | requires_grad(discriminator, False)
223 |
224 | noise = mixing_noise(args.batch, args.latent, args.mixing, device)
225 | fake_img, _ = generator(noise)
226 |
227 | if args.augment:
228 | fake_img, _ = augment(fake_img, ada_aug_p)
229 |
230 | fake_pred = discriminator(fake_img)
231 | g_loss = g_nonsaturating_loss(fake_pred)
232 |
233 | loss_dict["g"] = g_loss
234 |
235 | generator.zero_grad()
236 | g_loss.backward()
237 | g_optim.step()
238 |
239 | g_regularize = i % args.g_reg_every == 0
240 |
241 | if g_regularize:
242 | path_batch_size = max(1, args.batch // args.path_batch_shrink)
243 | noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
244 | fake_img, latents = generator(noise, return_latents=True)
245 |
246 | path_loss, mean_path_length, path_lengths = g_path_regularize(
247 | fake_img, latents, mean_path_length
248 | )
249 |
250 | generator.zero_grad()
251 | weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
252 |
253 | if args.path_batch_shrink:
254 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
255 |
256 | weighted_path_loss.backward()
257 |
258 | g_optim.step()
259 |
260 | mean_path_length_avg = (
261 | reduce_sum(mean_path_length).item() / get_world_size()
262 | )
263 |
264 | loss_dict["path"] = path_loss
265 | loss_dict["path_length"] = path_lengths.mean()
266 |
267 | accumulate(g_ema, g_module, accum)
268 |
269 | loss_reduced = reduce_loss_dict(loss_dict)
270 |
271 | d_loss_val = loss_reduced["d"].mean().item()
272 | g_loss_val = loss_reduced["g"].mean().item()
273 | r1_val = loss_reduced["r1"].mean().item()
274 | path_loss_val = loss_reduced["path"].mean().item()
275 | real_score_val = loss_reduced["real_score"].mean().item()
276 | fake_score_val = loss_reduced["fake_score"].mean().item()
277 | path_length_val = loss_reduced["path_length"].mean().item()
278 |
279 | if get_rank() == 0:
280 | pbar.set_description(
281 | (
282 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
283 | f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
284 | f"augment: {ada_aug_p:.4f}"
285 | )
286 | )
287 |
288 | if wandb and args.wandb:
289 | wandb.log(
290 | {
291 | "Generator": g_loss_val,
292 | "Discriminator": d_loss_val,
293 | "Augment": ada_aug_p,
294 | "Rt": r_t_stat,
295 | "R1": r1_val,
296 | "Path Length Regularization": path_loss_val,
297 | "Mean Path Length": mean_path_length,
298 | "Real Score": real_score_val,
299 | "Fake Score": fake_score_val,
300 | "Path Length": path_length_val,
301 | }
302 | )
303 |
304 | if i % 100 == 0:
305 | with torch.no_grad():
306 | g_ema.eval()
307 | sample, _ = g_ema([sample_z])
308 | utils.save_image(
309 | sample,
310 | f"{args.sample_dir}/{str(i).zfill(6)}.png",
311 | nrow=int(args.n_sample ** 0.5),
312 | normalize=True,
313 | range=(-1, 1),
314 | )
315 |
316 | if i % 10000 == 0:
317 | torch.save(
318 | {
319 | "g": g_module.state_dict(),
320 | "d": d_module.state_dict(),
321 | "g_ema": g_ema.state_dict(),
322 | "g_optim": g_optim.state_dict(),
323 | "d_optim": d_optim.state_dict(),
324 | "args": args,
325 | "ada_aug_p": ada_aug_p,
326 | },
327 | f"{args.checkpoint_dir}/{str(i).zfill(6)}.pt",
328 | )
329 |
330 |
331 | if __name__ == "__main__":
332 | device = "cuda"
333 |
334 | parser = argparse.ArgumentParser(description="StyleGAN2 trainer")
335 |
336 | parser.add_argument("path", type=str, help="path to the lmdb dataset")
337 | parser.add_argument(
338 | "--iter", type=int, default=800000, help="total training iterations"
339 | )
340 | parser.add_argument(
341 | "--batch", type=int, default=16, help="batch sizes for each gpus"
342 | )
343 | parser.add_argument(
344 | "--n_sample",
345 | type=int,
346 | default=16,
347 | help="number of the samples generated during training",
348 | )
349 | parser.add_argument(
350 | "--size", type=int, default=256, help="image sizes for the model"
351 | )
352 | parser.add_argument(
353 | "--r1", type=float, default=10, help="weight of the r1 regularization"
354 | )
355 | parser.add_argument(
356 | "--path_regularize",
357 | type=float,
358 | default=2,
359 | help="weight of the path length regularization",
360 | )
361 | parser.add_argument(
362 | "--path_batch_shrink",
363 | type=int,
364 | default=2,
365 | help="batch size reducing factor for the path length regularization (reduce memory consumption)",
366 | )
367 | parser.add_argument(
368 | "--d_reg_every",
369 | type=int,
370 | default=16,
371 | help="interval of the applying r1 regularization",
372 | )
373 | parser.add_argument(
374 | "--g_reg_every",
375 | type=int,
376 | default=4,
377 | help="interval of the applying path length regularization",
378 | )
379 | parser.add_argument(
380 | "--mixing", type=float, default=0.9, help="probability of latent code mixing"
381 | )
382 | parser.add_argument(
383 | "--ckpt",
384 | type=str,
385 | default=None,
386 | help="path to the checkpoints to resume training",
387 | )
388 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate")
389 | parser.add_argument(
390 | "--channel_multiplier",
391 | type=int,
392 | default=2,
393 | help="channel multiplier factor for the model. config-f = 2, else = 1",
394 | )
395 | parser.add_argument(
396 | "--wandb", action="store_true", help="use weights and biases logging"
397 | )
398 | parser.add_argument(
399 | "--local_rank", type=int, default=0, help="local rank for distributed training"
400 | )
401 | parser.add_argument(
402 | "--augment", action="store_true", help="apply non leaking augmentation"
403 | )
404 | parser.add_argument(
405 | "--augment_p",
406 | type=float,
407 | default=0,
408 | help="probability of applying augmentation. 0 = use adaptive augmentation",
409 | )
410 | parser.add_argument(
411 | "--ada_target",
412 | type=float,
413 | default=0.6,
414 | help="target augmentation probability for adaptive augmentation",
415 | )
416 | parser.add_argument(
417 | "--ada_length",
418 | type=int,
419 | default=500 * 1000,
420 | help="target duraing to reach augmentation probability for adaptive augmentation",
421 | )
422 | parser.add_argument(
423 | "--ada_every",
424 | type=int,
425 | default=256,
426 | help="probability update interval of the adaptive augmentation",
427 | )
428 |
429 | parser.add_argument(
430 | "--sample_dir",
431 | type=str,
432 | default='sample',
433 | help="sample directory",
434 | )
435 |
436 | parser.add_argument(
437 | "--checkpoint_dir",
438 | type=str,
439 | default='checkpoint',
440 | help="checkpoint directory",
441 | )
442 | args = parser.parse_args()
443 |
444 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
445 | args.distributed = n_gpu > 1
446 |
447 | if args.distributed:
448 | torch.cuda.set_device(args.local_rank)
449 | torch.distributed.init_process_group(backend="nccl", init_method="env://")
450 | synchronize()
451 |
452 | args.latent = 512
453 | args.n_mlp = 8
454 |
455 | args.start_iter = 0
456 |
457 | generator = Generator(
458 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
459 | ).to(device)
460 | discriminator = Discriminator(
461 | args.size, channel_multiplier=args.channel_multiplier
462 | ).to(device)
463 | g_ema = Generator(
464 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
465 | ).to(device)
466 | g_ema.eval()
467 | accumulate(g_ema, generator, 0)
468 |
469 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
470 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
471 |
472 | g_optim = optim.Adam(
473 | generator.parameters(),
474 | lr=args.lr * g_reg_ratio,
475 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
476 | )
477 | d_optim = optim.Adam(
478 | discriminator.parameters(),
479 | lr=args.lr * d_reg_ratio,
480 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
481 | )
482 |
483 | if args.ckpt is not None:
484 | print("load model:", args.ckpt)
485 |
486 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
487 |
488 | try:
489 | ckpt_name = os.path.basename(args.ckpt)
490 | args.start_iter = int(os.path.splitext(ckpt_name)[0])
491 |
492 | except ValueError:
493 | pass
494 |
495 | generator.load_state_dict(ckpt["g"])
496 | discriminator.load_state_dict(ckpt["d"])
497 | g_ema.load_state_dict(ckpt["g_ema"])
498 |
499 | g_optim.load_state_dict(ckpt["g_optim"])
500 | d_optim.load_state_dict(ckpt["d_optim"])
501 |
502 | if args.distributed:
503 | generator = nn.parallel.DistributedDataParallel(
504 | generator,
505 | device_ids=[args.local_rank],
506 | output_device=args.local_rank,
507 | broadcast_buffers=False,
508 | )
509 |
510 | discriminator = nn.parallel.DistributedDataParallel(
511 | discriminator,
512 | device_ids=[args.local_rank],
513 | output_device=args.local_rank,
514 | broadcast_buffers=False,
515 | )
516 |
517 | transform = transforms.Compose(
518 | [
519 | #transforms.RandomHorizontalFlip(),
520 | transforms.ToTensor(),
521 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
522 | ]
523 | )
524 |
525 | dataset = MultiResolutionDataset_drum(args.path, transform)
526 | loader = data.DataLoader(
527 | dataset,
528 | batch_size=args.batch,
529 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
530 | drop_last=True,
531 | )
532 |
533 | if get_rank() == 0 and wandb is not None and args.wandb:
534 | wandb.init(project="stylegan 2")
535 |
536 | train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
537 |
--------------------------------------------------------------------------------