├── .gitignore ├── academicodec ├── __init__.py ├── binary.py ├── models │ ├── encodec │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── distributed │ │ │ ├── distributed.py │ │ │ └── launch.py │ │ ├── loss.py │ │ ├── main_launch.py │ │ ├── msstftd.py │ │ ├── net3.py │ │ └── test.py │ ├── hificodec │ │ ├── __init__.py │ │ ├── env.py │ │ ├── meldataset.py │ │ ├── models.py │ │ ├── train.py │ │ ├── vqvae.py │ │ ├── vqvae_copy_syn.py │ │ └── vqvae_tester.py │ └── soundstream │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── loss.py │ │ └── models.py ├── modules │ ├── __init__.py │ ├── conv.py │ ├── lstm.py │ ├── norm.py │ ├── seanet.py │ └── transformer.py ├── quantization │ ├── __init__.py │ ├── ac.py │ ├── core_vq.py │ ├── distrib.py │ └── vq.py └── utils.py ├── egs ├── Encodec_16k_320d │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── Encodec_24k_240d │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── Encodec_24k_32d │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── HiFi-Codec-16k-320d │ ├── config_16k_320d.json │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── HiFi-Codec-24k-240d │ ├── config_24k_240d.json │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── HiFi-Codec-24k-320d │ ├── config_24k_320d.json │ ├── infer.ipynb │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh ├── SoundStream_24k_240d │ ├── main3_ddp.py │ ├── path.sh │ ├── readme.md │ ├── start.sh │ └── test.sh └── util │ └── wavlstgen.py ├── evaluation_metric └── calculate_voc_obj_metrics │ ├── compute_metrics.sh │ └── metrics │ ├── compute_pesq.py │ ├── compute_stoi.py │ └── utils.py ├── readme.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt 2 | outputdir 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /academicodec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/__init__.py -------------------------------------------------------------------------------- /academicodec/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" 7 | import io 8 | import json 9 | import struct 10 | import typing as tp 11 | 12 | # format is `ECDC` magic code, followed by the header size as uint32. 13 | # Then an uint8 indicates the protocol version (0.) 14 | # The header is then provided as json and should contain all required 15 | # informations for decoding. A raw stream of bytes is then provided 16 | # and should be interpretable using the json header. 17 | _encodec_header_struct = struct.Struct('!4sBI') 18 | _ENCODEC_MAGIC = b'ECDC' 19 | 20 | 21 | def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): 22 | meta_dumped = json.dumps(metadata).encode('utf-8') 23 | version = 0 24 | header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, 25 | len(meta_dumped)) 26 | fo.write(header) 27 | fo.write(meta_dumped) 28 | fo.flush() 29 | 30 | 31 | def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: 32 | buf = b"" 33 | while len(buf) < size: 34 | new_buf = fo.read(size) 35 | if not new_buf: 36 | raise EOFError("Impossible to read enough data from the stream, " 37 | f"{size} bytes remaining.") 38 | buf += new_buf 39 | size -= len(new_buf) 40 | return buf 41 | 42 | 43 | def read_ecdc_header(fo: tp.IO[bytes]): 44 | header_bytes = _read_exactly(fo, _encodec_header_struct.size) 45 | magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) 46 | if magic != _ENCODEC_MAGIC: 47 | raise ValueError("File is not in ECDC format.") 48 | if version != 0: 49 | raise ValueError("Version not supported.") 50 | meta_bytes = _read_exactly(fo, meta_size) 51 | return json.loads(meta_bytes.decode('utf-8')) 52 | 53 | 54 | class BitPacker: 55 | """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. 56 | Note that for some bandwidth (1.5, 3), the codebook representation 57 | will not cover an integer number of bytes. 58 | 59 | Args: 60 | bits (int): number of bits per value that will be pushed. 61 | fo (IO[bytes]): file-object to push the bytes to. 62 | """ 63 | 64 | def __init__(self, bits: int, fo: tp.IO[bytes]): 65 | self._current_value = 0 66 | self._current_bits = 0 67 | self.bits = bits 68 | self.fo = fo 69 | 70 | def push(self, value: int): 71 | """Push a new value to the stream. This will immediately 72 | write as many uint8 as possible to the underlying file-object.""" 73 | self._current_value += (value << self._current_bits) 74 | self._current_bits += self.bits 75 | while self._current_bits >= 8: 76 | lower_8bits = self._current_value & 0xff 77 | self._current_bits -= 8 78 | self._current_value >>= 8 79 | self.fo.write(bytes([lower_8bits])) 80 | 81 | def flush(self): 82 | """Flushes the remaining partial uint8, call this at the end 83 | of the stream to encode.""" 84 | if self._current_bits: 85 | self.fo.write(bytes([self._current_value])) 86 | self._current_value = 0 87 | self._current_bits = 0 88 | self.fo.flush() 89 | 90 | 91 | class BitUnpacker: 92 | """BitUnpacker does the opposite of `BitPacker`. 93 | 94 | Args: 95 | bits (int): number of bits of the values to decode. 96 | fo (IO[bytes]): file-object to push the bytes to. 97 | """ 98 | 99 | def __init__(self, bits: int, fo: tp.IO[bytes]): 100 | self.bits = bits 101 | self.fo = fo 102 | self._mask = (1 << bits) - 1 103 | self._current_value = 0 104 | self._current_bits = 0 105 | 106 | def pull(self) -> tp.Optional[int]: 107 | """ 108 | Pull a single value from the stream, potentially reading some 109 | extra bytes from the underlying file-object. 110 | Returns `None` when reaching the end of the stream. 111 | """ 112 | while self._current_bits < self.bits: 113 | buf = self.fo.read(1) 114 | if not buf: 115 | return None 116 | character = buf[0] 117 | self._current_value += character << self._current_bits 118 | self._current_bits += 8 119 | 120 | out = self._current_value & self._mask 121 | self._current_value >>= self.bits 122 | self._current_bits -= self.bits 123 | return out 124 | 125 | 126 | def test(): 127 | import torch 128 | torch.manual_seed(1234) 129 | for rep in range(4): 130 | length: int = torch.randint(10, 2_000, (1, )).item() 131 | bits: int = torch.randint(1, 16, (1, )).item() 132 | tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() 133 | rebuilt: tp.List[int] = [] 134 | buf = io.BytesIO() 135 | packer = BitPacker(bits, buf) 136 | for token in tokens: 137 | packer.push(token) 138 | packer.flush() 139 | buf.seek(0) 140 | unpacker = BitUnpacker(bits, buf) 141 | while True: 142 | value = unpacker.pull() 143 | if value is None: 144 | break 145 | rebuilt.append(value) 146 | assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) 147 | # The flushing mechanism might lead to "ghost" values at the end of the stream. 148 | assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), 149 | len(tokens), bits) 150 | for idx, (a, b) in enumerate(zip(tokens, rebuilt)): 151 | assert a == b, (idx, a, b) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /academicodec/models/encodec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/encodec/__init__.py -------------------------------------------------------------------------------- /academicodec/models/encodec/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | 4 | import torch 5 | import torchaudio 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class NSynthDataset(Dataset): 10 | """Dataset to load NSynth data.""" 11 | 12 | def __init__(self, audio_dir): 13 | super().__init__() 14 | self.filenames = [] 15 | self.filenames.extend(glob.glob(audio_dir + "/*.wav")) 16 | print(len(self.filenames)) 17 | _, self.sr = torchaudio.load(self.filenames[0]) 18 | self.max_len = 24000 # 24000 19 | 20 | def __len__(self): 21 | return len(self.filenames) 22 | 23 | def __getitem__(self, index): 24 | ans = torch.zeros(1, self.max_len) 25 | audio = torchaudio.load(self.filenames[index])[0] 26 | if audio.shape[1] > self.max_len: 27 | st = random.randint(0, audio.shape[1] - self.max_len - 1) 28 | ed = st + self.max_len 29 | return audio[:, st:ed] 30 | else: 31 | ans[:, :audio.shape[1]] = audio 32 | return ans 33 | -------------------------------------------------------------------------------- /academicodec/models/encodec/distributed/distributed.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Diffsound 3 | # code based https://github.com/cientgu/VQ-Diffusion 4 | # ------------------------------------------ 5 | import pickle 6 | 7 | import torch 8 | from torch import distributed as dist 9 | from torch.utils import data 10 | 11 | LOCAL_PROCESS_GROUP = None 12 | 13 | 14 | def is_primary(): 15 | return get_rank() == 0 16 | 17 | 18 | def get_rank(): 19 | if not dist.is_available(): 20 | return 0 21 | 22 | if not dist.is_initialized(): 23 | return 0 24 | 25 | return dist.get_rank() 26 | 27 | 28 | def get_local_rank(): 29 | if not dist.is_available(): 30 | return 0 31 | 32 | if not dist.is_initialized(): 33 | return 0 34 | 35 | if LOCAL_PROCESS_GROUP is None: 36 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 37 | 38 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 39 | 40 | 41 | def synchronize(): 42 | if not dist.is_available(): 43 | return 44 | 45 | if not dist.is_initialized(): 46 | return 47 | 48 | world_size = dist.get_world_size() 49 | 50 | if world_size == 1: 51 | return 52 | 53 | dist.barrier() 54 | 55 | 56 | def get_world_size(): 57 | if not dist.is_available(): 58 | return 1 59 | 60 | if not dist.is_initialized(): 61 | return 1 62 | 63 | return dist.get_world_size() 64 | 65 | 66 | def is_distributed(): 67 | raise RuntimeError('Please debug this function!') 68 | return get_world_size() > 1 69 | 70 | 71 | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False): 72 | world_size = get_world_size() 73 | 74 | if world_size == 1: 75 | return tensor 76 | dist.all_reduce(tensor, op=op, async_op=async_op) 77 | 78 | return tensor 79 | 80 | 81 | def all_gather(data): 82 | world_size = get_world_size() 83 | 84 | if world_size == 1: 85 | return [data] 86 | 87 | buffer = pickle.dumps(data) 88 | storage = torch.ByteStorage.from_buffer(buffer) 89 | tensor = torch.ByteTensor(storage).to("cuda") 90 | 91 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 92 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 93 | dist.all_gather(size_list, local_size) 94 | size_list = [int(size.item()) for size in size_list] 95 | max_size = max(size_list) 96 | 97 | tensor_list = [] 98 | for _ in size_list: 99 | tensor_list.append(torch.ByteTensor(size=(max_size, )).to("cuda")) 100 | 101 | if local_size != max_size: 102 | padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda") 103 | tensor = torch.cat((tensor, padding), 0) 104 | 105 | dist.all_gather(tensor_list, tensor) 106 | 107 | data_list = [] 108 | 109 | for size, tensor in zip(size_list, tensor_list): 110 | buffer = tensor.cpu().numpy().tobytes()[:size] 111 | data_list.append(pickle.loads(buffer)) 112 | 113 | return data_list 114 | 115 | 116 | def reduce_dict(input_dict, average=True): 117 | world_size = get_world_size() 118 | 119 | if world_size < 2: 120 | return input_dict 121 | 122 | with torch.no_grad(): 123 | keys = [] 124 | values = [] 125 | 126 | for k in sorted(input_dict.keys()): 127 | keys.append(k) 128 | values.append(input_dict[k]) 129 | 130 | values = torch.stack(values, 0) 131 | dist.reduce(values, dst=0) 132 | 133 | if dist.get_rank() == 0 and average: 134 | values /= world_size 135 | 136 | reduced_dict = {k: v for k, v in zip(keys, values)} 137 | 138 | return reduced_dict 139 | 140 | 141 | def data_sampler(dataset, shuffle, distributed): 142 | if distributed: 143 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 144 | 145 | if shuffle: 146 | return data.RandomSampler(dataset) 147 | 148 | else: 149 | return data.SequentialSampler(dataset) 150 | -------------------------------------------------------------------------------- /academicodec/models/encodec/distributed/launch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------ 2 | # Diffsound 3 | # code based https://github.com/cientgu/VQ-Diffusion 4 | # ------------------------------------------ 5 | import distributed.distributed as dist_fn 6 | import torch 7 | from torch import distributed as dist 8 | from torch import multiprocessing as mp 9 | 10 | # import distributed as dist_fn 11 | 12 | 13 | def find_free_port(): 14 | import socket 15 | 16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 17 | 18 | sock.bind(("", 0)) 19 | port = sock.getsockname()[1] 20 | sock.close() 21 | 22 | return port 23 | 24 | 25 | def launch(fn, 26 | n_gpu_per_machine, 27 | n_machine=1, 28 | machine_rank=0, 29 | dist_url=None, 30 | args=()): 31 | world_size = n_machine * n_gpu_per_machine 32 | 33 | if world_size > 1: 34 | # if "OMP_NUM_THREADS" not in os.environ: 35 | # os.environ["OMP_NUM_THREADS"] = "1" 36 | if dist_url == "auto": 37 | if n_machine != 1: 38 | raise ValueError( 39 | 'dist_url="auto" not supported in multi-machine jobs') 40 | port = find_free_port() 41 | dist_url = f"tcp://127.0.0.1:{port}" 42 | print('dist_url ', dist_url) 43 | print('n_machine ', n_machine) 44 | print('args ', args) 45 | print('world_size ', world_size) 46 | print('machine_rank ', machine_rank) 47 | if n_machine > 1 and dist_url.startswith("file://"): 48 | raise ValueError( 49 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 50 | ) 51 | 52 | mp.spawn( 53 | distributed_worker, 54 | nprocs=n_gpu_per_machine, 55 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, 56 | args), 57 | daemon=False, ) 58 | # n_machine ? world_size 59 | else: 60 | local_rank = 0 61 | fn(local_rank, *args) 62 | 63 | 64 | def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, 65 | machine_rank, dist_url, args): 66 | if not torch.cuda.is_available(): 67 | raise OSError("CUDA is not available. Please check your environments") 68 | 69 | global_rank = machine_rank * n_gpu_per_machine + local_rank 70 | print('local_rank ', local_rank) 71 | print('global_rank ', global_rank) 72 | try: 73 | dist.init_process_group( 74 | backend="NCCL", 75 | init_method=dist_url, 76 | world_size=world_size, 77 | rank=global_rank, ) 78 | 79 | except Exception: 80 | raise OSError("failed to initialize NCCL groups") 81 | 82 | # changed 83 | dist_fn.synchronize() 84 | 85 | if n_gpu_per_machine > torch.cuda.device_count(): 86 | raise ValueError( 87 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 88 | ) 89 | 90 | torch.cuda.set_device(local_rank) 91 | 92 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 93 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 94 | 95 | # change paert 96 | 97 | n_machine = world_size // n_gpu_per_machine 98 | for i in range(n_machine): 99 | ranks_on_i = list( 100 | range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 101 | pg = dist.new_group(ranks_on_i) 102 | 103 | if i == machine_rank: 104 | dist_fn.LOCAL_PROCESS_GROUP = pg 105 | 106 | fn(local_rank, *args) 107 | -------------------------------------------------------------------------------- /academicodec/models/encodec/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchaudio.transforms import MelSpectrogram 4 | 5 | 6 | def adversarial_g_loss(y_disc_gen): 7 | """Hinge loss""" 8 | loss = 0.0 9 | for i in range(len(y_disc_gen)): 10 | stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() 11 | loss += stft_loss 12 | return loss / len(y_disc_gen) 13 | 14 | 15 | def feature_loss(fmap_r, fmap_gen): 16 | loss = 0.0 17 | for i in range(len(fmap_r)): 18 | for j in range(len(fmap_r[i])): 19 | stft_loss = ((fmap_r[i][j] - fmap_gen[i][j]).abs() / 20 | (fmap_r[i][j].abs().mean())).mean() 21 | loss += stft_loss 22 | return loss / (len(fmap_r) * len(fmap_r[0])) 23 | 24 | 25 | def sim_loss(y_disc_r, y_disc_gen): 26 | loss = 0.0 27 | for i in range(len(y_disc_r)): 28 | loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) 29 | return loss / len(y_disc_r) 30 | 31 | # def sisnr_loss(x, s, eps=1e-8): 32 | # """ 33 | # calculate training loss 34 | # input: 35 | # x: separated signal, N x S tensor, estimate value 36 | # s: reference signal, N x S tensor, True value 37 | # Return: 38 | # sisnr: N tensor 39 | # """ 40 | # if x.shape != s.shape: 41 | # if x.shape[-1] > s.shape[-1]: 42 | # x = x[:, :s.shape[-1]] 43 | # else: 44 | # s = s[:, :x.shape[-1]] 45 | # def l2norm(mat, keepdim=False): 46 | # return torch.norm(mat, dim=-1, keepdim=keepdim) 47 | # if x.shape != s.shape: 48 | # raise RuntimeError( 49 | # "Dimention mismatch when calculate si-snr, {} vs {}".format( 50 | # x.shape, s.shape)) 51 | # x_zm = x - torch.mean(x, dim=-1, keepdim=True) 52 | # s_zm = s - torch.mean(s, dim=-1, keepdim=True) 53 | # t = torch.sum( 54 | # x_zm * s_zm, dim=-1, 55 | # keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 56 | # loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 57 | # return torch.sum(loss) / x.shape[0] 58 | 59 | 60 | def reconstruction_loss(x, G_x, args, eps=1e-7): 61 | # NOTE (lsx): hard-coded now 62 | L = args.LAMBDA_WAV * F.mse_loss(x, G_x) # wav L1 loss 63 | # loss_sisnr = sisnr_loss(G_x, x) # 64 | # L += 0.01*loss_sisnr 65 | # 2^6=64 -> 2^10=1024 66 | # NOTE (lsx): add 2^11 67 | for i in range(6, 12): 68 | # for i in range(5, 12): # Encodec setting 69 | s = 2**i 70 | melspec = MelSpectrogram( 71 | sample_rate=args.sr, 72 | n_fft=max(s, 512), 73 | win_length=s, 74 | hop_length=s // 4, 75 | n_mels=64, 76 | wkwargs={"device": args.device}).to(args.device) 77 | S_x = melspec(x) 78 | S_G_x = melspec(G_x) 79 | l1_loss = (S_x - S_G_x).abs().mean() 80 | l2_loss = (((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2).mean(dim=-2)**0.5).mean() 81 | 82 | alpha = (s / 2) ** 0.5 83 | L += (l1_loss + alpha * l2_loss) 84 | return L 85 | 86 | 87 | def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r, 88 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, 89 | fmap_s_r, fmap_s_g): 90 | """Hinge Loss""" 91 | loss = 0.0 92 | loss1 = 0.0 93 | loss2 = 0.0 94 | loss3 = 0.0 95 | for i in range(len(y_disc_r)): 96 | loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[ 97 | i]).mean() 98 | for i in range(len(y_df_hat_r)): 99 | loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[ 100 | i]).mean() 101 | for i in range(len(y_ds_hat_r)): 102 | loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[ 103 | i]).mean() 104 | 105 | loss = (loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / 106 | len(y_ds_hat_r)) / 3.0 107 | 108 | return loss 109 | 110 | 111 | def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, 112 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, 113 | y_ds_hat_g, fmap_s_r, fmap_s_g, args): 114 | adv_g_loss = adversarial_g_loss(y_disc_gen) 115 | feat_loss = (feature_loss(fmap_r, fmap_gen) + sim_loss( 116 | y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss( 117 | y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) + 118 | sim_loss(y_ds_hat_r, y_ds_hat_g)) / 3.0 119 | rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) 120 | total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV * adv_g_loss + args.LAMBDA_FEAT * feat_loss + args.LAMBDA_REC * rec_loss 121 | return total_loss, adv_g_loss, feat_loss, rec_loss 122 | 123 | 124 | def adopt_weight(weight, global_step, threshold=0, value=0.): 125 | if global_step < threshold: 126 | weight = value 127 | return weight 128 | 129 | 130 | def adopt_dis_weight(weight, global_step, threshold=0, value=0.): 131 | # 0,3,6,9,13....这些时间步,不更新dis 132 | if global_step % 3 == 0: 133 | weight = value 134 | return weight 135 | 136 | 137 | def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): 138 | if last_layer is not None: 139 | nll_grads = torch.autograd.grad( 140 | nll_loss, last_layer, retain_graph=True)[0] 141 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 142 | else: 143 | print('last_layer cannot be none') 144 | assert 1 == 2 145 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 146 | d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() 147 | d_weight = d_weight * args.LAMBDA_ADV 148 | return d_weight 149 | 150 | 151 | def loss_g(codebook_loss, 152 | inputs, 153 | reconstructions, 154 | fmap_r, 155 | fmap_gen, 156 | y_disc_r, 157 | y_disc_gen, 158 | global_step, 159 | y_df_hat_r, 160 | y_df_hat_g, 161 | y_ds_hat_r, 162 | y_ds_hat_g, 163 | fmap_f_r, 164 | fmap_f_g, 165 | fmap_s_r, 166 | fmap_s_g, 167 | last_layer=None, 168 | is_training=True, 169 | args=None): 170 | """ 171 | args: 172 | codebook_loss: commit loss. 173 | inputs: ground-truth wav. 174 | reconstructions: reconstructed wav. 175 | fmap_r: real stft-D feature map. 176 | fmap_gen: fake stft-D feature map. 177 | y_disc_r: real stft-D logits. 178 | y_disc_gen: fake stft-D logits. 179 | global_step: global training step. 180 | y_df_hat_r: real MPD logits. 181 | y_df_hat_g: fake MPD logits. 182 | y_ds_hat_r: real MSD logits. 183 | y_ds_hat_g: fake MSD logits. 184 | fmap_f_r: real MPD feature map. 185 | fmap_f_g: fake MPD feature map. 186 | fmap_s_r: real MSD feature map. 187 | fmap_s_g: fake MSD feature map. 188 | """ 189 | rec_loss = reconstruction_loss(inputs.contiguous(), 190 | reconstructions.contiguous(), args) 191 | adv_g_loss = adversarial_g_loss(y_disc_gen) 192 | adv_mpd_loss = adversarial_g_loss(y_df_hat_g) 193 | adv_msd_loss = adversarial_g_loss(y_ds_hat_g) 194 | adv_loss = (adv_g_loss + adv_mpd_loss + adv_msd_loss 195 | ) / 3.0 # NOTE(lsx): need to divide by 3? 196 | feat_loss = feature_loss( 197 | fmap_r, 198 | fmap_gen) #+ sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits? 199 | feat_loss_mpd = feature_loss(fmap_f_r, 200 | fmap_f_g) #+ sim_loss(y_df_hat_r, y_df_hat_g) 201 | feat_loss_msd = feature_loss(fmap_s_r, 202 | fmap_s_g) #+ sim_loss(y_ds_hat_r, y_ds_hat_g) 203 | feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 204 | d_weight = torch.tensor(1.0) 205 | # try: 206 | # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # 动态调整重构损失和对抗损失 207 | # except RuntimeError: 208 | # assert not is_training 209 | # d_weight = torch.tensor(0.0) 210 | disc_factor = adopt_weight( 211 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 212 | if disc_factor == 0.: 213 | fm_loss_wt = 0 214 | else: 215 | fm_loss_wt = args.LAMBDA_FEAT 216 | #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start) 217 | loss = rec_loss + d_weight * disc_factor * adv_loss + \ 218 | fm_loss_wt * feat_loss_tot + args.LAMBDA_COM * codebook_loss 219 | return loss, rec_loss, adv_loss, feat_loss_tot, d_weight 220 | 221 | 222 | def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r, 223 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, 224 | fmap_s_g, global_step, args): 225 | disc_factor = adopt_weight( 226 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 227 | d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det, 228 | fmap_gen_det, y_df_hat_r, y_df_hat_g, 229 | fmap_f_r, fmap_f_g, y_ds_hat_r, 230 | y_ds_hat_g, fmap_s_r, fmap_s_g) 231 | return d_loss 232 | -------------------------------------------------------------------------------- /academicodec/models/encodec/msstftd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """MS-STFT discriminator, provided here for reference.""" 7 | import typing as tp 8 | 9 | import torch 10 | import torchaudio 11 | from einops import rearrange 12 | from torch import nn 13 | 14 | from academicodec.modules import NormConv2d 15 | 16 | FeatureMapType = tp.List[torch.Tensor] 17 | LogitsType = torch.Tensor 18 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] 19 | 20 | 21 | def get_2d_padding(kernel_size: tp.Tuple[int, int], 22 | dilation: tp.Tuple[int, int]=(1, 1)): 23 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ( 24 | (kernel_size[1] - 1) * dilation[1]) // 2) 25 | 26 | 27 | class DiscriminatorSTFT(nn.Module): 28 | """STFT sub-discriminator. 29 | Args: 30 | filters (int): Number of filters in convolutions 31 | in_channels (int): Number of input channels. Default: 1 32 | out_channels (int): Number of output channels. Default: 1 33 | n_fft (int): Size of FFT for each scale. Default: 1024 34 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256 35 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` 36 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` 37 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` 38 | win_length (int): Window size for each scale. Default: 1024 39 | normalized (bool): Whether to normalize by magnitude after stft. Default: True 40 | norm (str): Normalization method. Default: `'weight_norm'` 41 | activation (str): Activation function. Default: `'LeakyReLU'` 42 | activation_params (dict): Parameters to provide to the activation function. 43 | growth (int): Growth factor for the filters. Default: 1 44 | """ 45 | 46 | def __init__(self, 47 | filters: int, 48 | in_channels: int=1, 49 | out_channels: int=1, 50 | n_fft: int=1024, 51 | hop_length: int=256, 52 | win_length: int=1024, 53 | max_filters: int=1024, 54 | filters_scale: int=1, 55 | kernel_size: tp.Tuple[int, int]=(3, 9), 56 | dilations: tp.List=[1, 2, 4], 57 | stride: tp.Tuple[int, int]=(1, 2), 58 | normalized: bool=True, 59 | norm: str='weight_norm', 60 | activation: str='LeakyReLU', 61 | activation_params: dict={'negative_slope': 0.2}): 62 | super().__init__() 63 | assert len(kernel_size) == 2 64 | assert len(stride) == 2 65 | self.filters = filters 66 | self.in_channels = in_channels 67 | self.out_channels = out_channels 68 | self.n_fft = n_fft 69 | self.hop_length = hop_length 70 | self.win_length = win_length 71 | self.normalized = normalized 72 | self.activation = getattr(torch.nn, activation)(**activation_params) 73 | self.spec_transform = torchaudio.transforms.Spectrogram( 74 | n_fft=self.n_fft, 75 | hop_length=self.hop_length, 76 | win_length=self.win_length, 77 | window_fn=torch.hann_window, 78 | normalized=self.normalized, 79 | center=False, 80 | pad_mode=None, 81 | power=None) 82 | spec_channels = 2 * self.in_channels 83 | self.convs = nn.ModuleList() 84 | self.convs.append( 85 | NormConv2d( 86 | spec_channels, 87 | self.filters, 88 | kernel_size=kernel_size, 89 | padding=get_2d_padding(kernel_size))) 90 | in_chs = min(filters_scale * self.filters, max_filters) 91 | for i, dilation in enumerate(dilations): 92 | out_chs = min((filters_scale**(i + 1)) * self.filters, max_filters) 93 | self.convs.append( 94 | NormConv2d( 95 | in_chs, 96 | out_chs, 97 | kernel_size=kernel_size, 98 | stride=stride, 99 | dilation=(dilation, 1), 100 | padding=get_2d_padding(kernel_size, (dilation, 1)), 101 | norm=norm)) 102 | in_chs = out_chs 103 | out_chs = min((filters_scale**(len(dilations) + 1)) * self.filters, 104 | max_filters) 105 | self.convs.append( 106 | NormConv2d( 107 | in_chs, 108 | out_chs, 109 | kernel_size=(kernel_size[0], kernel_size[0]), 110 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 111 | norm=norm)) 112 | self.conv_post = NormConv2d( 113 | out_chs, 114 | self.out_channels, 115 | kernel_size=(kernel_size[0], kernel_size[0]), 116 | padding=get_2d_padding((kernel_size[0], kernel_size[0])), 117 | norm=norm) 118 | 119 | def forward(self, x: torch.Tensor): 120 | fmap = [] 121 | # print('x ', x.shape) 122 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] 123 | # print('z ', z.shape) 124 | z = torch.cat([z.real, z.imag], dim=1) 125 | # print('cat_z ', z.shape) 126 | z = rearrange(z, 'b c w t -> b c t w') 127 | for i, layer in enumerate(self.convs): 128 | z = layer(z) 129 | z = self.activation(z) 130 | # print('z i', i, z.shape) 131 | fmap.append(z) 132 | z = self.conv_post(z) 133 | # print('logit ', z.shape) 134 | return z, fmap 135 | 136 | 137 | class MultiScaleSTFTDiscriminator(nn.Module): 138 | """Multi-Scale STFT (MS-STFT) discriminator. 139 | Args: 140 | filters (int): Number of filters in convolutions 141 | in_channels (int): Number of input channels. Default: 1 142 | out_channels (int): Number of output channels. Default: 1 143 | n_ffts (Sequence[int]): Size of FFT for each scale 144 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale 145 | win_lengths (Sequence[int]): Window size for each scale 146 | **kwargs: additional args for STFTDiscriminator 147 | """ 148 | 149 | def __init__(self, 150 | filters: int, 151 | in_channels: int=1, 152 | out_channels: int=1, 153 | n_ffts: tp.List[int]=[1024, 2048, 512, 256, 128], 154 | hop_lengths: tp.List[int]=[256, 512, 128, 64, 32], 155 | win_lengths: tp.List[int]=[1024, 2048, 512, 256, 128], 156 | **kwargs): 157 | super().__init__() 158 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths) 159 | self.discriminators = nn.ModuleList([ 160 | DiscriminatorSTFT( 161 | filters, 162 | in_channels=in_channels, 163 | out_channels=out_channels, 164 | n_fft=n_ffts[i], 165 | win_length=win_lengths[i], 166 | hop_length=hop_lengths[i], 167 | **kwargs) for i in range(len(n_ffts)) 168 | ]) 169 | self.num_discriminators = len(self.discriminators) 170 | 171 | def forward(self, x: torch.Tensor) -> DiscriminatorOutput: 172 | logits = [] 173 | fmaps = [] 174 | for disc in self.discriminators: 175 | logit, fmap = disc(x) 176 | logits.append(logit) 177 | fmaps.append(fmap) 178 | return logits, fmaps 179 | 180 | 181 | def test(): 182 | disc = MultiScaleSTFTDiscriminator(filters=32) 183 | y = torch.randn(1, 1, 24000) 184 | y_hat = torch.randn(1, 1, 24000) 185 | 186 | y_disc_r, fmap_r = disc(y) 187 | y_disc_gen, fmap_gen = disc(y_hat) 188 | assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len( 189 | fmap_gen) == disc.num_discriminators 190 | 191 | assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) 192 | assert all( 193 | [list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) 194 | assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) 195 | 196 | 197 | if __name__ == '__main__': 198 | test() 199 | -------------------------------------------------------------------------------- /academicodec/models/encodec/net3.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | from academicodec.modules.seanet import SEANetDecoder 7 | from academicodec.modules.seanet import SEANetEncoder 8 | from academicodec.quantization import ResidualVectorQuantizer 9 | 10 | 11 | # Generator 12 | class SoundStream(nn.Module): 13 | def __init__(self, 14 | n_filters, 15 | D, 16 | target_bandwidths=[7.5, 15], 17 | ratios=[8, 5, 4, 2], 18 | sample_rate=24000, 19 | bins=1024, 20 | normalize=False): 21 | super().__init__() 22 | self.hop_length = np.prod(ratios) # 计算乘积 23 | self.encoder = SEANetEncoder( 24 | n_filters=n_filters, dimension=D, ratios=ratios) 25 | n_q = int(1000 * target_bandwidths[-1] // 26 | (math.ceil(sample_rate / self.hop_length) * 10)) 27 | self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 75 28 | self.bits_per_codebook = int(math.log2(bins)) 29 | self.target_bandwidths = target_bandwidths 30 | self.quantizer = ResidualVectorQuantizer( 31 | dimension=D, n_q=n_q, bins=bins) 32 | self.decoder = SEANetDecoder( 33 | n_filters=n_filters, dimension=D, ratios=ratios) 34 | 35 | def get_last_layer(self): 36 | return self.decoder.layers[-1].weight 37 | 38 | def forward(self, x): 39 | e = self.encoder(x) 40 | max_idx = len(self.target_bandwidths) - 1 41 | bw = self.target_bandwidths[random.randint(0, max_idx)] 42 | quantized, codes, bandwidth, commit_loss = self.quantizer( 43 | e, self.frame_rate, bw) 44 | o = self.decoder(quantized) 45 | return o, commit_loss, None 46 | 47 | def encode(self, x, target_bw=None, st=None): 48 | e = self.encoder(x) 49 | if target_bw is None: 50 | bw = self.target_bandwidths[-1] 51 | else: 52 | bw = target_bw 53 | if st is None: 54 | st = 0 55 | codes = self.quantizer.encode(e, self.frame_rate, bw, st) 56 | return codes 57 | 58 | def decode(self, codes): 59 | quantized = self.quantizer.decode(codes) 60 | o = self.decoder(quantized) 61 | return o 62 | -------------------------------------------------------------------------------- /academicodec/models/encodec/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Command-line for audio compression.""" 7 | import argparse 8 | import os 9 | import sys 10 | import typing as tp 11 | from collections import OrderedDict 12 | from pathlib import Path 13 | 14 | import librosa 15 | import soundfile as sf 16 | import torch 17 | from academicodec.models.encodec.net3 import SoundStream 18 | 19 | 20 | def save_audio(wav: torch.Tensor, 21 | path: tp.Union[Path, str], 22 | sample_rate: int, 23 | rescale: bool=False): 24 | limit = 0.99 25 | mx = wav.abs().max() 26 | if rescale: 27 | wav = wav * min(limit / mx, 1) 28 | else: 29 | wav = wav.clamp(-limit, limit) 30 | wav = wav.squeeze().cpu().numpy() 31 | sf.write(path, wav, sample_rate) 32 | 33 | 34 | def get_parser(): 35 | parser = argparse.ArgumentParser( 36 | 'encodec', 37 | description='High fidelity neural audio codec. ' 38 | 'If input is a .ecdc, decompresses it. ' 39 | 'If input is .wav, compresses it. If output is also wav, ' 40 | 'do a compression/decompression cycle.') 41 | parser.add_argument( 42 | '--input', 43 | type=Path, 44 | help='Input file, whatever is supported by torchaudio on your system.') 45 | parser.add_argument( 46 | '--output', 47 | type=Path, 48 | nargs='?', 49 | help='Output file, otherwise inferred from input file.') 50 | parser.add_argument( 51 | '--resume_path', type=str, default='resume_path', help='resume_path') 52 | parser.add_argument( 53 | '--sr', type=int, default=16000, help='sample rate of model') 54 | parser.add_argument( 55 | '-r', 56 | '--rescale', 57 | action='store_true', 58 | help='Automatically rescale the output to avoid clipping.') 59 | parser.add_argument( 60 | '--ratios', 61 | type=int, 62 | nargs='+', 63 | # probs(ratios) = hop_size 64 | default=[8, 5, 4, 2], 65 | help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)' 66 | ) 67 | parser.add_argument( 68 | '--target_bandwidths', 69 | type=float, 70 | nargs='+', 71 | # default for 16k_320d 72 | default=[1, 1.5, 2, 4, 6, 12], 73 | help='target_bandwidths of net3.py') 74 | parser.add_argument( 75 | '--target_bw', 76 | type=float, 77 | # default for 16k_320d 78 | default=12, 79 | help='target_bw of net3.py') 80 | 81 | return parser 82 | 83 | 84 | def fatal(*args): 85 | print(*args, file=sys.stderr) 86 | sys.exit(1) 87 | 88 | 89 | # 这只是打印了但是没有真的 clip 90 | def check_clipping(wav, rescale): 91 | if rescale: 92 | return 93 | mx = wav.abs().max() 94 | limit = 0.99 95 | if mx > limit: 96 | print( 97 | f"Clipping!! max scale {mx}, limit is {limit}. " 98 | "To avoid clipping, use the `-r` option to rescale the output.", 99 | file=sys.stderr) 100 | 101 | 102 | def test_one(args, wav_root, store_root, rescale, soundstream): 103 | # torchaudio.load 的采样率为原始音频的采样率,不会自动下采样 104 | # wav, sr = torchaudio.load(wav_root) 105 | # # 取单声道, output shape [1, T] 106 | # wav = wav[0].unsqueeze(0) 107 | # # 重采样为模型的采样率 108 | # wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav) 109 | 110 | # load wav with librosa 111 | wav, sr = librosa.load(wav_root, sr=args.sr) 112 | wav = torch.tensor(wav).unsqueeze(0) 113 | 114 | # add batch axis 115 | wav = wav.unsqueeze(1).cuda() 116 | 117 | # compressing 118 | compressed = soundstream.encode(wav, target_bw=args.target_bw) 119 | print('finish compressing') 120 | out = soundstream.decode(compressed) 121 | out = out.detach().cpu().squeeze(0) 122 | check_clipping(out, rescale) 123 | save_audio(wav=out, path=store_root, sample_rate=args.sr, rescale=rescale) 124 | print('finish decompressing') 125 | 126 | 127 | def remove_encodec_weight_norm(model): 128 | from academicodec.modules import SConv1d 129 | from academicodec.modules.seanet import SConvTranspose1d 130 | from academicodec.modules.seanet import SEANetResnetBlock 131 | from torch.nn.utils import remove_weight_norm 132 | 133 | encoder = model.encoder.model 134 | for key in encoder._modules: 135 | if isinstance(encoder._modules[key], SEANetResnetBlock): 136 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv) 137 | block_modules = encoder._modules[key].block._modules 138 | for skey in block_modules: 139 | if isinstance(block_modules[skey], SConv1d): 140 | remove_weight_norm(block_modules[skey].conv.conv) 141 | elif isinstance(encoder._modules[key], SConv1d): 142 | remove_weight_norm(encoder._modules[key].conv.conv) 143 | 144 | decoder = model.decoder.model 145 | for key in decoder._modules: 146 | if isinstance(decoder._modules[key], SEANetResnetBlock): 147 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv) 148 | block_modules = decoder._modules[key].block._modules 149 | for skey in block_modules: 150 | if isinstance(block_modules[skey], SConv1d): 151 | remove_weight_norm(block_modules[skey].conv.conv) 152 | elif isinstance(decoder._modules[key], SConvTranspose1d): 153 | remove_weight_norm(decoder._modules[key].convtr.convtr) 154 | elif isinstance(decoder._modules[key], SConv1d): 155 | remove_weight_norm(decoder._modules[key].conv.conv) 156 | 157 | 158 | def test_batch(): 159 | args = get_parser().parse_args() 160 | print("args.target_bandwidths:", args.target_bandwidths) 161 | if not args.input.exists(): 162 | fatal(f"Input file {args.input} does not exist.") 163 | input_lists = os.listdir(args.input) 164 | input_lists.sort() 165 | soundstream = SoundStream( 166 | n_filters=32, 167 | D=512, 168 | ratios=args.ratios, 169 | sample_rate=args.sr, 170 | target_bandwidths=args.target_bandwidths) 171 | parameter_dict = torch.load(args.resume_path) 172 | new_state_dict = OrderedDict() 173 | # k 为 module.xxx.weight, v 为权重 174 | for k, v in parameter_dict.items(): 175 | # 截取`module.`后面的xxx.weight 176 | name = k[7:] 177 | new_state_dict[name] = v 178 | soundstream.load_state_dict(new_state_dict) # load model 179 | remove_encodec_weight_norm(soundstream) 180 | soundstream.cuda() 181 | soundstream.eval() 182 | os.makedirs(args.output, exist_ok=True) 183 | for audio in input_lists: 184 | test_one( 185 | args=args, 186 | wav_root=os.path.join(args.input, audio), 187 | store_root=os.path.join(args.output, audio), 188 | rescale=args.rescale, 189 | soundstream=soundstream) 190 | 191 | 192 | if __name__ == '__main__': 193 | test_batch() 194 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/hificodec/__init__.py -------------------------------------------------------------------------------- /academicodec/models/hificodec/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/meldataset.py: -------------------------------------------------------------------------------- 1 | # code based on https://github.com/b04901014/MQTTS 2 | import math 3 | import os 4 | import random 5 | 6 | import librosa 7 | import numpy as np 8 | import torch.utils.data 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | 12 | def load_wav(full_path, sr): 13 | wav, sr = librosa.load(full_path, sr=sr) 14 | return wav, sr 15 | 16 | 17 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 18 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 19 | 20 | 21 | def dynamic_range_decompression(x, C=1): 22 | return np.exp(x) / C 23 | 24 | 25 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 26 | return torch.log(torch.clamp(x, min=clip_val) * C) 27 | 28 | 29 | def dynamic_range_decompression_torch(x, C=1): 30 | return torch.exp(x) / C 31 | 32 | 33 | def spectral_normalize_torch(magnitudes): 34 | output = dynamic_range_compression_torch(magnitudes) 35 | return output 36 | 37 | 38 | def spectral_de_normalize_torch(magnitudes): 39 | output = dynamic_range_decompression_torch(magnitudes) 40 | return output 41 | 42 | 43 | mel_basis = {} 44 | hann_window = {} 45 | 46 | 47 | def mel_spectrogram(y, 48 | n_fft, 49 | num_mels, 50 | sampling_rate, 51 | hop_size, 52 | win_size, 53 | fmin, 54 | fmax, 55 | center=False): 56 | if torch.min(y) < -1.: 57 | print('min value is ', torch.min(y)) 58 | if torch.max(y) > 1.: 59 | print('max value is ', torch.max(y)) 60 | 61 | global mel_basis, hann_window 62 | if fmax not in mel_basis: 63 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 64 | mel_basis[str(fmax) + '_' + 65 | str(y.device)] = torch.from_numpy(mel).float().to(y.device) 66 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 67 | 68 | y = torch.nn.functional.pad( 69 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int( 70 | (n_fft - hop_size) / 2)), 71 | mode='reflect') 72 | y = y.squeeze(1) 73 | 74 | spec = torch.stft( 75 | y, 76 | n_fft, 77 | hop_length=hop_size, 78 | win_length=win_size, 79 | window=hann_window[str(y.device)], 80 | center=center, 81 | pad_mode='reflect', 82 | normalized=False, 83 | onesided=True) 84 | 85 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 86 | 87 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) 88 | spec = spectral_normalize_torch(spec) 89 | 90 | return spec 91 | 92 | 93 | def get_dataset_filelist(a): 94 | with open(a.input_training_file, 'r') as f: 95 | training_files = [l.strip() for l in f] 96 | with open(a.input_validation_file, 'r') as f: 97 | validation_files = [l.strip() for l in f] 98 | return training_files, validation_files 99 | 100 | 101 | class MelDataset(torch.utils.data.Dataset): 102 | def __init__(self, 103 | training_files, 104 | segment_size, 105 | n_fft, 106 | num_mels, 107 | hop_size, 108 | win_size, 109 | sampling_rate, 110 | fmin, 111 | fmax, 112 | split=True, 113 | shuffle=True, 114 | n_cache_reuse=1, 115 | device=None, 116 | fmax_loss=None, 117 | fine_tuning=False, 118 | base_mels_path=None): 119 | self.audio_files = training_files 120 | random.seed(1234) 121 | if shuffle: 122 | random.shuffle(self.audio_files) 123 | self.segment_size = segment_size 124 | self.sampling_rate = sampling_rate 125 | self.split = split 126 | self.n_fft = n_fft 127 | self.num_mels = num_mels 128 | self.hop_size = hop_size 129 | self.win_size = win_size 130 | self.fmin = fmin 131 | self.fmax = fmax 132 | self.fmax_loss = fmax_loss 133 | self.cached_wav = None 134 | self.n_cache_reuse = n_cache_reuse 135 | self._cache_ref_count = 0 136 | self.device = device 137 | self.fine_tuning = fine_tuning 138 | self.base_mels_path = base_mels_path 139 | 140 | def __getitem__(self, index): 141 | filename = self.audio_files[index] 142 | if self._cache_ref_count == 0: 143 | try: 144 | # Note by yuantian: load with the sample_rate of config 145 | audio, sampling_rate = load_wav(filename, sr=self.sampling_rate) 146 | except Exception as e: 147 | print(f"Error on audio: {filename}") 148 | audio = np.random.normal(size=(160000, )) * 0.05 149 | sampling_rate = self.sampling_rate 150 | self.cached_wav = audio 151 | if sampling_rate != self.sampling_rate: 152 | raise ValueError("{} SR doesn't match target {} SR".format( 153 | sampling_rate, self.sampling_rate)) 154 | self._cache_ref_count = self.n_cache_reuse 155 | else: 156 | audio = self.cached_wav 157 | self._cache_ref_count -= 1 158 | 159 | audio = torch.FloatTensor(audio) 160 | audio = audio.unsqueeze(0) 161 | 162 | if not self.fine_tuning: 163 | if self.split: 164 | if audio.size(1) >= self.segment_size: 165 | max_audio_start = audio.size(1) - self.segment_size 166 | audio_start = random.randint(0, max_audio_start) 167 | audio = audio[:, audio_start:audio_start + 168 | self.segment_size] 169 | else: 170 | audio = torch.nn.functional.pad(audio, ( 171 | 0, self.segment_size - audio.size(1)), 'constant') 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False) 183 | else: 184 | mel = np.load( 185 | os.path.join(self.base_mels_path, 186 | os.path.splitext(os.path.split(filename)[-1])[0] + 187 | '.npy')) 188 | mel = torch.from_numpy(mel) 189 | 190 | if len(mel.shape) < 3: 191 | mel = mel.unsqueeze(0) 192 | 193 | if self.split: 194 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 195 | 196 | if audio.size(1) >= self.segment_size: 197 | mel_start = random.randint(0, 198 | mel.size(2) - frames_per_seg - 1) 199 | mel = mel[:, :, mel_start:mel_start + frames_per_seg] 200 | audio = audio[:, mel_start * self.hop_size:( 201 | mel_start + frames_per_seg) * self.hop_size] 202 | else: 203 | mel = torch.nn.functional.pad(mel, ( 204 | 0, frames_per_seg - mel.size(2)), 'constant') 205 | audio = torch.nn.functional.pad(audio, ( 206 | 0, self.segment_size - audio.size(1)), 'constant') 207 | 208 | mel_loss = mel_spectrogram( 209 | audio, 210 | self.n_fft, 211 | self.num_mels, 212 | self.sampling_rate, 213 | self.hop_size, 214 | self.win_size, 215 | self.fmin, 216 | self.fmax_loss, 217 | center=False) 218 | 219 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 220 | 221 | def __len__(self): 222 | return len(self.audio_files) 223 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import itertools 4 | import os 5 | import time 6 | import argparse 7 | import json 8 | import torch 9 | import torch.nn.functional as F 10 | from torchaudio.transforms import MelSpectrogram 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DistributedSampler, DataLoader 13 | import torch.multiprocessing as mp 14 | from torch.distributed import init_process_group 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | from academicodec.models.hificodec.env import AttrDict, build_env 18 | from academicodec.models.hificodec.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist 19 | from academicodec.models.encodec.msstftd import MultiScaleSTFTDiscriminator 20 | from academicodec.models.hificodec.models import Generator 21 | from academicodec.models.hificodec.models import MultiPeriodDiscriminator 22 | from academicodec.models.hificodec.models import MultiScaleDiscriminator 23 | from academicodec.models.hificodec.models import feature_loss 24 | from academicodec.models.hificodec.models import generator_loss 25 | from academicodec.models.hificodec.models import discriminator_loss 26 | from academicodec.models.hificodec.models import Encoder 27 | from academicodec.models.hificodec.models import Quantizer 28 | from academicodec.utils import plot_spectrogram 29 | from academicodec.utils import scan_checkpoint 30 | from academicodec.utils import load_checkpoint 31 | from academicodec.utils import save_checkpoint 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | def reconstruction_loss(x, G_x, device, eps=1e-7): 37 | L = 100 * F.mse_loss(x, G_x) # wav L1 loss 38 | for i in range(6, 11): 39 | s = 2**i 40 | melspec = MelSpectrogram( 41 | sample_rate=24000, 42 | n_fft=s, 43 | hop_length=s // 4, 44 | n_mels=64, 45 | wkwargs={"device": device}).to(device) 46 | # 64, 16, 64 47 | # 128, 32, 128 48 | # 256, 64, 256 49 | # 512, 128, 512 50 | # 1024, 256, 1024 51 | S_x = melspec(x) 52 | S_G_x = melspec(G_x) 53 | loss = ((S_x - S_G_x).abs().mean() + ( 54 | ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2 55 | ).mean(dim=-2)**0.5).mean()) / (i) 56 | L += loss 57 | #print('i ,loss ', i, loss) 58 | #assert 1==2 59 | return L 60 | 61 | 62 | def train(rank, a, h): 63 | torch.cuda.set_device(rank) 64 | if h.num_gpus > 1: 65 | init_process_group( 66 | backend=h.dist_config['dist_backend'], 67 | init_method=h.dist_config['dist_url'], 68 | world_size=h.dist_config['world_size'] * h.num_gpus, 69 | rank=rank) 70 | 71 | torch.cuda.manual_seed(h.seed) 72 | device = torch.device('cuda:{:d}'.format(rank)) 73 | 74 | encoder = Encoder(h).to(device) 75 | generator = Generator(h).to(device) 76 | quantizer = Quantizer(h).to(device) 77 | mpd = MultiPeriodDiscriminator().to(device) 78 | msd = MultiScaleDiscriminator().to(device) 79 | mstftd = MultiScaleSTFTDiscriminator(32).to(device) 80 | if rank == 0: 81 | print(encoder) 82 | print(quantizer) 83 | print(generator) 84 | os.makedirs(a.checkpoint_path, exist_ok=True) 85 | print("checkpoints directory : ", a.checkpoint_path) 86 | 87 | if os.path.isdir(a.checkpoint_path): 88 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_') 89 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_') 90 | 91 | steps = 0 92 | if cp_g is None or cp_do is None: 93 | state_dict_do = None 94 | last_epoch = -1 95 | else: 96 | state_dict_g = load_checkpoint(cp_g, device) 97 | state_dict_do = load_checkpoint(cp_do, device) 98 | generator.load_state_dict(state_dict_g['generator']) 99 | encoder.load_state_dict(state_dict_g['encoder']) 100 | quantizer.load_state_dict(state_dict_g['quantizer']) 101 | mpd.load_state_dict(state_dict_do['mpd']) 102 | msd.load_state_dict(state_dict_do['msd']) 103 | mstftd.load_state_dict(state_dict_do['mstftd']) 104 | steps = state_dict_do['steps'] + 1 105 | last_epoch = state_dict_do['epoch'] 106 | 107 | if h.num_gpus > 1: 108 | generator = DistributedDataParallel( 109 | generator, device_ids=[rank]).to(device) 110 | encoder = DistributedDataParallel(encoder, device_ids=[rank]).to(device) 111 | quantizer = DistributedDataParallel( 112 | quantizer, device_ids=[rank]).to(device) 113 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) 114 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) 115 | mstftd = DistributedDataParallel(mstftd, device_ids=[rank]).to(device) 116 | 117 | optim_g = torch.optim.Adam( 118 | itertools.chain(generator.parameters(), 119 | encoder.parameters(), quantizer.parameters()), 120 | h.learning_rate, 121 | betas=[h.adam_b1, h.adam_b2]) 122 | optim_d = torch.optim.Adam( 123 | itertools.chain(msd.parameters(), mpd.parameters(), 124 | mstftd.parameters()), 125 | h.learning_rate, 126 | betas=[h.adam_b1, h.adam_b2]) 127 | if state_dict_do is not None: 128 | optim_g.load_state_dict(state_dict_do['optim_g']) 129 | optim_d.load_state_dict(state_dict_do['optim_d']) 130 | 131 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 132 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch) 133 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 134 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch) 135 | 136 | training_filelist, validation_filelist = get_dataset_filelist(a) 137 | 138 | trainset = MelDataset( 139 | training_filelist, 140 | h.segment_size, 141 | h.n_fft, 142 | h.num_mels, 143 | h.hop_size, 144 | h.win_size, 145 | h.sampling_rate, 146 | h.fmin, 147 | h.fmax, 148 | n_cache_reuse=0, 149 | shuffle=False if h.num_gpus > 1 else True, 150 | fmax_loss=h.fmax_for_loss, 151 | device=device, 152 | fine_tuning=a.fine_tuning, 153 | base_mels_path=a.input_mels_dir) 154 | 155 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None 156 | 157 | train_loader = DataLoader( 158 | trainset, 159 | num_workers=h.num_workers, 160 | shuffle=False, 161 | sampler=train_sampler, 162 | batch_size=h.batch_size, 163 | pin_memory=True, 164 | drop_last=True) 165 | 166 | if rank == 0: 167 | validset = MelDataset( 168 | validation_filelist, 169 | h.segment_size, 170 | h.n_fft, 171 | h.num_mels, 172 | h.hop_size, 173 | h.win_size, 174 | h.sampling_rate, 175 | h.fmin, 176 | h.fmax, 177 | False, 178 | False, 179 | n_cache_reuse=0, 180 | fmax_loss=h.fmax_for_loss, 181 | device=device, 182 | fine_tuning=a.fine_tuning, 183 | base_mels_path=a.input_mels_dir) 184 | validation_loader = DataLoader( 185 | validset, 186 | num_workers=1, 187 | shuffle=False, 188 | sampler=None, 189 | batch_size=1, 190 | pin_memory=True, 191 | drop_last=True) 192 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) 193 | plot_gt_once = False 194 | generator.train() 195 | encoder.train() 196 | quantizer.train() 197 | mpd.train() 198 | msd.train() 199 | for epoch in range(max(0, last_epoch), a.training_epochs): 200 | if rank == 0: 201 | start = time.time() 202 | print("Epoch: {}".format(epoch + 1)) 203 | if h.num_gpus > 1: 204 | train_sampler.set_epoch(epoch) 205 | for i, batch in enumerate(train_loader): 206 | if rank == 0: 207 | start_b = time.time() 208 | x, y, _, y_mel = batch 209 | x = torch.autograd.Variable(x.to(device, non_blocking=True)) 210 | y = torch.autograd.Variable(y.to(device, non_blocking=True)) 211 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) 212 | y = y.unsqueeze(1) 213 | 214 | c = encoder(y) 215 | # print("c.shape: ", c.shape) 216 | q, loss_q, c = quantizer(c) 217 | # print("q.shape: ", q.shape) 218 | y_g_hat = generator(q) 219 | y_g_hat_mel = mel_spectrogram( 220 | y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, 221 | h.hop_size, h.win_size, h.fmin, 222 | h.fmax_for_loss) # 1024, 80, 24000, 240,1024 223 | y_r_mel_1 = mel_spectrogram( 224 | y.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512, 225 | h.fmin, h.fmax_for_loss) 226 | y_g_mel_1 = mel_spectrogram( 227 | y_g_hat.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512, 228 | h.fmin, h.fmax_for_loss) 229 | y_r_mel_2 = mel_spectrogram( 230 | y.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256, h.fmin, 231 | h.fmax_for_loss) 232 | y_g_mel_2 = mel_spectrogram( 233 | y_g_hat.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256, 234 | h.fmin, h.fmax_for_loss) 235 | y_r_mel_3 = mel_spectrogram( 236 | y.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128, h.fmin, 237 | h.fmax_for_loss) 238 | y_g_mel_3 = mel_spectrogram( 239 | y_g_hat.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128, 240 | h.fmin, h.fmax_for_loss) 241 | # print("x.shape: ", x.shape) 242 | # print("y.shape: ", y.shape) 243 | # print("y_g_hat.shape: ", y_g_hat.shape) 244 | optim_d.zero_grad() 245 | 246 | # MPD 247 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 248 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 249 | y_df_hat_r, y_df_hat_g) 250 | 251 | # MSD 252 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 253 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 254 | y_ds_hat_r, y_ds_hat_g) 255 | 256 | y_disc_r, fmap_r = mstftd(y) 257 | y_disc_gen, fmap_gen = mstftd(y_g_hat.detach()) 258 | loss_disc_stft, losses_disc_stft_r, losses_disc_stft_g = discriminator_loss( 259 | y_disc_r, y_disc_gen) 260 | loss_disc_all = loss_disc_s + loss_disc_f + loss_disc_stft 261 | 262 | loss_disc_all.backward() 263 | optim_d.step() 264 | 265 | # Generator 266 | optim_g.zero_grad() 267 | 268 | # L1 Mel-Spectrogram Loss 269 | loss_mel1 = F.l1_loss(y_r_mel_1, y_g_mel_1) 270 | loss_mel2 = F.l1_loss(y_r_mel_2, y_g_mel_2) 271 | loss_mel3 = F.l1_loss(y_r_mel_3, y_g_mel_3) 272 | #print('loss_mel1, loss_mel2 ', loss_mel1, loss_mel2) 273 | loss_mel = F.l1_loss(y_mel, 274 | y_g_hat_mel) * 45 + loss_mel1 + loss_mel2 275 | # print('loss_mel ', loss_mel) 276 | # assert 1==2 277 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 278 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 279 | y_stftd_hat_r, fmap_stftd_r = mstftd(y) 280 | y_stftd_hat_g, fmap_stftd_g = mstftd(y_g_hat) 281 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 282 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 283 | loss_fm_stft = feature_loss(fmap_stftd_r, fmap_stftd_g) 284 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 285 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 286 | loss_gen_stft, losses_gen_stft = generator_loss(y_stftd_hat_g) 287 | loss_gen_all = loss_gen_s + loss_gen_f + loss_gen_stft + loss_fm_s + loss_fm_f + loss_fm_stft + loss_mel + loss_q * 10 288 | loss_gen_all.backward() 289 | optim_g.step() 290 | if rank == 0: 291 | # STDOUT logging 292 | if steps % a.stdout_interval == 0: 293 | with torch.no_grad(): 294 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() 295 | print( 296 | 'Steps : {:d}, Gen Loss Total : {:4.3f}, Loss Q : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. 297 | format(steps, loss_gen_all, loss_q, mel_error, 298 | time.time() - start_b)) 299 | # checkpointing 300 | if steps % a.checkpoint_interval == 0 and steps != 0: 301 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, 302 | steps) 303 | save_checkpoint( 304 | checkpoint_path, { 305 | 'generator': (generator.module if h.num_gpus > 1 306 | else generator).state_dict(), 307 | 'encoder': (encoder.module if h.num_gpus > 1 else 308 | encoder).state_dict(), 309 | 'quantizer': (quantizer.module if h.num_gpus > 1 310 | else quantizer).state_dict() 311 | }, 312 | num_ckpt_keep=a.num_ckpt_keep) 313 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, 314 | steps) 315 | save_checkpoint( 316 | checkpoint_path, { 317 | 'mpd': (mpd.module 318 | if h.num_gpus > 1 else mpd).state_dict(), 319 | 'msd': (msd.module 320 | if h.num_gpus > 1 else msd).state_dict(), 321 | 'mstftd': (mstftd.module 322 | if h.num_gpus > 1 else msd).state_dict(), 323 | 'optim_g': 324 | optim_g.state_dict(), 325 | 'optim_d': 326 | optim_d.state_dict(), 327 | 'steps': 328 | steps, 329 | 'epoch': 330 | epoch 331 | }, 332 | num_ckpt_keep=a.num_ckpt_keep) 333 | # Tensorboard summary logging 334 | if steps % a.summary_interval == 0: 335 | sw.add_scalar("training/gen_loss_total", loss_gen_all, 336 | steps) 337 | sw.add_scalar("training/mel_spec_error", mel_error, steps) 338 | 339 | # Validation 340 | if steps % a.validation_interval == 0 and steps != 0: 341 | generator.eval() 342 | encoder.eval() 343 | quantizer.eval() 344 | torch.cuda.empty_cache() 345 | val_err_tot = 0 346 | with torch.no_grad(): 347 | for j, batch in enumerate(validation_loader): 348 | x, y, _, y_mel = batch 349 | c = encoder(y.to(device).unsqueeze(1)) 350 | q, loss_q, c = quantizer(c) 351 | y_g_hat = generator(q) 352 | y_mel = torch.autograd.Variable(y_mel.to(device)) 353 | y_g_hat_mel = mel_spectrogram( 354 | y_g_hat.squeeze(1), h.n_fft, h.num_mels, 355 | h.sampling_rate, h.hop_size, h.win_size, h.fmin, 356 | h.fmax_for_loss) 357 | i_size = min(y_mel.size(2), y_g_hat_mel.size(2)) 358 | val_err_tot += F.l1_loss( 359 | y_mel[:, :, :i_size], 360 | y_g_hat_mel[:, :, :i_size]).item() 361 | 362 | if j <= 8: 363 | # if steps == 0: 364 | if not plot_gt_once: 365 | sw.add_audio('gt/y_{}'.format(j), y[0], 366 | steps, h.sampling_rate) 367 | sw.add_figure('gt/y_spec_{}'.format(j), 368 | plot_spectrogram(x[0]), steps) 369 | 370 | sw.add_audio('generated/y_hat_{}'.format(j), 371 | y_g_hat[0], steps, h.sampling_rate) 372 | y_hat_spec = mel_spectrogram( 373 | y_g_hat.squeeze(1), h.n_fft, h.num_mels, 374 | h.sampling_rate, h.hop_size, h.win_size, 375 | h.fmin, h.fmax) 376 | sw.add_figure( 377 | 'generated/y_hat_spec_{}'.format(j), 378 | plot_spectrogram( 379 | y_hat_spec.squeeze(0).cpu().numpy()), 380 | steps) 381 | 382 | val_err = val_err_tot / (j + 1) 383 | sw.add_scalar("validation/mel_spec_error", val_err, 384 | steps) 385 | if not plot_gt_once: 386 | plot_gt_once = True 387 | 388 | generator.train() 389 | 390 | steps += 1 391 | 392 | scheduler_g.step() 393 | scheduler_d.step() 394 | 395 | if rank == 0: 396 | print('Time taken for epoch {} is {} sec\n'.format( 397 | epoch + 1, int(time.time() - start))) 398 | 399 | 400 | def main(): 401 | print('Initializing Training Process..') 402 | 403 | parser = argparse.ArgumentParser() 404 | 405 | # parser.add_argument('--group_name', default=None) 406 | # parser.add_argument('--input_wavs_dir', default='../datasets/audios') 407 | parser.add_argument('--input_mels_dir', default=None) 408 | parser.add_argument('--input_training_file', required=True) 409 | parser.add_argument('--input_validation_file', required=True) 410 | parser.add_argument('--checkpoint_path', default='checkpoints') 411 | parser.add_argument('--config', default='') 412 | parser.add_argument('--training_epochs', default=2000, type=int) 413 | parser.add_argument('--stdout_interval', default=5, type=int) 414 | parser.add_argument('--checkpoint_interval', default=5000, type=int) 415 | parser.add_argument('--summary_interval', default=100, type=int) 416 | parser.add_argument('--validation_interval', default=5000, type=int) 417 | parser.add_argument('--num_ckpt_keep', default=5, type=int) 418 | parser.add_argument('--fine_tuning', default=False, type=bool) 419 | 420 | a = parser.parse_args() 421 | 422 | with open(a.config) as f: 423 | data = f.read() 424 | 425 | json_config = json.loads(data) 426 | h = AttrDict(json_config) 427 | build_env(a.config, 'config.json', a.checkpoint_path) 428 | 429 | torch.manual_seed(h.seed) 430 | if torch.cuda.is_available(): 431 | torch.cuda.manual_seed(h.seed) 432 | h.num_gpus = torch.cuda.device_count() 433 | h.batch_size = int(h.batch_size / h.num_gpus) 434 | print('Batch size per GPU :', h.batch_size) 435 | else: 436 | pass 437 | 438 | if h.num_gpus > 1: 439 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h, )) 440 | else: 441 | train(0, a, h) 442 | 443 | 444 | if __name__ == '__main__': 445 | main() 446 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/vqvae.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from academicodec.models.hificodec.env import AttrDict 7 | from academicodec.models.hificodec.models import Encoder 8 | from academicodec.models.hificodec.models import Generator 9 | from academicodec.models.hificodec.models import Quantizer 10 | 11 | 12 | class VQVAE(nn.Module): 13 | def __init__(self, 14 | config_path, 15 | ckpt_path, 16 | with_encoder=False): 17 | super(VQVAE, self).__init__() 18 | ckpt = torch.load(ckpt_path) 19 | with open(config_path) as f: 20 | data = f.read() 21 | json_config = json.loads(data) 22 | self.h = AttrDict(json_config) 23 | self.quantizer = Quantizer(self.h) 24 | self.generator = Generator(self.h) 25 | self.generator.load_state_dict(ckpt['generator']) 26 | self.quantizer.load_state_dict(ckpt['quantizer']) 27 | if with_encoder: 28 | self.encoder = Encoder(self.h) 29 | self.encoder.load_state_dict(ckpt['encoder']) 30 | 31 | def forward(self, x): 32 | # x is the codebook 33 | # x.shape (B, T, Nq) 34 | quant_emb = self.quantizer.embed(x) 35 | return self.generator(quant_emb) 36 | 37 | def encode(self, x): 38 | batch_size = x.size(0) 39 | if len(x.shape) == 3 and x.shape[-1] == 1: 40 | x = x.squeeze(-1) 41 | c = self.encoder(x.unsqueeze(1)) 42 | q, loss_q, c = self.quantizer(c) 43 | c = [code.reshape(batch_size, -1) for code in c] 44 | # shape: [N, T, 4] 45 | return torch.stack(c, -1) 46 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/vqvae_copy_syn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | from pathlib import Path 6 | 7 | import soundfile as sf 8 | from tqdm import tqdm 9 | 10 | from academicodec.models.hificodec.vqvae_tester import VqvaeTester 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | #Path 15 | parser.add_argument('--outputdir', type=str, required=True) 16 | parser.add_argument('--model_path', type=str, required=True) 17 | parser.add_argument('--input_wavdir', type=str, required=True) 18 | parser.add_argument('--config_path', type=str, required=True) 19 | parser.add_argument('--num_gens', type=int, default=1024) 20 | 21 | #Data 22 | parser.add_argument('--sample_rate', type=int, default=24000) 23 | 24 | args = parser.parse_args() 25 | 26 | with open(args.config_path, 'r') as f: 27 | argdict = json.load(f) 28 | assert argdict['sampling_rate'] == args.sample_rate, \ 29 | f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}" 30 | argdict.update(args.__dict__) 31 | args.__dict__ = argdict 32 | 33 | if __name__ == '__main__': 34 | Path(args.outputdir).mkdir(parents=True, exist_ok=True) 35 | print("Init model and load weights") 36 | model = VqvaeTester(config_path=args.config_path, model_path=args.model_path,sample_rate=args.sample_rate) 37 | model.cuda() 38 | model.vqvae.generator.remove_weight_norm() 39 | model.vqvae.encoder.remove_weight_norm() 40 | model.eval() 41 | print("Model ready") 42 | 43 | wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens] 44 | print(f"Globbed {len(wav_paths)} wav files.") 45 | 46 | for wav_path in wav_paths: 47 | fid, wav = model(wav_path) 48 | wav = wav.squeeze().cpu().numpy() 49 | sf.write( 50 | os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate) 51 | -------------------------------------------------------------------------------- /academicodec/models/hificodec/vqvae_tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import torch 5 | import torch.nn as nn 6 | 7 | from academicodec.models.hificodec.vqvae import VQVAE 8 | 9 | 10 | class VqvaeTester(nn.Module): 11 | def __init__(self, config_path, model_path, sample_rate=24000): 12 | super().__init__() 13 | self.vqvae = VQVAE(config_path, model_path, with_encoder=True) 14 | self.sample_rate = sample_rate 15 | 16 | @torch.no_grad() 17 | def forward(self, wav_path): 18 | # 单声道 19 | # wav.shape (T, ), 按照模型的 sr 读取 20 | wav, sr = librosa.load(wav_path, sr=self.sample_rate) 21 | fid = os.path.basename(wav_path)[:-4] 22 | wav = torch.tensor(wav).unsqueeze(0) 23 | wav = wav.cuda() 24 | # vq_codes is acoustic token 25 | vq_codes = self.vqvae.encode(wav) 26 | syn = self.vqvae(vq_codes) 27 | return fid, syn 28 | 29 | @torch.no_grad() 30 | def vq(self, wav_path): 31 | wav, sr = librosa.load(wav_path, sr=self.sample_rate) 32 | fid = os.path.basename(wav_path)[:-4] 33 | wav = torch.tensor(wav).unsqueeze(0) 34 | wav = wav.cuda() 35 | # vq_codes is acoustic token 36 | vq_codes = self.vqvae.encode(wav) 37 | return fid, vq_codes 38 | -------------------------------------------------------------------------------- /academicodec/models/soundstream/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/soundstream/__init__.py -------------------------------------------------------------------------------- /academicodec/models/soundstream/dataset.py: -------------------------------------------------------------------------------- 1 | # 和 Encodec* 的 dataset.py 有点类似但是不完全一样 2 | # 主要是 prob > 0.7 的时候多了 ans2 3 | import glob 4 | import random 5 | 6 | import torch 7 | import torchaudio 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class NSynthDataset(Dataset): 12 | """Dataset to load NSynth data.""" 13 | 14 | def __init__(self, audio_dir): 15 | super().__init__() 16 | self.filenames = [] 17 | self.filenames.extend(glob.glob(audio_dir + "/*.wav")) 18 | print(len(self.filenames)) 19 | _, self.sr = torchaudio.load(self.filenames[0]) 20 | self.max_len = 24000 # 24000 21 | 22 | def __len__(self): 23 | return len(self.filenames) 24 | 25 | def __getitem__(self, index): 26 | #print(self.filenames[index]) 27 | prob = random.random() # (0,1) 28 | if prob > 0.7: 29 | # data augmentation 30 | ans1 = torch.zeros(1, self.max_len) 31 | ans2 = torch.zeros(1, self.max_len) 32 | audio1 = torchaudio.load(self.filenames[index])[0] 33 | index2 = random.randint(0, len(self.filenames) - 1) 34 | audio2 = torchaudio.load(self.filenames[index2])[0] 35 | if audio1.shape[1] > self.max_len: 36 | st = random.randint(0, audio1.shape[1] - self.max_len - 1) 37 | ed = st + self.max_len 38 | ans1 = audio1[:, st:ed] 39 | else: 40 | ans1[:, :audio1.shape[1]] = audio1 41 | if audio2.shape[1] > self.max_len: 42 | st = random.randint(0, audio2.shape[1] - self.max_len - 1) 43 | ed = st + self.max_len 44 | ans2 = audio2[:, st:ed] 45 | else: 46 | ans2[:, :audio2.shape[1]] = audio2 47 | ans = ans1 + ans2 48 | return ans 49 | else: 50 | ans = torch.zeros(1, self.max_len) 51 | audio = torchaudio.load(self.filenames[index])[0] 52 | if audio.shape[1] > self.max_len: 53 | st = random.randint(0, audio.shape[1] - self.max_len - 1) 54 | ed = st + self.max_len 55 | return audio[:, st:ed] 56 | else: 57 | ans[:, :audio.shape[1]] = audio 58 | return ans 59 | -------------------------------------------------------------------------------- /academicodec/models/soundstream/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchaudio.transforms import MelSpectrogram 4 | 5 | 6 | def adversarial_g_loss(y_disc_gen): 7 | loss = 0.0 8 | for i in range(len(y_disc_gen)): 9 | #print(y_disc_gen[i].shape) 10 | # assert 1==2 11 | stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() 12 | loss += stft_loss 13 | return loss / len(y_disc_gen) 14 | 15 | 16 | def feature_loss(fmap_r, fmap_gen): 17 | loss = 0.0 18 | for i in range(len(fmap_r)): 19 | for j in range(len(fmap_r[i])): 20 | stft_loss = ((fmap_r[i][j] - fmap_gen[i][j]).abs() / 21 | (fmap_r[i][j].abs().mean())).mean() 22 | loss += stft_loss 23 | return loss / (len(fmap_r) * len(fmap_r[0])) 24 | 25 | 26 | def sim_loss(y_disc_r, y_disc_gen): 27 | loss = 0.0 28 | for i in range(len(y_disc_r)): 29 | loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) 30 | return loss / len(y_disc_r) 31 | 32 | 33 | def sisnr_loss(x, s, eps=1e-8): 34 | """ 35 | calculate training loss 36 | input: 37 | x: separated signal, N x S tensor, estimate value 38 | s: reference signal, N x S tensor, True value 39 | Return: 40 | sisnr: N tensor 41 | """ 42 | if x.shape != s.shape: 43 | if x.shape[-1] > s.shape[-1]: 44 | x = x[:, :s.shape[-1]] 45 | else: 46 | s = s[:, :x.shape[-1]] 47 | 48 | def l2norm(mat, keepdim=False): 49 | return torch.norm(mat, dim=-1, keepdim=keepdim) 50 | 51 | if x.shape != s.shape: 52 | raise RuntimeError("Dimention mismatch when calculate si-snr, {} vs {}". 53 | format(x.shape, s.shape)) 54 | x_zm = x - torch.mean(x, dim=-1, keepdim=True) 55 | s_zm = s - torch.mean(s, dim=-1, keepdim=True) 56 | t = torch.sum( 57 | x_zm * s_zm, dim=-1, 58 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 59 | loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 60 | return torch.sum(loss) / x.shape[0] 61 | 62 | 63 | def reconstruction_loss(x, G_x, args, eps=1e-7): 64 | L = 100 * F.mse_loss(x, G_x) # wav L1 loss 65 | #loss_sisnr = sisnr_loss(G_x, x) # 66 | #L += 0.01*loss_sisnr 67 | # print('L0 ', L) 68 | # print('loss_sisnr ', 0.01*loss_sisnr) 69 | # print('L0 ', L) 70 | for i in range(6, 11): 71 | s = 2**i 72 | melspec = MelSpectrogram( 73 | sample_rate=args.sr, 74 | n_fft=max(s, 512), 75 | win_length=s, 76 | hop_length=s // 4, 77 | n_mels=64, 78 | wkwargs={"device": args.device}).to(args.device) 79 | S_x = melspec(x) 80 | S_G_x = melspec(G_x) 81 | l1_loss = (S_x - S_G_x).abs().mean() 82 | l2_loss = (((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2).mean(dim=-2)**0.5).mean() 83 | 84 | alpha = (s / 2) ** 0.5 85 | L += (l1_loss + alpha * l2_loss) 86 | #print('i ,loss ', i, loss) 87 | #assert 1==2 88 | return L 89 | 90 | 91 | def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r, 92 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, 93 | fmap_s_r, fmap_s_g): 94 | loss = 0.0 95 | loss1 = 0.0 96 | loss2 = 0.0 97 | loss3 = 0.0 98 | loss_f = feature_loss(fmap_r_det, fmap_gen_det) + feature_loss( 99 | fmap_f_r, fmap_f_g) + feature_loss(fmap_s_r, fmap_s_g) 100 | for i in range(len(y_disc_r)): 101 | loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[ 102 | i]).mean() 103 | for i in range(len(y_df_hat_r)): 104 | loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[ 105 | i]).mean() 106 | for i in range(len(y_ds_hat_r)): 107 | loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[ 108 | i]).mean() 109 | loss = (loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / 110 | len(y_ds_hat_r)) / 3.0 111 | return loss + 0.0 * loss_f 112 | 113 | 114 | def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, 115 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, 116 | y_ds_hat_g, fmap_s_r, fmap_s_g, args): 117 | adv_g_loss = adversarial_g_loss(y_disc_gen) 118 | feat_loss = (feature_loss(fmap_r, fmap_gen) + sim_loss( 119 | y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss( 120 | y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) + 121 | sim_loss(y_ds_hat_r, y_ds_hat_g)) / 3.0 122 | rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) 123 | total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV * adv_g_loss + \ 124 | args.LAMBDA_FEAT * feat_loss + args.LAMBDA_REC * rec_loss 125 | return total_loss, adv_g_loss, feat_loss, rec_loss 126 | 127 | 128 | def adopt_weight(weight, global_step, threshold=0, value=0.): 129 | if global_step < threshold: 130 | weight = value 131 | return weight 132 | 133 | 134 | def adopt_dis_weight(weight, global_step, threshold=0, value=0.): 135 | if global_step % 3 == 0: # 0,3,6,9,13....这些时间步,不更新dis 136 | weight = value 137 | return weight 138 | 139 | 140 | def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): 141 | if last_layer is not None: 142 | nll_grads = torch.autograd.grad( 143 | nll_loss, last_layer, retain_graph=True)[0] 144 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 145 | else: 146 | print('last_layer cannot be none') 147 | assert 1 == 2 148 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 149 | d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() 150 | d_weight = d_weight * args.LAMBDA_ADV 151 | return d_weight 152 | 153 | 154 | def loss_g(codebook_loss, 155 | inputs, 156 | reconstructions, 157 | fmap_r, 158 | fmap_gen, 159 | y_disc_r, 160 | y_disc_gen, 161 | global_step, 162 | y_df_hat_r, 163 | y_df_hat_g, 164 | y_ds_hat_r, 165 | y_ds_hat_g, 166 | fmap_f_r, 167 | fmap_f_g, 168 | fmap_s_r, 169 | fmap_s_g, 170 | last_layer=None, 171 | is_training=True, 172 | args=None): 173 | rec_loss = reconstruction_loss(inputs.contiguous(), 174 | reconstructions.contiguous(), args) 175 | adv_g_loss = adversarial_g_loss(y_disc_gen) 176 | adv_mpd_loss = adversarial_g_loss(y_df_hat_g) 177 | adv_msd_loss = adversarial_g_loss(y_ds_hat_g) 178 | adv_loss = (adv_g_loss + adv_mpd_loss + adv_msd_loss) / 3.0 179 | feat_loss = feature_loss(fmap_r, fmap_gen) + sim_loss(y_disc_r, 180 | y_disc_gen) # 181 | feat_loss_mpd = feature_loss(fmap_f_r, fmap_f_g) + sim_loss(y_df_hat_r, 182 | y_df_hat_g) 183 | feat_loss_msd = feature_loss(fmap_s_r, fmap_s_g) + sim_loss(y_ds_hat_r, 184 | y_ds_hat_g) 185 | feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 186 | d_weight = torch.tensor(1.0) 187 | # try: 188 | # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # 动态调整重构损失和对抗损失 189 | # except RuntimeError: 190 | # assert not is_training 191 | # d_weight = torch.tensor(0.0) 192 | disc_factor = adopt_weight( 193 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 194 | #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start) 195 | loss = rec_loss + d_weight * disc_factor * adv_loss + \ 196 | args.LAMBDA_FEAT * feat_loss_tot + args.LAMBDA_COM * codebook_loss 197 | return loss, rec_loss, adv_loss, feat_loss_tot, d_weight 198 | 199 | 200 | def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r, 201 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, 202 | fmap_s_g, global_step, args): 203 | disc_factor = adopt_weight( 204 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start) 205 | d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det, 206 | fmap_gen_det, y_df_hat_r, y_df_hat_g, 207 | fmap_f_r, fmap_f_g, y_ds_hat_r, 208 | y_ds_hat_g, fmap_s_r, fmap_s_g) 209 | return d_loss 210 | -------------------------------------------------------------------------------- /academicodec/models/soundstream/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from academicodec.modules import NormConv1d 5 | from academicodec.modules import NormConv2d 6 | from academicodec.utils import get_padding 7 | from torch.nn import AvgPool1d 8 | from torch.nn.utils import spectral_norm 9 | from torch.nn.utils import weight_norm 10 | 11 | LRELU_SLOPE = 0.1 12 | 13 | 14 | class DiscriminatorP(torch.nn.Module): 15 | def __init__(self, 16 | period, 17 | kernel_size=5, 18 | stride=3, 19 | use_spectral_norm=False, 20 | activation: str='LeakyReLU', 21 | activation_params: dict={'negative_slope': 0.2}): 22 | super(DiscriminatorP, self).__init__() 23 | self.period = period 24 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm 25 | self.activation = getattr(torch.nn, activation)(**activation_params) 26 | self.convs = nn.ModuleList([ 27 | NormConv2d( 28 | 1, 29 | 32, (kernel_size, 1), (stride, 1), 30 | padding=(get_padding(5, 1), 0)), 31 | NormConv2d( 32 | 32, 33 | 32, (kernel_size, 1), (stride, 1), 34 | padding=(get_padding(5, 1), 0)), 35 | NormConv2d( 36 | 32, 37 | 32, (kernel_size, 1), (stride, 1), 38 | padding=(get_padding(5, 1), 0)), 39 | NormConv2d( 40 | 32, 41 | 32, (kernel_size, 1), (stride, 1), 42 | padding=(get_padding(5, 1), 0)), 43 | NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)), 44 | ]) 45 | self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0)) 46 | 47 | def forward(self, x): 48 | fmap = [] 49 | # 1d to 2d 50 | b, c, t = x.shape 51 | if t % self.period != 0: # pad first 52 | n_pad = self.period - (t % self.period) 53 | x = F.pad(x, (0, n_pad), "reflect") 54 | t = t + n_pad 55 | x = x.view(b, c, t // self.period, self.period) 56 | 57 | for l in self.convs: 58 | x = l(x) 59 | x = self.activation(x) 60 | fmap.append(x) 61 | x = self.conv_post(x) 62 | fmap.append(x) 63 | x = torch.flatten(x, 1, -1) 64 | 65 | return x, fmap 66 | 67 | 68 | class MultiPeriodDiscriminator(torch.nn.Module): 69 | def __init__(self): 70 | super(MultiPeriodDiscriminator, self).__init__() 71 | self.discriminators = nn.ModuleList([ 72 | DiscriminatorP(2), 73 | DiscriminatorP(3), 74 | DiscriminatorP(5), 75 | DiscriminatorP(7), 76 | DiscriminatorP(11), 77 | ]) 78 | 79 | def forward(self, y, y_hat): 80 | y_d_rs = [] 81 | y_d_gs = [] 82 | fmap_rs = [] 83 | fmap_gs = [] 84 | for i, d in enumerate(self.discriminators): 85 | y_d_r, fmap_r = d(y) 86 | y_d_g, fmap_g = d(y_hat) 87 | y_d_rs.append(y_d_r) 88 | fmap_rs.append(fmap_r) 89 | y_d_gs.append(y_d_g) 90 | fmap_gs.append(fmap_g) 91 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 92 | 93 | 94 | class DiscriminatorS(torch.nn.Module): 95 | def __init__(self, 96 | use_spectral_norm=False, 97 | activation: str='LeakyReLU', 98 | activation_params: dict={'negative_slope': 0.2}): 99 | super(DiscriminatorS, self).__init__() 100 | self.activation = getattr(torch.nn, activation)(**activation_params) 101 | self.convs = nn.ModuleList([ 102 | NormConv1d(1, 32, 15, 1, padding=7), 103 | NormConv1d(32, 32, 41, 2, groups=4, padding=20), 104 | NormConv1d(32, 32, 41, 2, groups=16, padding=20), 105 | NormConv1d(32, 32, 41, 4, groups=16, padding=20), 106 | NormConv1d(32, 32, 41, 4, groups=16, padding=20), 107 | NormConv1d(32, 32, 41, 1, groups=16, padding=20), 108 | NormConv1d(32, 32, 5, 1, padding=2), 109 | ]) 110 | self.conv_post = NormConv1d(32, 1, 3, 1, padding=1) 111 | 112 | def forward(self, x): 113 | fmap = [] 114 | for l in self.convs: 115 | x = l(x) 116 | x = self.activation(x) 117 | fmap.append(x) 118 | x = self.conv_post(x) 119 | fmap.append(x) 120 | x = torch.flatten(x, 1, -1) 121 | return x, fmap 122 | 123 | 124 | class MultiScaleDiscriminator(torch.nn.Module): 125 | def __init__(self): 126 | super(MultiScaleDiscriminator, self).__init__() 127 | self.discriminators = nn.ModuleList([ 128 | DiscriminatorS(), 129 | DiscriminatorS(), 130 | DiscriminatorS(), 131 | ]) 132 | self.meanpools = nn.ModuleList( 133 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) 134 | 135 | def forward(self, y, y_hat): 136 | y_d_rs = [] 137 | y_d_gs = [] 138 | fmap_rs = [] 139 | fmap_gs = [] 140 | for i, d in enumerate(self.discriminators): 141 | if i != 0: 142 | y = self.meanpools[i - 1](y) 143 | y_hat = self.meanpools[i - 1](y_hat) 144 | y_d_r, fmap_r = d(y) 145 | y_d_g, fmap_g = d(y_hat) 146 | y_d_rs.append(y_d_r) 147 | fmap_rs.append(fmap_r) 148 | y_d_gs.append(y_d_g) 149 | fmap_gs.append(fmap_g) 150 | 151 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 152 | -------------------------------------------------------------------------------- /academicodec/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Torch modules.""" 7 | # flake8: noqa 8 | from .conv import NormConv1d 9 | from .conv import NormConv2d 10 | from .conv import NormConvTranspose1d 11 | from .conv import NormConvTranspose2d 12 | from .conv import pad1d 13 | from .conv import SConv1d 14 | from .conv import SConvTranspose1d 15 | from .conv import unpad1d 16 | from .lstm import SLSTM 17 | from .seanet import SEANetDecoder 18 | from .seanet import SEANetEncoder 19 | from .transformer import StreamingTransformerEncoder 20 | -------------------------------------------------------------------------------- /academicodec/modules/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Convolutional layers wrappers and utilities.""" 7 | import math 8 | import typing as tp 9 | import warnings 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.nn.utils import spectral_norm 15 | from torch.nn.utils import weight_norm 16 | 17 | from academicodec.modules.norm import ConvLayerNorm 18 | 19 | CONV_NORMALIZATIONS = frozenset([ 20 | 'none', 'weight_norm', 'spectral_norm', 'time_layer_norm', 'layer_norm', 21 | 'time_group_norm' 22 | ]) 23 | 24 | 25 | def apply_parametrization_norm(module: nn.Module, 26 | norm: str='none') -> nn.Module: 27 | assert norm in CONV_NORMALIZATIONS 28 | if norm == 'weight_norm': 29 | return weight_norm(module) 30 | elif norm == 'spectral_norm': 31 | return spectral_norm(module) 32 | else: 33 | # We already check was in CONV_NORMALIZATION, so any other choice 34 | # doesn't need reparametrization. 35 | return module 36 | 37 | 38 | def get_norm_module(module: nn.Module, 39 | causal: bool=False, 40 | norm: str='none', 41 | **norm_kwargs) -> nn.Module: 42 | """Return the proper normalization module. If causal is True, this will ensure the returned 43 | module is causal, or return an error if the normalization doesn't support causal evaluation. 44 | """ 45 | assert norm in CONV_NORMALIZATIONS 46 | if norm == 'layer_norm': 47 | assert isinstance(module, nn.modules.conv._ConvNd) 48 | return ConvLayerNorm(module.out_channels, **norm_kwargs) 49 | elif norm == 'time_group_norm': 50 | if causal: 51 | raise ValueError("GroupNorm doesn't support causal evaluation.") 52 | assert isinstance(module, nn.modules.conv._ConvNd) 53 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs) 54 | else: 55 | return nn.Identity() 56 | 57 | 58 | def get_extra_padding_for_conv1d(x: torch.Tensor, 59 | kernel_size: int, 60 | stride: int, 61 | padding_total: int=0) -> int: 62 | """See `pad_for_conv1d`. 63 | """ 64 | length = x.shape[-1] 65 | n_frames = (length - kernel_size + padding_total) / stride + 1 66 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - 67 | padding_total) 68 | return ideal_length - length 69 | 70 | 71 | def pad_for_conv1d(x: torch.Tensor, 72 | kernel_size: int, 73 | stride: int, 74 | padding_total: int=0): 75 | """Pad for a convolution to make sure that the last window is full. 76 | Extra padding is added at the end. This is required to ensure that we can rebuild 77 | an output of the same length, as otherwise, even with padding, some time steps 78 | might get removed. 79 | For instance, with total padding = 4, kernel size = 4, stride = 2: 80 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 81 | 1 2 3 # (output frames of a convolution, last 0 is never used) 82 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 83 | 1 2 3 4 # once you removed padding, we are missing one time step ! 84 | """ 85 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, 86 | padding_total) 87 | return F.pad(x, (0, extra_padding)) 88 | 89 | 90 | def pad1d(x: torch.Tensor, 91 | paddings: tp.Tuple[int, int], 92 | mode: str='zero', 93 | value: float=0.): 94 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 95 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 96 | """ 97 | length = x.shape[-1] 98 | padding_left, padding_right = paddings 99 | assert padding_left >= 0 and padding_right >= 0, (padding_left, 100 | padding_right) 101 | if mode == 'reflect': 102 | max_pad = max(padding_left, padding_right) 103 | extra_pad = 0 104 | if length <= max_pad: 105 | extra_pad = max_pad - length + 1 106 | x = F.pad(x, (0, extra_pad)) 107 | padded = F.pad(x, paddings, mode, value) 108 | end = padded.shape[-1] - extra_pad 109 | return padded[..., :end] 110 | else: 111 | return F.pad(x, paddings, mode, value) 112 | 113 | 114 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 115 | """Remove padding from x, handling properly zero padding. Only for 1d!""" 116 | padding_left, padding_right = paddings 117 | assert padding_left >= 0 and padding_right >= 0, (padding_left, 118 | padding_right) 119 | assert (padding_left + padding_right) <= x.shape[-1] 120 | end = x.shape[-1] - padding_right 121 | return x[..., padding_left:end] 122 | 123 | 124 | class NormConv1d(nn.Module): 125 | """Wrapper around Conv1d and normalization applied to this conv 126 | to provide a uniform interface across normalization approaches. 127 | """ 128 | 129 | def __init__(self, 130 | *args, 131 | causal: bool=False, 132 | norm: str='none', 133 | norm_kwargs: tp.Dict[str, tp.Any]={}, 134 | **kwargs): 135 | super().__init__() 136 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) 137 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) 138 | self.norm_type = norm 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | class NormConv2d(nn.Module): 147 | """Wrapper around Conv2d and normalization applied to this conv 148 | to provide a uniform interface across normalization approaches. 149 | """ 150 | 151 | def __init__(self, 152 | *args, 153 | norm: str='none', 154 | norm_kwargs: tp.Dict[str, tp.Any]={}, 155 | **kwargs): 156 | super().__init__() 157 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) 158 | self.norm = get_norm_module( 159 | self.conv, causal=False, norm=norm, **norm_kwargs) 160 | self.norm_type = norm 161 | 162 | def forward(self, x): 163 | x = self.conv(x) 164 | x = self.norm(x) 165 | return x 166 | 167 | 168 | class NormConvTranspose1d(nn.Module): 169 | """Wrapper around ConvTranspose1d and normalization applied to this conv 170 | to provide a uniform interface across normalization approaches. 171 | """ 172 | 173 | def __init__(self, 174 | *args, 175 | causal: bool=False, 176 | norm: str='none', 177 | norm_kwargs: tp.Dict[str, tp.Any]={}, 178 | **kwargs): 179 | super().__init__() 180 | self.convtr = apply_parametrization_norm( 181 | nn.ConvTranspose1d(*args, **kwargs), norm) 182 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) 183 | self.norm_type = norm 184 | 185 | def forward(self, x): 186 | x = self.convtr(x) 187 | x = self.norm(x) 188 | return x 189 | 190 | 191 | class NormConvTranspose2d(nn.Module): 192 | """Wrapper around ConvTranspose2d and normalization applied to this conv 193 | to provide a uniform interface across normalization approaches. 194 | """ 195 | 196 | def __init__(self, 197 | *args, 198 | norm: str='none', 199 | norm_kwargs: tp.Dict[str, tp.Any]={}, 200 | **kwargs): 201 | super().__init__() 202 | self.convtr = apply_parametrization_norm( 203 | nn.ConvTranspose2d(*args, **kwargs), norm) 204 | self.norm = get_norm_module( 205 | self.convtr, causal=False, norm=norm, **norm_kwargs) 206 | 207 | def forward(self, x): 208 | x = self.convtr(x) 209 | x = self.norm(x) 210 | return x 211 | 212 | 213 | class SConv1d(nn.Module): 214 | """Conv1d with some builtin handling of asymmetric or causal padding 215 | and normalization. 216 | """ 217 | 218 | def __init__(self, 219 | in_channels: int, 220 | out_channels: int, 221 | kernel_size: int, 222 | stride: int=1, 223 | dilation: int=1, 224 | groups: int=1, 225 | bias: bool=True, 226 | causal: bool=False, 227 | norm: str='none', 228 | norm_kwargs: tp.Dict[str, tp.Any]={}, 229 | pad_mode: str='reflect'): 230 | super().__init__() 231 | # warn user on unusual setup between dilation and stride 232 | if stride > 1 and dilation > 1: 233 | warnings.warn( 234 | 'SConv1d has been initialized with stride > 1 and dilation > 1' 235 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).' 236 | ) 237 | self.conv = NormConv1d( 238 | in_channels, 239 | out_channels, 240 | kernel_size, 241 | stride, 242 | dilation=dilation, 243 | groups=groups, 244 | bias=bias, 245 | causal=causal, 246 | norm=norm, 247 | norm_kwargs=norm_kwargs) 248 | self.causal = causal 249 | self.pad_mode = pad_mode 250 | 251 | def forward(self, x): 252 | B, C, T = x.shape 253 | kernel_size = self.conv.conv.kernel_size[0] 254 | stride = self.conv.conv.stride[0] 255 | dilation = self.conv.conv.dilation[0] 256 | padding_total = (kernel_size - 1) * dilation - (stride - 1) 257 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, 258 | padding_total) 259 | if self.causal: 260 | # Left padding for causal 261 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 262 | else: 263 | # Asymmetric padding required for odd strides 264 | padding_right = padding_total // 2 265 | padding_left = padding_total - padding_right 266 | x = pad1d( 267 | x, (padding_left, padding_right + extra_padding), 268 | mode=self.pad_mode) 269 | return self.conv(x) 270 | 271 | 272 | class SConvTranspose1d(nn.Module): 273 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 274 | and normalization. 275 | """ 276 | 277 | def __init__(self, 278 | in_channels: int, 279 | out_channels: int, 280 | kernel_size: int, 281 | stride: int=1, 282 | causal: bool=False, 283 | norm: str='none', 284 | trim_right_ratio: float=1., 285 | norm_kwargs: tp.Dict[str, tp.Any]={}): 286 | super().__init__() 287 | self.convtr = NormConvTranspose1d( 288 | in_channels, 289 | out_channels, 290 | kernel_size, 291 | stride, 292 | causal=causal, 293 | norm=norm, 294 | norm_kwargs=norm_kwargs) 295 | self.causal = causal 296 | self.trim_right_ratio = trim_right_ratio 297 | assert self.causal or self.trim_right_ratio == 1., \ 298 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 299 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. 300 | 301 | def forward(self, x): 302 | kernel_size = self.convtr.convtr.kernel_size[0] 303 | stride = self.convtr.convtr.stride[0] 304 | padding_total = kernel_size - stride 305 | 306 | y = self.convtr(x) 307 | 308 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 309 | # removed at the very end, when keeping only the right length for the output, 310 | # as removing it here would require also passing the length at the matching layer 311 | # in the encoder. 312 | if self.causal: 313 | # Trim the padding on the right according to the specified ratio 314 | # if trim_right_ratio = 1.0, trim everything from right 315 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 316 | padding_left = padding_total - padding_right 317 | y = unpad1d(y, (padding_left, padding_right)) 318 | else: 319 | # Asymmetric padding required for odd strides 320 | padding_right = padding_total // 2 321 | padding_left = padding_total - padding_right 322 | y = unpad1d(y, (padding_left, padding_right)) 323 | return y 324 | -------------------------------------------------------------------------------- /academicodec/modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """LSTM layers module.""" 7 | from torch import nn 8 | 9 | 10 | class SLSTM(nn.Module): 11 | """ 12 | LSTM without worrying about the hidden state, nor the layout of the data. 13 | Expects input as convolutional layout. 14 | """ 15 | 16 | def __init__(self, dimension: int, num_layers: int=2, skip: bool=True): 17 | super().__init__() 18 | self.skip = skip 19 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 20 | 21 | def forward(self, x): 22 | x = x.permute(2, 0, 1) 23 | y, _ = self.lstm(x) 24 | if self.skip: 25 | y = y + x 26 | y = y.permute(1, 2, 0) 27 | return y 28 | -------------------------------------------------------------------------------- /academicodec/modules/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Normalization modules.""" 7 | import typing as tp 8 | 9 | import einops 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class ConvLayerNorm(nn.LayerNorm): 15 | """ 16 | Convolution-friendly LayerNorm that moves channels to last dimensions 17 | before running the normalization and moves them back to original position right after. 18 | """ 19 | 20 | def __init__(self, 21 | normalized_shape: tp.Union[int, tp.List[int], torch.Size], 22 | **kwargs): 23 | super().__init__(normalized_shape, **kwargs) 24 | 25 | def forward(self, x): 26 | x = einops.rearrange(x, 'b ... t -> b t ...') 27 | x = super().forward(x) 28 | x = einops.rearrange(x, 'b t ... -> b ... t') 29 | return 30 | -------------------------------------------------------------------------------- /academicodec/modules/seanet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Encodec SEANet-based encoder and decoder implementation.""" 7 | import typing as tp 8 | 9 | import numpy as np 10 | import torch.nn as nn 11 | 12 | from academicodec.modules import SConv1d 13 | from academicodec.modules import SConvTranspose1d 14 | from academicodec.modules import SLSTM 15 | 16 | 17 | class SEANetResnetBlock(nn.Module): 18 | """Residual block from SEANet model. 19 | Args: 20 | dim (int): Dimension of the input/output 21 | kernel_sizes (list): List of kernel sizes for the convolutions. 22 | dilations (list): List of dilations for the convolutions. 23 | activation (str): Activation function. 24 | activation_params (dict): Parameters to provide to the activation function 25 | norm (str): Normalization method. 26 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 27 | causal (bool): Whether to use fully causal convolution. 28 | pad_mode (str): Padding mode for the convolutions. 29 | compress (int): Reduced dimensionality in residual branches (from Demucs v3) 30 | true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. 31 | """ 32 | 33 | def __init__(self, 34 | dim: int, 35 | kernel_sizes: tp.List[int]=[3, 1], 36 | dilations: tp.List[int]=[1, 1], 37 | activation: str='ELU', 38 | activation_params: dict={'alpha': 1.0}, 39 | norm: str='weight_norm', 40 | norm_params: tp.Dict[str, tp.Any]={}, 41 | causal: bool=False, 42 | pad_mode: str='reflect', 43 | compress: int=2, 44 | true_skip: bool=True): 45 | super().__init__() 46 | assert len(kernel_sizes) == len( 47 | dilations), 'Number of kernel sizes should match number of dilations' 48 | act = getattr(nn, activation) 49 | hidden = dim // compress 50 | block = [] 51 | for i, (kernel_size, 52 | dilation) in enumerate(zip(kernel_sizes, dilations)): 53 | in_chs = dim if i == 0 else hidden 54 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden 55 | block += [ 56 | act(**activation_params), 57 | SConv1d( 58 | in_chs, 59 | out_chs, 60 | kernel_size=kernel_size, 61 | dilation=dilation, 62 | norm=norm, 63 | norm_kwargs=norm_params, 64 | causal=causal, 65 | pad_mode=pad_mode), 66 | ] 67 | self.block = nn.Sequential(*block) 68 | self.shortcut: nn.Module 69 | if true_skip: 70 | self.shortcut = nn.Identity() 71 | else: 72 | self.shortcut = SConv1d( 73 | dim, 74 | dim, 75 | kernel_size=1, 76 | norm=norm, 77 | norm_kwargs=norm_params, 78 | causal=causal, 79 | pad_mode=pad_mode) 80 | 81 | def forward(self, x): 82 | return self.shortcut(x) + self.block(x) 83 | 84 | 85 | class SEANetEncoder(nn.Module): 86 | """SEANet encoder. 87 | Args: 88 | channels (int): Audio channels. 89 | dimension (int): Intermediate representation dimension. 90 | n_filters (int): Base width for the model. 91 | n_residual_layers (int): nb of residual layers. 92 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of 93 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here 94 | that must match the decoder order 95 | activation (str): Activation function. 96 | activation_params (dict): Parameters to provide to the activation function 97 | norm (str): Normalization method. 98 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 99 | kernel_size (int): Kernel size for the initial convolution. 100 | last_kernel_size (int): Kernel size for the initial convolution. 101 | residual_kernel_size (int): Kernel size for the residual layers. 102 | dilation_base (int): How much to increase the dilation with each layer. 103 | causal (bool): Whether to use fully causal convolution. 104 | pad_mode (str): Padding mode for the convolutions. 105 | true_skip (bool): Whether to use true skip connection or a simple 106 | (streamable) convolution as the skip connection in the residual network blocks. 107 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 108 | lstm (int): Number of LSTM layers at the end of the encoder. 109 | """ 110 | 111 | def __init__(self, 112 | channels: int=1, 113 | dimension: int=128, 114 | n_filters: int=32, 115 | n_residual_layers: int=1, 116 | ratios: tp.List[int]=[8, 5, 4, 2], 117 | activation: str='ELU', 118 | activation_params: dict={'alpha': 1.0}, 119 | norm: str='weight_norm', 120 | norm_params: tp.Dict[str, tp.Any]={}, 121 | kernel_size: int=7, 122 | last_kernel_size: int=7, 123 | residual_kernel_size: int=3, 124 | dilation_base: int=2, 125 | causal: bool=False, 126 | pad_mode: str='reflect', 127 | true_skip: bool=False, 128 | compress: int=2, 129 | lstm: int=2): 130 | super().__init__() 131 | self.channels = channels 132 | self.dimension = dimension 133 | self.n_filters = n_filters 134 | self.ratios = list(reversed(ratios)) 135 | del ratios 136 | self.n_residual_layers = n_residual_layers 137 | self.hop_length = np.prod(self.ratios) # 计算乘积 138 | 139 | act = getattr(nn, activation) 140 | mult = 1 141 | model: tp.List[nn.Module] = [ 142 | SConv1d( 143 | channels, 144 | mult * n_filters, 145 | kernel_size, 146 | norm=norm, 147 | norm_kwargs=norm_params, 148 | causal=causal, 149 | pad_mode=pad_mode) 150 | ] 151 | # Downsample to raw audio scale 152 | for i, ratio in enumerate(self.ratios): 153 | # Add residual layers 154 | for j in range(n_residual_layers): 155 | model += [ 156 | SEANetResnetBlock( 157 | mult * n_filters, 158 | kernel_sizes=[residual_kernel_size, 1], 159 | dilations=[dilation_base**j, 1], 160 | norm=norm, 161 | norm_params=norm_params, 162 | activation=activation, 163 | activation_params=activation_params, 164 | causal=causal, 165 | pad_mode=pad_mode, 166 | compress=compress, 167 | true_skip=true_skip) 168 | ] 169 | 170 | # Add downsampling layers 171 | model += [ 172 | act(**activation_params), 173 | SConv1d( 174 | mult * n_filters, 175 | mult * n_filters * 2, 176 | kernel_size=ratio * 2, 177 | stride=ratio, 178 | norm=norm, 179 | norm_kwargs=norm_params, 180 | causal=causal, 181 | pad_mode=pad_mode), 182 | ] 183 | mult *= 2 184 | 185 | if lstm: 186 | model += [SLSTM(mult * n_filters, num_layers=lstm)] 187 | 188 | model += [ 189 | act(**activation_params), SConv1d( 190 | mult * n_filters, 191 | dimension, 192 | last_kernel_size, 193 | norm=norm, 194 | norm_kwargs=norm_params, 195 | causal=causal, 196 | pad_mode=pad_mode) 197 | ] 198 | 199 | self.model = nn.Sequential(*model) 200 | 201 | def forward(self, x): 202 | return self.model(x) 203 | 204 | 205 | class SEANetDecoder(nn.Module): 206 | """SEANet decoder. 207 | Args: 208 | channels (int): Audio channels. 209 | dimension (int): Intermediate representation dimension. 210 | n_filters (int): Base width for the model. 211 | n_residual_layers (int): nb of residual layers. 212 | ratios (Sequence[int]): kernel size and stride ratios 213 | activation (str): Activation function. 214 | activation_params (dict): Parameters to provide to the activation function 215 | final_activation (str): Final activation function after all convolutions. 216 | final_activation_params (dict): Parameters to provide to the activation function 217 | norm (str): Normalization method. 218 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 219 | kernel_size (int): Kernel size for the initial convolution. 220 | last_kernel_size (int): Kernel size for the initial convolution. 221 | residual_kernel_size (int): Kernel size for the residual layers. 222 | dilation_base (int): How much to increase the dilation with each layer. 223 | causal (bool): Whether to use fully causal convolution. 224 | pad_mode (str): Padding mode for the convolutions. 225 | true_skip (bool): Whether to use true skip connection or a simple 226 | (streamable) convolution as the skip connection in the residual network blocks. 227 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 228 | lstm (int): Number of LSTM layers at the end of the encoder. 229 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. 230 | If equal to 1.0, it means that all the trimming is done at the right. 231 | """ 232 | 233 | def __init__(self, 234 | channels: int=1, 235 | dimension: int=128, 236 | n_filters: int=32, 237 | n_residual_layers: int=1, 238 | ratios: tp.List[int]=[8, 5, 4, 2], 239 | activation: str='ELU', 240 | activation_params: dict={'alpha': 1.0}, 241 | final_activation: tp.Optional[str]=None, 242 | final_activation_params: tp.Optional[dict]=None, 243 | norm: str='weight_norm', 244 | norm_params: tp.Dict[str, tp.Any]={}, 245 | kernel_size: int=7, 246 | last_kernel_size: int=7, 247 | residual_kernel_size: int=3, 248 | dilation_base: int=2, 249 | causal: bool=False, 250 | pad_mode: str='reflect', 251 | true_skip: bool=False, 252 | compress: int=2, 253 | lstm: int=2, 254 | trim_right_ratio: float=1.0): 255 | super().__init__() 256 | self.dimension = dimension 257 | self.channels = channels 258 | self.n_filters = n_filters 259 | self.ratios = ratios 260 | del ratios 261 | self.n_residual_layers = n_residual_layers 262 | self.hop_length = np.prod(self.ratios) 263 | 264 | act = getattr(nn, activation) 265 | mult = int(2**len(self.ratios)) 266 | model: tp.List[nn.Module] = [ 267 | SConv1d( 268 | dimension, 269 | mult * n_filters, 270 | kernel_size, 271 | norm=norm, 272 | norm_kwargs=norm_params, 273 | causal=causal, 274 | pad_mode=pad_mode) 275 | ] 276 | 277 | if lstm: 278 | model += [SLSTM(mult * n_filters, num_layers=lstm)] 279 | 280 | # Upsample to raw audio scale 281 | for i, ratio in enumerate(self.ratios): 282 | # Add upsampling layers 283 | model += [ 284 | act(**activation_params), 285 | SConvTranspose1d( 286 | mult * n_filters, 287 | mult * n_filters // 2, 288 | kernel_size=ratio * 2, 289 | stride=ratio, 290 | norm=norm, 291 | norm_kwargs=norm_params, 292 | causal=causal, 293 | trim_right_ratio=trim_right_ratio), 294 | ] 295 | # Add residual layers 296 | for j in range(n_residual_layers): 297 | model += [ 298 | SEANetResnetBlock( 299 | mult * n_filters // 2, 300 | kernel_sizes=[residual_kernel_size, 1], 301 | dilations=[dilation_base**j, 1], 302 | activation=activation, 303 | activation_params=activation_params, 304 | norm=norm, 305 | norm_params=norm_params, 306 | causal=causal, 307 | pad_mode=pad_mode, 308 | compress=compress, 309 | true_skip=true_skip) 310 | ] 311 | 312 | mult //= 2 313 | 314 | # Add final layers 315 | model += [ 316 | act(**activation_params), SConv1d( 317 | n_filters, 318 | channels, 319 | last_kernel_size, 320 | norm=norm, 321 | norm_kwargs=norm_params, 322 | causal=causal, 323 | pad_mode=pad_mode) 324 | ] 325 | # Add optional final activation to decoder (eg. tanh) 326 | if final_activation is not None: 327 | final_act = getattr(nn, final_activation) 328 | final_activation_params = final_activation_params or {} 329 | model += [final_act(**final_activation_params)] 330 | self.model = nn.Sequential(*model) 331 | 332 | def forward(self, z): 333 | y = self.model(z) 334 | return y 335 | 336 | 337 | def test(): 338 | import torch 339 | encoder = SEANetEncoder() 340 | decoder = SEANetDecoder() 341 | x = torch.randn(1, 1, 24000) 342 | z = encoder(x) 343 | print('z ', z.shape) 344 | assert 1 == 2 345 | assert list(z.shape) == [1, 128, 75], z.shape 346 | y = decoder(z) 347 | assert y.shape == x.shape, (x.shape, y.shape) 348 | 349 | 350 | if __name__ == '__main__': 351 | test() 352 | -------------------------------------------------------------------------------- /academicodec/modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """A streamable transformer.""" 7 | import typing as tp 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def create_sin_embedding(positions: torch.Tensor, 15 | dim: int, 16 | max_period: float=10000): 17 | """Create time embedding for the given positions, target dimension `dim`. 18 | """ 19 | # We aim for BTC format 20 | assert dim % 2 == 0 21 | half_dim = dim // 2 22 | adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) 23 | phase = positions / (max_period**(adim / (half_dim - 1))) 24 | return torch.cat( 25 | [ 26 | torch.cos(phase), 27 | torch.sin(phase), 28 | ], dim=-1) 29 | 30 | 31 | class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): 32 | def forward(self, x: torch.Tensor, x_past: torch.Tensor, 33 | past_context: int): # type: ignore 34 | if self.norm_first: 35 | sa_input = self.norm1(x) 36 | x = x + self._sa_block(sa_input, x_past, past_context) 37 | x = x + self._ff_block(self.norm2(x)) 38 | else: 39 | sa_input = x 40 | x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) 41 | x = self.norm2(x + self._ff_block(x)) 42 | 43 | return x, sa_input 44 | 45 | # self-attention block 46 | def _sa_block(self, 47 | x: torch.Tensor, 48 | x_past: torch.Tensor, 49 | past_context: int): # type: ignore 50 | _, T, _ = x.shape 51 | _, H, _ = x_past.shape 52 | 53 | queries = x 54 | keys = torch.cat([x_past, x], dim=1) 55 | values = keys 56 | 57 | queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) 58 | keys_pos = torch.arange(T + H, device=x.device).view(1, -1) 59 | delta = queries_pos - keys_pos 60 | valid_access = (delta >= 0) & (delta <= past_context) 61 | x = self.self_attn( 62 | queries, keys, values, attn_mask=~valid_access, 63 | need_weights=False)[0] 64 | return self.dropout1(x) 65 | 66 | 67 | class StreamingTransformerEncoder(nn.Module): 68 | """TransformerEncoder with streaming support. 69 | 70 | Args: 71 | dim (int): dimension of the data. 72 | hidden_scale (int): intermediate dimension of FF module is this times the dimension. 73 | num_heads (int): number of heads. 74 | num_layers (int): number of layers. 75 | max_period (float): maxium period of cosines in the positional embedding. 76 | past_context (int or None): receptive field for the causal mask, infinite if None. 77 | gelu (bool): if true uses GeLUs, otherwise use ReLUs. 78 | norm_in (bool): normalize the input. 79 | dropout (float): dropout probability. 80 | **kwargs: See `nn.TransformerEncoderLayer`. 81 | """ 82 | 83 | def __init__(self, 84 | dim, 85 | hidden_scale: float=4., 86 | num_heads: int=8, 87 | num_layers: int=5, 88 | max_period: float=10000, 89 | past_context: int=1000, 90 | gelu: bool=True, 91 | norm_in: bool=True, 92 | dropout: float=0., 93 | **kwargs): 94 | super().__init__() 95 | assert dim % num_heads == 0 96 | hidden_dim = int(dim * hidden_scale) 97 | 98 | self.max_period = max_period 99 | self.past_context = past_context 100 | activation: tp.Any = F.gelu if gelu else F.relu 101 | 102 | self.norm_in: nn.Module 103 | if norm_in: 104 | self.norm_in = nn.LayerNorm(dim) 105 | else: 106 | self.norm_in = nn.Identity() 107 | 108 | self.layers = nn.ModuleList() 109 | for idx in range(num_layers): 110 | self.layers.append( 111 | StreamingTransformerEncoderLayer( 112 | dim, 113 | num_heads, 114 | hidden_dim, 115 | activation=activation, 116 | batch_first=True, 117 | dropout=dropout, 118 | **kwargs)) 119 | 120 | def forward(self, 121 | x: torch.Tensor, 122 | states: tp.Optional[tp.List[torch.Tensor]]=None, 123 | offset: tp.Union[int, torch.Tensor]=0): 124 | B, T, C = x.shape 125 | if states is None: 126 | states = [ 127 | torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers)) 128 | ] 129 | 130 | positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset 131 | pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) 132 | 133 | new_state: tp.List[torch.Tensor] = [] 134 | x = self.norm_in(x) 135 | x = x + pos_emb 136 | 137 | for layer_state, layer in zip(states, self.layers): 138 | x, new_layer_state = layer(x, layer_state, self.past_context) 139 | new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) 140 | new_state.append(new_layer_state[:, -self.past_context:, :]) 141 | return x, new_state, offset + T 142 | -------------------------------------------------------------------------------- /academicodec/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # flake8: noqa 7 | from .vq import QuantizedResult 8 | from .vq import ResidualVectorQuantizer 9 | -------------------------------------------------------------------------------- /academicodec/quantization/ac.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Arithmetic coder.""" 7 | import io 8 | import math 9 | import random 10 | import typing as tp 11 | 12 | import torch 13 | 14 | from academicodec.binary import BitPacker 15 | from academicodec.binary import BitUnpacker 16 | 17 | 18 | def build_stable_quantized_cdf(pdf: torch.Tensor, 19 | total_range_bits: int, 20 | roundoff: float=1e-8, 21 | min_range: int=2, 22 | check: bool=True) -> torch.Tensor: 23 | """Turn the given PDF into a quantized CDF that splits 24 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional 25 | to the PDF. 26 | 27 | Args: 28 | pdf (torch.Tensor): probability distribution, shape should be `[N]`. 29 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect 30 | during the coding process is `[0, 2 ** total_range_bits - 1]`. 31 | roundoff (float): will round the pdf up to that level to remove difference coming 32 | from e.g. evaluating the Language Model on different architectures. 33 | min_range (int): minimum range width. Should always be at least 2 for numerical 34 | stability. Use this to avoid pathological behavior is a value 35 | that is expected to be rare actually happens in real life. 36 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed. 37 | """ 38 | pdf = pdf.detach() 39 | if roundoff: 40 | pdf = (pdf / roundoff).floor() * roundoff 41 | # interpolate with uniform distribution to achieve desired minimum probability. 42 | total_range = 2**total_range_bits 43 | cardinality = len(pdf) 44 | alpha = min_range * cardinality / total_range 45 | assert alpha <= 1, "you must reduce min_range" 46 | ranges = (((1 - alpha) * total_range) * pdf).floor().long() 47 | ranges += min_range 48 | quantized_cdf = torch.cumsum(ranges, dim=-1) 49 | if min_range < 2: 50 | raise ValueError("min_range must be at least 2.") 51 | if check: 52 | assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] 53 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range 54 | ).any() or quantized_cdf[0] < min_range: 55 | raise ValueError("You must increase your total_range_bits.") 56 | return quantized_cdf 57 | 58 | 59 | class ArithmeticCoder: 60 | """ArithmeticCoder, 61 | Let us take a distribution `p` over `N` symbols, and assume we have a stream 62 | of random variables `s_t` sampled from `p`. Let us assume that we have a budget 63 | of `B` bits that we can afford to write on device. There are `2**B` possible numbers, 64 | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single 65 | sequence `(s_t)` by doing the following: 66 | 67 | 1) Initialize the current range to` [0 ** 2 B - 1]`. 68 | 2) For each time step t, split the current range into contiguous chunks, 69 | one for each possible outcome, with size roughly proportional to `p`. 70 | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks 71 | would be `{[0, 2], [3, 3]}`. 72 | 3) Select the chunk corresponding to `s_t`, and replace the current range with this. 73 | 4) When done encoding all the values, just select any value remaining in the range. 74 | 75 | You will notice that this procedure can fail: for instance if at any point in time 76 | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each 77 | possible outcome. Intuitively, the more likely a value is, the less the range width 78 | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient 79 | coding scheme, likely outcomes would take less bits, and more of them can be coded 80 | with a fixed budget. 81 | 82 | In practice, we do not know `B` ahead of time, but we have a way to inject new bits 83 | when the current range decreases below a given limit (given by `total_range_bits`), without 84 | having to redo all the computations. If we encode mostly likely values, we will seldom 85 | need to inject new bits, but a single rare value can deplete our stock of entropy! 86 | 87 | In this explanation, we assumed that the distribution `p` was constant. In fact, the present 88 | code works for any sequence `(p_t)` possibly different for each timestep. 89 | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller 90 | the KL between the true distribution and `p_t`, the most efficient the coding will be. 91 | 92 | Args: 93 | fo (IO[bytes]): file-like object to which the bytes will be written to. 94 | total_range_bits (int): the range `M` described above is `2 ** total_range_bits. 95 | Any time the current range width fall under this limit, new bits will 96 | be injected to rescale the initial range. 97 | """ 98 | 99 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24): 100 | assert total_range_bits <= 30 101 | self.total_range_bits = total_range_bits 102 | self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. 103 | self.low: int = 0 104 | self.high: int = 0 105 | self.max_bit: int = -1 106 | self._dbg: tp.List[tp.Any] = [] 107 | self._dbg2: tp.List[tp.Any] = [] 108 | 109 | @property 110 | def delta(self) -> int: 111 | """Return the current range width.""" 112 | return self.high - self.low + 1 113 | 114 | def _flush_common_prefix(self): 115 | # If self.low and self.high start with the sames bits, 116 | # those won't change anymore as we always just increase the range 117 | # by powers of 2, and we can flush them out to the bit stream. 118 | assert self.high >= self.low, (self.low, self.high) 119 | assert self.high < 2**(self.max_bit + 1) 120 | while self.max_bit >= 0: 121 | b1 = self.low >> self.max_bit 122 | b2 = self.high >> self.max_bit 123 | if b1 == b2: 124 | self.low -= (b1 << self.max_bit) 125 | self.high -= (b1 << self.max_bit) 126 | assert self.high >= self.low, (self.high, self.low, 127 | self.max_bit) 128 | assert self.low >= 0 129 | self.max_bit -= 1 130 | self.packer.push(b1) 131 | else: 132 | break 133 | 134 | def push(self, symbol: int, quantized_cdf: torch.Tensor): 135 | """Push the given symbol on the stream, flushing out bits 136 | if possible. 137 | 138 | Args: 139 | symbol (int): symbol to encode with the AC. 140 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 141 | to build this from your pdf estimate. 142 | """ 143 | while self.delta < 2**self.total_range_bits: 144 | self.low *= 2 145 | self.high = self.high * 2 + 1 146 | self.max_bit += 1 147 | 148 | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() 149 | range_high = quantized_cdf[symbol].item() - 1 150 | effective_low = int( 151 | math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) 152 | effective_high = int( 153 | math.floor(range_high * (self.delta / (2**self.total_range_bits)))) 154 | assert self.low <= self.high 155 | self.high = self.low + effective_high 156 | self.low = self.low + effective_low 157 | assert self.low <= self.high, (effective_low, effective_high, range_low, 158 | range_high) 159 | self._dbg.append((self.low, self.high)) 160 | self._dbg2.append((self.low, self.high)) 161 | outs = self._flush_common_prefix() 162 | assert self.low <= self.high 163 | assert self.max_bit >= -1 164 | assert self.max_bit <= 61, self.max_bit 165 | return outs 166 | 167 | def flush(self): 168 | """Flush the remaining information to the stream. 169 | """ 170 | while self.max_bit >= 0: 171 | b1 = (self.low >> self.max_bit) & 1 172 | self.packer.push(b1) 173 | self.max_bit -= 1 174 | self.packer.flush() 175 | 176 | 177 | class ArithmeticDecoder: 178 | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. 179 | 180 | Note that this must be called with **exactly** the same parameters and sequence 181 | of quantized cdf as the arithmetic encoder or the wrong values will be decoded. 182 | 183 | If the AC encoder current range is [L, H], with `L` and `H` having the some common 184 | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. 185 | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside 186 | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained 187 | for a specific sequence of symbols and a binary-search allows us to decode those symbols. 188 | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, 189 | and we will need to read new bits from the stream and repeat the process. 190 | 191 | """ 192 | 193 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24): 194 | self.total_range_bits = total_range_bits 195 | self.low: int = 0 196 | self.high: int = 0 197 | self.current: int = 0 198 | self.max_bit: int = -1 199 | self.unpacker = BitUnpacker( 200 | bits=1, fo=fo) # we pull single bits at a time. 201 | # Following is for debugging 202 | self._dbg: tp.List[tp.Any] = [] 203 | self._dbg2: tp.List[tp.Any] = [] 204 | self._last: tp.Any = None 205 | 206 | @property 207 | def delta(self) -> int: 208 | return self.high - self.low + 1 209 | 210 | def _flush_common_prefix(self): 211 | # Given the current range [L, H], if both have a common prefix, 212 | # we know we can remove it from our representation to avoid handling large numbers. 213 | while self.max_bit >= 0: 214 | b1 = self.low >> self.max_bit 215 | b2 = self.high >> self.max_bit 216 | if b1 == b2: 217 | self.low -= (b1 << self.max_bit) 218 | self.high -= (b1 << self.max_bit) 219 | self.current -= (b1 << self.max_bit) 220 | assert self.high >= self.low 221 | assert self.low >= 0 222 | self.max_bit -= 1 223 | else: 224 | break 225 | 226 | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: 227 | """Pull a symbol, reading as many bits from the stream as required. 228 | This returns `None` when the stream has been exhausted. 229 | 230 | Args: 231 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 232 | to build this from your pdf estimate. This must be **exatly** 233 | the same cdf as the one used at encoding time. 234 | """ 235 | while self.delta < 2**self.total_range_bits: 236 | bit = self.unpacker.pull() 237 | if bit is None: 238 | return None 239 | self.low *= 2 240 | self.high = self.high * 2 + 1 241 | self.current = self.current * 2 + bit 242 | self.max_bit += 1 243 | 244 | def bin_search(low_idx: int, high_idx: int): 245 | # Binary search is not just for coding interviews :) 246 | if high_idx < low_idx: 247 | raise RuntimeError("Binary search failed") 248 | mid = (low_idx + high_idx) // 2 249 | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 250 | range_high = quantized_cdf[mid].item() - 1 251 | effective_low = int( 252 | math.ceil(range_low * (self.delta / (2**self.total_range_bits) 253 | ))) 254 | effective_high = int( 255 | math.floor(range_high * (self.delta / (2**self.total_range_bits) 256 | ))) 257 | low = effective_low + self.low 258 | high = effective_high + self.low 259 | if self.current >= low: 260 | if self.current <= high: 261 | return (mid, low, high, self.current) 262 | else: 263 | return bin_search(mid + 1, high_idx) 264 | else: 265 | return bin_search(low_idx, mid - 1) 266 | 267 | self._last = (self.low, self.high, self.current, self.max_bit) 268 | sym, self.low, self.high, self.current = bin_search( 269 | 0, len(quantized_cdf) - 1) 270 | self._dbg.append((self.low, self.high, self.current)) 271 | self._flush_common_prefix() 272 | self._dbg2.append((self.low, self.high, self.current)) 273 | 274 | return sym 275 | 276 | 277 | def test(): 278 | torch.manual_seed(1234) 279 | random.seed(1234) 280 | for _ in range(4): 281 | pdfs = [] 282 | cardinality = random.randrange(4000) 283 | steps = random.randrange(100, 500) 284 | fo = io.BytesIO() 285 | encoder = ArithmeticCoder(fo) 286 | symbols = [] 287 | for step in range(steps): 288 | pdf = torch.softmax(torch.randn(cardinality), dim=0) 289 | pdfs.append(pdf) 290 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 291 | symbol = torch.multinomial(pdf, 1).item() 292 | symbols.append(symbol) 293 | encoder.push(symbol, q_cdf) 294 | encoder.flush() 295 | 296 | fo.seek(0) 297 | decoder = ArithmeticDecoder(fo) 298 | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): 299 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 300 | decoded_symbol = decoder.pull(q_cdf) 301 | assert decoded_symbol == symbol, idx 302 | assert decoder.pull(torch.zeros(1)) is None 303 | 304 | 305 | if __name__ == "__main__": 306 | test() 307 | -------------------------------------------------------------------------------- /academicodec/quantization/core_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This implementation is inspired from 8 | # https://github.com/lucidrains/vector-quantize-pytorch 9 | # which is released under MIT License. Hereafter, the original license: 10 | # MIT License 11 | # 12 | # Copyright (c) 2020 Phil Wang 13 | # 14 | # Permission is hereby granted, free of charge, to any person obtaining a copy 15 | # of this software and associated documentation files (the "Software"), to deal 16 | # in the Software without restriction, including without limitation the rights 17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | # copies of the Software, and to permit persons to whom the Software is 19 | # furnished to do so, subject to the following conditions: 20 | # 21 | # The above copyright notice and this permission notice shall be included in all 22 | # copies or substantial portions of the Software. 23 | # 24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | # SOFTWARE. 31 | """Core vector quantization implementation.""" 32 | import typing as tp 33 | 34 | import torch 35 | import torch.nn.functional as F 36 | from einops import rearrange 37 | from einops import repeat 38 | from torch import nn 39 | 40 | from academicodec.quantization.distrib import broadcast_tensors 41 | 42 | 43 | def default(val: tp.Any, d: tp.Any) -> tp.Any: 44 | return val if val is not None else d 45 | 46 | 47 | def ema_inplace(moving_avg, new, decay: float): 48 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 49 | 50 | 51 | def laplace_smoothing(x, n_categories: int, epsilon: float=1e-5): 52 | return (x + epsilon) / (x.sum() + n_categories * epsilon) 53 | 54 | 55 | def uniform_init(*shape: int): 56 | t = torch.empty(shape) 57 | nn.init.kaiming_uniform_(t) 58 | return t 59 | 60 | 61 | def sample_vectors(samples, num: int): 62 | num_samples, device = samples.shape[0], samples.device 63 | 64 | if num_samples >= num: 65 | indices = torch.randperm(num_samples, device=device)[:num] 66 | else: 67 | indices = torch.randint(0, num_samples, (num, ), device=device) 68 | 69 | return samples[indices] 70 | 71 | 72 | def kmeans(samples, num_clusters: int, num_iters: int=10): 73 | dim, dtype = samples.shape[-1], samples.dtype 74 | 75 | means = sample_vectors(samples, num_clusters) 76 | 77 | for _ in range(num_iters): 78 | diffs = rearrange(samples, "n d -> n () d") - rearrange(means, 79 | "c d -> () c d") 80 | dists = -(diffs**2).sum(dim=-1) 81 | 82 | buckets = dists.max(dim=-1).indices 83 | bins = torch.bincount(buckets, minlength=num_clusters) 84 | zero_mask = bins == 0 85 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 86 | 87 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 88 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) 89 | new_means = new_means / bins_min_clamped[..., None] 90 | 91 | means = torch.where(zero_mask[..., None], means, new_means) 92 | 93 | return means, bins 94 | 95 | 96 | class EuclideanCodebook(nn.Module): 97 | """Codebook with Euclidean distance. 98 | Args: 99 | dim (int): Dimension. 100 | codebook_size (int): Codebook size. 101 | kmeans_init (bool): Whether to use k-means to initialize the codebooks. 102 | If set to true, run the k-means algorithm on the first training batch and use 103 | the learned centroids as initialization. 104 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. 105 | decay (float): Decay for exponential moving average over the codebooks. 106 | epsilon (float): Epsilon value for numerical stability. 107 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 108 | that have an exponential moving average cluster size less than the specified threshold with 109 | randomly selected vector from the current batch. 110 | """ 111 | 112 | def __init__( 113 | self, 114 | dim: int, 115 | codebook_size: int, 116 | kmeans_init: int=False, 117 | kmeans_iters: int=10, 118 | decay: float=0.99, 119 | epsilon: float=1e-5, 120 | threshold_ema_dead_code: int=2, ): 121 | super().__init__() 122 | self.decay = decay 123 | init_fn: tp.Union[ 124 | tp.Callable[..., torch.Tensor], 125 | tp.Any] = uniform_init if not kmeans_init else torch.zeros 126 | embed = init_fn(codebook_size, dim) 127 | 128 | self.codebook_size = codebook_size 129 | 130 | self.kmeans_iters = kmeans_iters 131 | self.epsilon = epsilon 132 | self.threshold_ema_dead_code = threshold_ema_dead_code 133 | 134 | self.register_buffer("inited", torch.Tensor([not kmeans_init])) 135 | self.register_buffer("cluster_size", torch.zeros(codebook_size)) 136 | self.register_buffer("embed", embed) 137 | self.register_buffer("embed_avg", embed.clone()) 138 | 139 | @torch.jit.ignore 140 | def init_embed_(self, data): 141 | if self.inited: 142 | return 143 | 144 | embed, cluster_size = kmeans(data, self.codebook_size, 145 | self.kmeans_iters) 146 | self.embed.data.copy_(embed) 147 | self.embed_avg.data.copy_(embed.clone()) 148 | self.cluster_size.data.copy_(cluster_size) 149 | self.inited.data.copy_(torch.Tensor([True])) 150 | # Make sure all buffers across workers are in sync after initialization 151 | broadcast_tensors(self.buffers()) 152 | 153 | def replace_(self, samples, mask): 154 | modified_codebook = torch.where( 155 | mask[..., None], 156 | sample_vectors(samples, self.codebook_size), self.embed) 157 | self.embed.data.copy_(modified_codebook) 158 | 159 | def expire_codes_(self, batch_samples): 160 | if self.threshold_ema_dead_code == 0: 161 | return 162 | 163 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 164 | if not torch.any(expired_codes): 165 | return 166 | 167 | batch_samples = rearrange(batch_samples, "... d -> (...) d") 168 | self.replace_(batch_samples, mask=expired_codes) 169 | broadcast_tensors(self.buffers()) 170 | 171 | def preprocess(self, x): 172 | x = rearrange(x, "... d -> (...) d") 173 | return x 174 | 175 | def quantize(self, x): 176 | embed = self.embed.t() 177 | dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + 178 | embed.pow(2).sum(0, keepdim=True)) 179 | embed_ind = dist.max(dim=-1).indices 180 | return embed_ind 181 | 182 | def postprocess_emb(self, embed_ind, shape): 183 | return embed_ind.view(*shape[:-1]) 184 | 185 | def dequantize(self, embed_ind): 186 | quantize = F.embedding(embed_ind, self.embed) 187 | return quantize 188 | 189 | def encode(self, x): 190 | shape = x.shape 191 | # pre-process 192 | x = self.preprocess(x) 193 | # quantize 194 | embed_ind = self.quantize(x) 195 | # post-process 196 | embed_ind = self.postprocess_emb(embed_ind, shape) 197 | return embed_ind 198 | 199 | def decode(self, embed_ind): 200 | quantize = self.dequantize(embed_ind) 201 | return quantize 202 | 203 | def forward(self, x): 204 | shape, dtype = x.shape, x.dtype 205 | x = self.preprocess(x) 206 | 207 | self.init_embed_(x) 208 | 209 | embed_ind = self.quantize(x) 210 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 211 | embed_ind = self.postprocess_emb(embed_ind, shape) 212 | quantize = self.dequantize(embed_ind) 213 | 214 | if self.training: 215 | # We do the expiry of code at that point as buffers are in sync 216 | # and all the workers will take the same decision. 217 | self.expire_codes_(x) 218 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 219 | embed_sum = x.t() @ embed_onehot 220 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) 221 | cluster_size = ( 222 | laplace_smoothing(self.cluster_size, self.codebook_size, 223 | self.epsilon) * self.cluster_size.sum()) 224 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 225 | self.embed.data.copy_(embed_normalized) 226 | 227 | return quantize, embed_ind 228 | 229 | 230 | class VectorQuantization(nn.Module): 231 | """Vector quantization implementation. 232 | Currently supports only euclidean distance. 233 | Args: 234 | dim (int): Dimension 235 | codebook_size (int): Codebook size 236 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. 237 | decay (float): Decay for exponential moving average over the codebooks. 238 | epsilon (float): Epsilon value for numerical stability. 239 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 240 | kmeans_iters (int): Number of iterations used for kmeans initialization. 241 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 242 | that have an exponential moving average cluster size less than the specified threshold with 243 | randomly selected vector from the current batch. 244 | commitment_weight (float): Weight for commitment loss. 245 | """ 246 | 247 | def __init__( 248 | self, 249 | dim: int, 250 | codebook_size: int, 251 | codebook_dim: tp.Optional[int]=None, 252 | decay: float=0.99, 253 | epsilon: float=1e-5, 254 | kmeans_init: bool=True, 255 | kmeans_iters: int=50, 256 | threshold_ema_dead_code: int=2, 257 | commitment_weight: float=1., ): 258 | super().__init__() 259 | _codebook_dim: int = default(codebook_dim, dim) 260 | 261 | requires_projection = _codebook_dim != dim 262 | self.project_in = (nn.Linear(dim, _codebook_dim) 263 | if requires_projection else nn.Identity()) 264 | self.project_out = (nn.Linear(_codebook_dim, dim) 265 | if requires_projection else nn.Identity()) 266 | 267 | self.epsilon = epsilon 268 | self.commitment_weight = commitment_weight 269 | 270 | self._codebook = EuclideanCodebook( 271 | dim=_codebook_dim, 272 | codebook_size=codebook_size, 273 | kmeans_init=kmeans_init, 274 | kmeans_iters=kmeans_iters, 275 | decay=decay, 276 | epsilon=epsilon, 277 | threshold_ema_dead_code=threshold_ema_dead_code) 278 | self.codebook_size = codebook_size 279 | 280 | @property 281 | def codebook(self): 282 | return self._codebook.embed 283 | 284 | def encode(self, x): 285 | x = rearrange(x, "b d n -> b n d") 286 | x = self.project_in(x) 287 | embed_in = self._codebook.encode(x) 288 | return embed_in 289 | 290 | def decode(self, embed_ind): 291 | quantize = self._codebook.decode(embed_ind) 292 | quantize = self.project_out(quantize) 293 | quantize = rearrange(quantize, "b n d -> b d n") 294 | return quantize 295 | 296 | def forward(self, x): 297 | device = x.device 298 | x = rearrange(x, "b d n -> b n d") 299 | x = self.project_in(x) 300 | 301 | quantize, embed_ind = self._codebook(x) 302 | 303 | if self.training: 304 | quantize = x + (quantize - x).detach() 305 | 306 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 307 | 308 | if self.training: 309 | if self.commitment_weight > 0: 310 | commit_loss = F.mse_loss(quantize.detach(), x) 311 | loss = loss + commit_loss * self.commitment_weight 312 | 313 | quantize = self.project_out(quantize) 314 | quantize = rearrange(quantize, "b n d -> b d n") 315 | return quantize, embed_ind, loss 316 | 317 | 318 | class ResidualVectorQuantization(nn.Module): 319 | """Residual vector quantization implementation. 320 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf 321 | """ 322 | 323 | def __init__(self, *, num_quantizers, **kwargs): 324 | super().__init__() 325 | self.layers = nn.ModuleList( 326 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)]) 327 | 328 | def forward(self, x, n_q: tp.Optional[int]=None): 329 | quantized_out = 0.0 330 | residual = x 331 | 332 | all_losses = [] 333 | all_indices = [] 334 | 335 | n_q = n_q or len(self.layers) 336 | 337 | for layer in self.layers[:n_q]: 338 | quantized, indices, loss = layer(residual) 339 | residual = residual - quantized 340 | quantized_out = quantized_out + quantized 341 | 342 | all_indices.append(indices) 343 | all_losses.append(loss) 344 | 345 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) 346 | return quantized_out, out_indices, out_losses 347 | 348 | def encode(self, 349 | x: torch.Tensor, 350 | n_q: tp.Optional[int]=None, 351 | st: tp.Optional[int]=None) -> torch.Tensor: 352 | residual = x 353 | all_indices = [] 354 | n_q = n_q or len(self.layers) 355 | st = st or 0 356 | for layer in self.layers[st:n_q]: # 设置解码的起止layer 357 | indices = layer.encode(residual) 358 | quantized = layer.decode(indices) 359 | residual = residual - quantized 360 | all_indices.append(indices) 361 | out_indices = torch.stack(all_indices) 362 | return out_indices 363 | 364 | def decode(self, q_indices: torch.Tensor) -> torch.Tensor: 365 | quantized_out = torch.tensor(0.0, device=q_indices.device) 366 | for i, indices in enumerate(q_indices): 367 | layer = self.layers[i] 368 | quantized = layer.decode(indices) 369 | quantized_out = quantized_out + quantized 370 | return quantized_out 371 | -------------------------------------------------------------------------------- /academicodec/quantization/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Torch distributed utilities.""" 7 | import typing as tp 8 | 9 | import torch 10 | 11 | 12 | def rank(): 13 | if torch.distributed.is_initialized(): 14 | return torch.distributed.get_rank() 15 | else: 16 | return 0 17 | 18 | 19 | def world_size(): 20 | if torch.distributed.is_initialized(): 21 | return torch.distributed.get_world_size() 22 | else: 23 | return 1 24 | 25 | 26 | def is_distributed(): 27 | return world_size() > 1 28 | 29 | 30 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 31 | if is_distributed(): 32 | return torch.distributed.all_reduce(tensor, op) 33 | 34 | 35 | def _is_complex_or_float(tensor): 36 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 37 | 38 | 39 | def _check_number_of_params(params: tp.List[torch.Tensor]): 40 | # utility function to check that the number of params in all workers is the same, 41 | # and thus avoid a deadlock with distributed all reduce. 42 | if not is_distributed() or not params: 43 | return 44 | #print('params[0].device ', params[0].device) 45 | tensor = torch.tensor( 46 | [len(params)], device=params[0].device, dtype=torch.long) 47 | all_reduce(tensor) 48 | if tensor.item() != len(params) * world_size(): 49 | # If not all the workers have the same number, for at least one of them, 50 | # this inequality will be verified. 51 | raise RuntimeError( 52 | f"Mismatch in number of params: ours is {len(params)}, " 53 | "at least one worker has a different one.") 54 | 55 | 56 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int=0): 57 | """Broadcast the tensors from the given parameters to all workers. 58 | This can be used to ensure that all workers have the same model to start with. 59 | """ 60 | if not is_distributed(): 61 | return 62 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 63 | _check_number_of_params(tensors) 64 | handles = [] 65 | for tensor in tensors: 66 | # src = int(rank()) # added code 67 | handle = torch.distributed.broadcast( 68 | tensor.data, src=src, async_op=True) 69 | handles.append(handle) 70 | for handle in handles: 71 | handle.wait() 72 | 73 | 74 | def sync_buffer(buffers, average=True): 75 | """ 76 | Sync grad for buffers. If average is False, broadcast instead of averaging. 77 | """ 78 | if not is_distributed(): 79 | return 80 | handles = [] 81 | for buffer in buffers: 82 | if torch.is_floating_point(buffer.data): 83 | if average: 84 | handle = torch.distributed.all_reduce( 85 | buffer.data, 86 | op=torch.distributed.ReduceOp.SUM, 87 | async_op=True) 88 | else: 89 | handle = torch.distributed.broadcast( 90 | buffer.data, src=0, async_op=True) 91 | handles.append((buffer, handle)) 92 | for buffer, handle in handles: 93 | handle.wait() 94 | if average: 95 | buffer.data /= world_size 96 | 97 | 98 | def sync_grad(params): 99 | """ 100 | Simpler alternative to DistributedDataParallel, that doesn't rely 101 | on any black magic. For simple models it can also be as fast. 102 | Just call this on your model parameters after the call to backward! 103 | """ 104 | if not is_distributed(): 105 | return 106 | handles = [] 107 | for p in params: 108 | if p.grad is not None: 109 | handle = torch.distributed.all_reduce( 110 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 111 | handles.append((p, handle)) 112 | for p, handle in handles: 113 | handle.wait() 114 | p.grad.data /= world_size() 115 | 116 | 117 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 118 | """Average a dictionary of metrics across all workers, using the optional 119 | `count` as unormalized weight. 120 | """ 121 | if not is_distributed(): 122 | return metrics 123 | keys, values = zip(*metrics.items()) 124 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 125 | tensor = torch.tensor( 126 | list(values) + [1], device=device, dtype=torch.float32) 127 | tensor *= count 128 | all_reduce(tensor) 129 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 130 | return dict(zip(keys, averaged)) 131 | -------------------------------------------------------------------------------- /academicodec/quantization/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Residual vector quantizer implementation.""" 7 | import math 8 | import typing as tp 9 | from dataclasses import dataclass 10 | from dataclasses import field 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from academicodec.quantization.core_vq import ResidualVectorQuantization 16 | 17 | 18 | @dataclass 19 | class QuantizedResult: 20 | quantized: torch.Tensor 21 | codes: torch.Tensor 22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 23 | penalty: tp.Optional[torch.Tensor] = None 24 | metrics: dict = field(default_factory=dict) 25 | 26 | 27 | class ResidualVectorQuantizer(nn.Module): 28 | """Residual Vector Quantizer. 29 | Args: 30 | dimension (int): Dimension of the codebooks. 31 | n_q (int): Number of residual vector quantizers used. 32 | bins (int): Codebook size. 33 | decay (float): Decay for exponential moving average over the codebooks. 34 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 35 | kmeans_iters (int): Number of iterations used for kmeans initialization. 36 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 37 | that have an exponential moving average cluster size less than the specified threshold with 38 | randomly selected vector from the current batch. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | dimension: int=256, 44 | n_q: int=8, 45 | bins: int=1024, 46 | decay: float=0.99, 47 | kmeans_init: bool=True, 48 | kmeans_iters: int=50, 49 | threshold_ema_dead_code: int=2, ): 50 | super().__init__() 51 | self.n_q = n_q 52 | self.dimension = dimension 53 | self.bins = bins 54 | self.decay = decay 55 | self.kmeans_init = kmeans_init 56 | self.kmeans_iters = kmeans_iters 57 | self.threshold_ema_dead_code = threshold_ema_dead_code 58 | self.vq = ResidualVectorQuantization( 59 | dim=self.dimension, 60 | codebook_size=self.bins, 61 | num_quantizers=self.n_q, 62 | decay=self.decay, 63 | kmeans_init=self.kmeans_init, 64 | kmeans_iters=self.kmeans_iters, 65 | threshold_ema_dead_code=self.threshold_ema_dead_code, ) 66 | 67 | def forward(self, 68 | x: torch.Tensor, 69 | sample_rate: int, 70 | bandwidth: tp.Optional[float]=None) -> QuantizedResult: 71 | """Residual vector quantization on the given input tensor. 72 | Args: 73 | x (torch.Tensor): Input tensor. 74 | sample_rate (int): Sample rate of the input tensor. 75 | bandwidth (float): Target bandwidth. 76 | Returns: 77 | QuantizedResult: 78 | The quantized (or approximately quantized) representation with 79 | the associated bandwidth and any penalty term for the loss. 80 | """ 81 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) 82 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) 83 | quantized, codes, commit_loss = self.vq(x, n_q=n_q) 84 | bw = torch.tensor(n_q * bw_per_q).to(x) 85 | return quantized, codes, bw, torch.mean(commit_loss) 86 | #return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) 87 | 88 | def get_num_quantizers_for_bandwidth( 89 | self, sample_rate: int, bandwidth: tp.Optional[float]=None) -> int: 90 | """Return n_q based on specified target bandwidth. 91 | """ 92 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) 93 | n_q = self.n_q 94 | if bandwidth and bandwidth > 0.: 95 | n_q = int(max(1, math.floor(bandwidth / bw_per_q))) 96 | return n_q 97 | 98 | def get_bandwidth_per_quantizer(self, sample_rate: int): 99 | """Return bandwidth per quantizer for a given input sample rate. 100 | """ 101 | return math.log2(self.bins) * sample_rate / 1000 102 | 103 | def encode(self, 104 | x: torch.Tensor, 105 | sample_rate: int, 106 | bandwidth: tp.Optional[float]=None, 107 | st: tp.Optional[int]=None) -> torch.Tensor: 108 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 109 | The RVQ encode method sets the appropriate number of quantizer to use 110 | and returns indices for each quantizer. 111 | """ 112 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) 113 | st = st or 0 114 | codes = self.vq.encode(x, n_q=n_q, st=st) 115 | return codes 116 | 117 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 118 | """Decode the given codes to the quantized representation. 119 | """ 120 | quantized = self.vq.decode(codes) 121 | return quantized 122 | -------------------------------------------------------------------------------- /academicodec/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import random 5 | import sys 6 | import time 7 | import warnings 8 | 9 | import matplotlib 10 | import numpy as np 11 | import torch 12 | import yaml 13 | from torch import distributed as dist 14 | from torch.nn.utils import weight_norm 15 | matplotlib.use("Agg") 16 | import matplotlib.pylab as plt 17 | import re 18 | import pathlib 19 | 20 | 21 | def seed_everything(seed, cudnn_deterministic=False): 22 | """ 23 | Function that sets seed for pseudo-random number generators in: 24 | pytorch, numpy, python.random 25 | 26 | Args: 27 | seed: the integer value seed for global random state 28 | """ 29 | if seed is not None: 30 | # print(f"Global seed set to {seed}") 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | # if cudnn_deterministic: 37 | # torch.backends.cudnn.deterministic = True 38 | # warnings.warn('You have chosen to seed training. ' 39 | # 'This will turn on the CUDNN deterministic setting, ' 40 | # 'which can slow down your training considerably! ' 41 | # 'You may see unexpected behavior when restarting ' 42 | # 'from checkpoints.') 43 | 44 | 45 | def is_primary(): 46 | return get_rank() == 0 47 | 48 | 49 | def get_rank(): 50 | if not dist.is_available(): 51 | return 0 52 | if not dist.is_initialized(): 53 | return 0 54 | 55 | return dist.get_rank() 56 | 57 | 58 | def load_yaml_config(path): 59 | with open(path) as f: 60 | config = yaml.full_load(f) 61 | return config 62 | 63 | 64 | def save_config_to_yaml(config, path): 65 | assert path.endswith('.yaml') 66 | with open(path, 'w') as f: 67 | f.write(yaml.dump(config)) 68 | f.close() 69 | 70 | 71 | def save_dict_to_json(d, path, indent=None): 72 | json.dump(d, open(path, 'w'), indent=indent) 73 | 74 | 75 | def load_dict_from_json(path): 76 | return json.load(open(path, 'r')) 77 | 78 | 79 | def write_args(args, path): 80 | args_dict = dict((name, getattr(args, name)) for name in dir(args) 81 | if not name.startswith('_')) 82 | with open(path, 'a') as args_file: 83 | args_file.write('==> torch version: {}\n'.format(torch.__version__)) 84 | args_file.write( 85 | '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) 86 | args_file.write('==> Cmd:\n') 87 | args_file.write(str(sys.argv)) 88 | args_file.write('\n==> args:\n') 89 | for k, v in sorted(args_dict.items()): 90 | args_file.write(' %s: %s\n' % (str(k), str(v))) 91 | args_file.close() 92 | 93 | 94 | class Logger(object): 95 | def __init__(self, args): 96 | self.args = args 97 | self.save_dir = args.save_dir 98 | self.is_primary = is_primary() 99 | 100 | if self.is_primary: 101 | os.makedirs(self.save_dir, exist_ok=True) 102 | 103 | # save the args and config 104 | self.config_dir = os.path.join(self.save_dir, 'configs') 105 | os.makedirs(self.config_dir, exist_ok=True) 106 | file_name = os.path.join(self.config_dir, 'args.txt') 107 | write_args(args, file_name) 108 | 109 | log_dir = os.path.join(self.save_dir, 'logs') 110 | if not os.path.exists(log_dir): 111 | os.makedirs(log_dir, exist_ok=True) 112 | self.text_writer = open(os.path.join(log_dir, 'log.txt'), 113 | 'a') # 'w') 114 | if args.tensorboard: 115 | self.log_info('using tensorboard') 116 | self.tb_writer = torch.utils.tensorboard.SummaryWriter( 117 | log_dir=log_dir 118 | ) # tensorboard.SummaryWriter(log_dir=log_dir) 119 | else: 120 | self.tb_writer = None 121 | 122 | def save_config(self, config): 123 | if self.is_primary: 124 | save_config_to_yaml(config, 125 | os.path.join(self.config_dir, 'config.yaml')) 126 | 127 | def log_info(self, info, check_primary=True): 128 | if self.is_primary or (not check_primary): 129 | print(info) 130 | if self.is_primary: 131 | info = str(info) 132 | time_str = time.strftime('%Y-%m-%d-%H-%M') 133 | info = '{}: {}'.format(time_str, info) 134 | if not info.endswith('\n'): 135 | info += '\n' 136 | self.text_writer.write(info) 137 | self.text_writer.flush() 138 | 139 | def add_scalar(self, **kargs): 140 | """Log a scalar variable.""" 141 | if self.is_primary: 142 | if self.tb_writer is not None: 143 | self.tb_writer.add_scalar(**kargs) 144 | 145 | def add_scalars(self, **kargs): 146 | """Log a scalar variable.""" 147 | if self.is_primary: 148 | if self.tb_writer is not None: 149 | self.tb_writer.add_scalars(**kargs) 150 | 151 | def add_image(self, **kargs): 152 | """Log a scalar variable.""" 153 | if self.is_primary: 154 | if self.tb_writer is not None: 155 | self.tb_writer.add_image(**kargs) 156 | 157 | def add_images(self, **kargs): 158 | """Log a scalar variable.""" 159 | if self.is_primary: 160 | if self.tb_writer is not None: 161 | self.tb_writer.add_images(**kargs) 162 | 163 | def close(self): 164 | if self.is_primary: 165 | self.text_writer.close() 166 | self.tb_writer.close() 167 | 168 | 169 | def plot_spectrogram(spectrogram): 170 | fig, ax = plt.subplots(figsize=(10, 2)) 171 | im = ax.imshow( 172 | spectrogram, aspect="auto", origin="lower", interpolation='none') 173 | plt.colorbar(im, ax=ax) 174 | 175 | fig.canvas.draw() 176 | plt.close() 177 | 178 | return fig 179 | 180 | 181 | def init_weights(m, mean=0.0, std=0.01): 182 | classname = m.__class__.__name__ 183 | if classname.find("Conv") != -1: 184 | m.weight.data.normal_(mean, std) 185 | 186 | 187 | def apply_weight_norm(m): 188 | classname = m.__class__.__name__ 189 | if classname.find("Conv") != -1: 190 | weight_norm(m) 191 | 192 | 193 | def get_padding(kernel_size, dilation=1): 194 | return int((kernel_size * dilation - dilation) / 2) 195 | 196 | 197 | def load_checkpoint(filepath, device): 198 | assert os.path.isfile(filepath) 199 | print("Loading '{}'".format(filepath)) 200 | checkpoint_dict = torch.load(filepath, map_location=device) 201 | print("Complete.") 202 | return checkpoint_dict 203 | 204 | 205 | def save_checkpoint(filepath, obj, num_ckpt_keep=5): 206 | name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) 207 | ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) 208 | if len(ckpts) > num_ckpt_keep: 209 | [os.remove(c) for c in ckpts[:-num_ckpt_keep]] 210 | print("Saving checkpoint to {}".format(filepath)) 211 | torch.save(obj, filepath) 212 | print("Complete.") 213 | 214 | 215 | def scan_checkpoint(cp_dir, prefix): 216 | pattern = os.path.join(cp_dir, prefix + '????????') 217 | cp_list = glob.glob(pattern) 218 | if len(cp_list) == 0: 219 | return None 220 | return sorted(cp_list)[-1] 221 | -------------------------------------------------------------------------------- /egs/Encodec_16k_320d/path.sh: -------------------------------------------------------------------------------- 1 | ../Encodec_24k_32d/path.sh -------------------------------------------------------------------------------- /egs/Encodec_16k_320d/readme.md: -------------------------------------------------------------------------------- 1 | # The training code of Encodec 2 | 3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec. 4 | 5 | ### For Training 6 | set the right path to start.sh 7 | `bash start.sh` 8 | 9 | ### For Inference 10 | if you want to use our checkpoint. Run the following
11 | ```bash 12 | mkdir checkpoint 13 | cd checkpoint 14 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_16khz_320d.pth 15 | bash test.sh # set the root in test.sh, before runing it. 16 | ``` -------------------------------------------------------------------------------- /egs/Encodec_16k_320d/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | log_root=logs 4 | # 16kHz *.wav in train_data_dir 5 | train_data_dir=dump/train 6 | valid_data_dir=dump/valid 7 | 8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \ 10 | --BATCH_SIZE 16 \ 11 | --N_EPOCHS 300 \ 12 | --save_dir ${log_root} \ 13 | --PATH ${log_root} \ 14 | --train_data_path ${train_data_dir} \ 15 | --valid_data_path ${valid_data_dir} \ 16 | --sr 16000 \ 17 | --ratios 8 5 4 2 \ 18 | --target_bandwidths 1 1.5 2 4 6 12 19 | -------------------------------------------------------------------------------- /egs/Encodec_16k_320d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | python3 ${BIN_DIR}/test.py \ 5 | --input=./test_wav \ 6 | --output=./output \ 7 | --resume_path=checkpoint/encodec_16k_320d.pth \ 8 | --sr=16000 \ 9 | --ratios 8 5 4 2 \ 10 | --target_bandwidths 1 1.5 2 4 6 12 \ 11 | --target_bw=12 \ 12 | -r 13 | -------------------------------------------------------------------------------- /egs/Encodec_24k_240d/path.sh: -------------------------------------------------------------------------------- 1 | ../Encodec_24k_32d/path.sh -------------------------------------------------------------------------------- /egs/Encodec_24k_240d/readme.md: -------------------------------------------------------------------------------- 1 | # The training code of Encodec 2 | 3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec. 4 | 5 | ### For Training 6 | set the right path to statr/start.sh 7 | 8 | run: `bash start.sh` 9 | 10 | ### For Finetune 11 | If you want to finetune the model, you can use following instruct: 12 | ` 13 | python3 main3_ddp.py --BATCH_SIZE 16 --N_EPOCHS 300 \ 14 | --save_dir path_to_save_log \ 15 | --PATH path_to_save_model \ 16 | --train_data_path path_to_training_data \ 17 | --valid_data_path path_to_val_data \ 18 | --resume --resume_path the_model_path 19 | ` 20 | 21 | ### For Inference 22 | if you want to use our checkpoint. Run the following
23 | ```bash 24 | mkdir checkpoint 25 | cd checkpoint 26 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_24khz_240d.pth 27 | bash test.sh # set the root in test.sh, before runing it. 28 | ``` -------------------------------------------------------------------------------- /egs/Encodec_24k_240d/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | log_root=logs 4 | # 24kHz *.wav in train_data_dir 5 | train_data_dir=dump/train 6 | valid_data_dir=dump/valid 7 | 8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \ 10 | --BATCH_SIZE 16 \ 11 | --N_EPOCHS 300 \ 12 | --save_dir ${log_root} \ 13 | --PATH ${log_root} \ 14 | --train_data_path ${train_data_dir} \ 15 | --valid_data_path ${valid_data_dir} \ 16 | --sr 24000 \ 17 | --ratios 6 5 4 2 \ 18 | --target_bandwidths 1 2 4 8 12 -------------------------------------------------------------------------------- /egs/Encodec_24k_240d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | python3 ${BIN_DIR}/test.py \ 5 | --input=./test_wav \ 6 | --output=./output \ 7 | --resume_path=checkpoint/encodec_24khz_240d.pth \ 8 | --sr=24000 \ 9 | --ratios 6 5 4 2 \ 10 | --target_bandwidths 1 2 4 8 12 \ 11 | --target_bw=12 \ 12 | -r 13 | -------------------------------------------------------------------------------- /egs/Encodec_24k_32d/path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export MAIN_ROOT=`realpath ${PWD}/../../` 3 | 4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} 5 | MODEL=encodec 6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL} -------------------------------------------------------------------------------- /egs/Encodec_24k_32d/readme.md: -------------------------------------------------------------------------------- 1 | # The training code of Encodec 2 | 3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec. 4 | 5 | ### For Training 6 | set the right path to start.sh 7 | 8 | `bash start.sh` 9 | 10 | ### For Inference 11 | if you want to use our checkpoint. Run the following
12 | ```bash 13 | mkdir checkpoint 14 | cd checkpoint` 15 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_24khz_32d.pth 16 | bash test.sh # set the root in test.sh, before runing it. 17 | ``` 18 | -------------------------------------------------------------------------------- /egs/Encodec_24k_32d/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | log_root=logs 4 | # 24kHz *.wav in train_data_dir 5 | train_data_dir=dump/train 6 | valid_data_dir=dump/valid 7 | 8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \ 10 | --BATCH_SIZE 16 \ 11 | --N_EPOCHS 300 \ 12 | --save_dir ${log_root} \ 13 | --PATH ${log_root} \ 14 | --train_data_path ${train_data_dir} \ 15 | --valid_data_path ${valid_data_dir} \ 16 | --sr 24000 \ 17 | --ratios 2 2 2 4 \ 18 | --target_bandwidths 7.5 15 19 | -------------------------------------------------------------------------------- /egs/Encodec_24k_32d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | python3 ${BIN_DIR}/test.py \ 5 | --input=./test_wav \ 6 | --output=./output \ 7 | --resume_path=checkpoint/Encodec_24khz_32d.pth \ 8 | --sr=24000 \ 9 | --ratios 2 2 2 4 \ 10 | --target_bandwidths 7.5 15 \ 11 | --target_bw=7.5 \ 12 | -r 13 | 14 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-16k-320d/config_16k_320d.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 8, 4 | "batch_size": 64, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.5, 7 | "adam_b2": 0.9, 8 | "lr_decay": 0.98, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,5,4,2], 12 | "upsample_kernel_sizes": [16,11,8,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 16000, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 200, 22 | "win_size": 800, 23 | 24 | "sampling_rate": 16000, 25 | 26 | "n_code_groups": 2, 27 | "n_codes": 1024, 28 | "codebook_loss_lambda": 1.0, 29 | "commitment_loss_lambda": 0.25, 30 | 31 | "fmin": 0, 32 | "fmax": 8000, 33 | "fmax_for_loss": null, 34 | 35 | "num_workers": 12, 36 | 37 | "dist_config": { 38 | "dist_backend": "nccl", 39 | "dist_url": "tcp://localhost:54321", 40 | "world_size": 1 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-16k-320d/path.sh: -------------------------------------------------------------------------------- 1 | ../HiFi-Codec-24k-240d/path.sh -------------------------------------------------------------------------------- /egs/HiFi-Codec-16k-320d/readme.md: -------------------------------------------------------------------------------- 1 | ## How to train your model 2 | Firstly, set the related path in start.sh file, then
3 | ```bash 4 | bash start.sh 5 | ``` 6 | 7 | ## How to Inference 8 | ```bash 9 | mkdir checkpoint 10 | cd checkpoint 11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-16k-320d 12 | bash test.sh 13 | ``` 14 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-16k-320d/start.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | source path.sh 4 | set -e 5 | 6 | log_root="logs" 7 | # .lst save the wav path. 8 | input_training_file="train.lst" 9 | input_validation_file="valid.lst" 10 | 11 | #mode=debug 12 | mode=train 13 | 14 | if [ "${mode}" == "debug" ]; then 15 | ## debug 16 | echo "Debug" 17 | log_root=${log_root}_debug 18 | export CUDA_VISIBLE_DEVICES=0 19 | python ${BIN_DIR}/train.py \ 20 | --config config_16k_320d.json \ 21 | --checkpoint_path ${log_root} \ 22 | --input_training_file ${input_training_file} \ 23 | --input_validation_file ${input_validation_file} \ 24 | --checkpoint_interval 100 \ 25 | --summary_interval 10 \ 26 | --validation_interval 100 \ 27 | 28 | elif [ "$mode" == "train" ]; then 29 | ## train 30 | echo "Train model..." 31 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 32 | python ${BIN_DIR}/train.py \ 33 | --config config_16k_320d.json \ 34 | --checkpoint_path ${log_root} \ 35 | --input_training_file ${input_training_file} \ 36 | --input_validation_file ${input_validation_file} \ 37 | --checkpoint_interval 5000 \ 38 | --summary_interval 100 \ 39 | --validation_interval 5000 40 | fi 41 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-16k-320d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | ckpt=checkpoint/HiFi-Codec-16k-320d 5 | echo checkpoint path: ${ckpt} 6 | 7 | # the path of test wave 8 | wav_dir=test_wav 9 | 10 | outputdir=output 11 | mkdir -p ${outputdir} 12 | 13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \ 14 | --model_path=${ckpt} \ 15 | --config_path=config_16k_320d.json \ 16 | --input_wavdir=${wav_dir} \ 17 | --outputdir=${outputdir} \ 18 | --num_gens=10000 \ 19 | --sample_rate=16000 20 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-240d/config_24k_240d.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 8, 4 | "batch_size": 32, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.5, 7 | "adam_b2": 0.9, 8 | "lr_decay": 0.98, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,5,3,2], 12 | "upsample_kernel_sizes": [16,11,7,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 12000, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 240, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 24000, 25 | 26 | "n_code_groups": 2, 27 | "n_codes": 1024, 28 | "codebook_loss_lambda": 1.0, 29 | "commitment_loss_lambda": 0.25, 30 | 31 | "fmin": 0, 32 | "fmax": 8000, 33 | "fmax_for_loss": null, 34 | 35 | "num_workers": 12, 36 | 37 | "dist_config": { 38 | "dist_backend": "nccl", 39 | "dist_url": "tcp://localhost:54321", 40 | "world_size": 1 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-240d/path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export MAIN_ROOT=`realpath ${PWD}/../../` 3 | 4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} 5 | MODEL=hificodec 6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL} 7 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-240d/readme.md: -------------------------------------------------------------------------------- 1 | ## How to train your model 2 | Firstly, set the related path in start.sh file, then
3 | ```bash 4 | bash start.sh 5 | ``` 6 | 7 | ## How to Inference 8 | ```bash 9 | mkdir checkpoint 10 | cd checkpoint 11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-24k-240d 12 | bash test.sh 13 | ``` 14 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-240d/start.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | source path.sh 4 | set -e 5 | 6 | log_root="logs" 7 | # .lst save the wav path. 8 | input_training_file="train.lst" 9 | input_validation_file="valid.lst" 10 | 11 | #mode=debug 12 | mode=train 13 | 14 | if [ "${mode}" == "debug" ]; then 15 | ## debug 16 | echo "Debug" 17 | log_root=${log_root}_debug 18 | export CUDA_VISIBLE_DEVICES=0 19 | python ${BIN_DIR}/train.py \ 20 | --config config_24k_240d.json \ 21 | --checkpoint_path ${log_root} \ 22 | --input_training_file ${input_training_file} \ 23 | --input_validation_file ${input_validation_file} \ 24 | --checkpoint_interval 100 \ 25 | --summary_interval 10 \ 26 | --validation_interval 100 \ 27 | 28 | elif [ "$mode" == "train" ]; then 29 | ## train 30 | echo "Train model..." 31 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 32 | python ${BIN_DIR}/train.py \ 33 | --config config_24k_240d.json \ 34 | --checkpoint_path ${log_root} \ 35 | --input_training_file ${input_training_file} \ 36 | --input_validation_file ${input_validation_file} \ 37 | --checkpoint_interval 5000 \ 38 | --summary_interval 100 \ 39 | --validation_interval 5000 40 | fi 41 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-240d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | ckpt=checkpoint/HiFi-Codec-24k-240d 5 | echo checkpoint path: ${ckpt} 6 | 7 | # the path of test wave 8 | wav_dir=test_wav 9 | 10 | outputdir=output 11 | mkdir -p ${outputdir} 12 | 13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \ 14 | --model_path=${ckpt} \ 15 | --config_path=config_24k_240d.json \ 16 | --input_wavdir=${wav_dir} \ 17 | --outputdir=${outputdir} \ 18 | --num_gens=10000 19 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/config_24k_320d.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 8, 4 | "batch_size": 80, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.5, 7 | "adam_b2": 0.9, 8 | "lr_decay": 0.98, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,5,4,2], 12 | "upsample_kernel_sizes": [16,11,8,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 16000, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 240, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 24000, 25 | 26 | "n_code_groups": 2, 27 | "n_codes": 1024, 28 | "codebook_loss_lambda": 1.0, 29 | "commitment_loss_lambda": 0.25, 30 | 31 | "fmin": 0, 32 | "fmax": 8000, 33 | "fmax_for_loss": null, 34 | 35 | "num_workers": 12, 36 | 37 | "dist_config": { 38 | "dist_backend": "nccl", 39 | "dist_url": "tcp://localhost:54321", 40 | "world_size": 1 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "\n", 11 | "sys.path.append('../../')" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "Init model and load weights\n", 24 | "Model ready\n", 25 | "Globbed 12 wav files.\n" 26 | ] 27 | }, 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "100%|███████████| 1/1 [00:00<00:00, 11.08it/s]" 33 | ] 34 | }, 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "wav.shape: (97681,)\n", 40 | "acoustic_token: tensor([[[ 11, 591, 281, 629],\n", 41 | " [733, 591, 401, 139],\n", 42 | " [500, 591, 733, 600],\n", 43 | " ...,\n", 44 | " [733, 591, 451, 346],\n", 45 | " [733, 591, 401, 139],\n", 46 | " [386, 591, 281, 461]]], device='cuda:0')\n", 47 | "acoustic_token.shape: torch.Size([1, 305, 4])\n", 48 | "acoustic_token.dtype: torch.int64\n" 49 | ] 50 | }, 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "import glob\n", 61 | "import json\n", 62 | "import os\n", 63 | "from pathlib import Path\n", 64 | "\n", 65 | "import librosa\n", 66 | "import torch\n", 67 | "from academicodec.models.hificodec.vqvae import VQVAE\n", 68 | "from librosa.util import normalize\n", 69 | "from tqdm import tqdm\n", 70 | "\n", 71 | "ckpt_path = './checkpoint/HiFi-Codec-24k-320d'\n", 72 | "config_path = './config_24k_320d.json'\n", 73 | "with open(config_path, 'r') as f:\n", 74 | " config = json.load(f)\n", 75 | " sample_rate = config['sampling_rate']\n", 76 | "\n", 77 | "outputdir = './output'\n", 78 | "inputdir = './test_wav'\n", 79 | "num = 1024\n", 80 | "\n", 81 | "if __name__ == '__main__':\n", 82 | " Path(outputdir).mkdir(parents=True, exist_ok=True)\n", 83 | " print(\"Init model and load weights\")\n", 84 | " # make sure you downloaded the weights from https://huggingface.co/Dongchao/AcademiCodec/blob/main/HiFi-Codec-24k-320d \n", 85 | " # and put it in ./checkpoint/\n", 86 | " model = VQVAE(\n", 87 | " config_path,\n", 88 | " ckpt_path,\n", 89 | " with_encoder=True)\n", 90 | " model.cuda()\n", 91 | " model.eval()\n", 92 | " print(\"Model ready\")\n", 93 | "\n", 94 | " wav_paths = glob.glob(f\"{inputdir}/*.wav\")[:num]\n", 95 | " print(f\"Globbed {len(wav_paths)} wav files.\")\n", 96 | " fid_to_acoustic_token = {}\n", 97 | " for wav_path in tqdm(wav_paths[:1]):\n", 98 | " wav, sr = librosa.load(wav_path, sr=sample_rate)\n", 99 | " print(\"wav.shape:\",wav.shape)\n", 100 | " assert sr == sample_rate\n", 101 | " fid = os.path.basename(wav_path)[:-4]\n", 102 | " wav = normalize(wav) * 0.95\n", 103 | " wav = torch.FloatTensor(wav).unsqueeze(0)\n", 104 | " wav = wav.to(torch.device('cuda'))\n", 105 | " acoustic_token = model.encode(wav)\n", 106 | " print(\"acoustic_token:\",acoustic_token)\n", 107 | " print(\"acoustic_token.shape:\",acoustic_token.shape)\n", 108 | " print(\"acoustic_token.dtype:\",acoustic_token.dtype)\n", 109 | " fid = os.path.basename(wav_path)[:-4]\n", 110 | " fid_to_acoustic_token[fid] = acoustic_token\n", 111 | "\n", 112 | " torch.save(fid_to_acoustic_token,\n", 113 | " os.path.join(outputdir, 'fid_to_acoustic_token.pth'))\n" 114 | ] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "Python 3 (ipykernel)", 120 | "language": "python", 121 | "name": "python3" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 3 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython3", 133 | "version": "3.7.14" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 2 138 | } 139 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/path.sh: -------------------------------------------------------------------------------- 1 | ../HiFi-Codec-24k-240d/path.sh -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/readme.md: -------------------------------------------------------------------------------- 1 | ## How to train your model 2 | Firstly, set the related path in start.sh file, then
3 | ```bash 4 | bash start.sh 5 | ``` 6 | 7 | ## How to Inference 8 | ```bash 9 | mkdir checkpoint 10 | cd checkpoint 11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-24k-320d 12 | bash test.sh 13 | ``` 14 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | set -e 4 | 5 | log_root="logs" 6 | # .lst save the wav path. 7 | input_training_file="train.lst" 8 | input_validation_file="valid.lst" 9 | 10 | #mode=debug 11 | mode=train 12 | 13 | if [ "${mode}" == "debug" ]; then 14 | ## debug 15 | echo "Debug" 16 | log_root=${log_root}_debug 17 | export CUDA_VISIBLE_DEVICES=0 18 | python ${BIN_DIR}/train.py \ 19 | --config config_24k_320d.json \ 20 | --checkpoint_path ${log_root} \ 21 | --input_training_file ${input_training_file} \ 22 | --input_validation_file ${input_validation_file} \ 23 | --checkpoint_interval 100 \ 24 | --summary_interval 10 \ 25 | --validation_interval 100 \ 26 | 27 | elif [ "$mode" == "train" ]; then 28 | ## train 29 | echo "Train model..." 30 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 31 | python ${BIN_DIR}/train.py \ 32 | --config config_24k_320d.json \ 33 | --checkpoint_path ${log_root} \ 34 | --input_training_file ${input_training_file} \ 35 | --input_validation_file ${input_validation_file} \ 36 | --checkpoint_interval 5000 \ 37 | --summary_interval 100 \ 38 | --validation_interval 5000 39 | fi 40 | -------------------------------------------------------------------------------- /egs/HiFi-Codec-24k-320d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | ckpt=checkpoint/HiFi-Codec-24k-320d 5 | echo checkpoint path: ${ckpt} 6 | 7 | # the path of test wave 8 | wav_dir=test_wav 9 | 10 | outputdir=output 11 | mkdir -p ${outputdir} 12 | 13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \ 14 | --model_path=${ckpt} \ 15 | --config_path=config_24k_320d.json \ 16 | --input_wavdir=${wav_dir} \ 17 | --outputdir=${outputdir} \ 18 | --num_gens=10000 19 | -------------------------------------------------------------------------------- /egs/SoundStream_24k_240d/path.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export MAIN_ROOT=`realpath ${PWD}/../../` 3 | 4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} 5 | MODEL=encodec 6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL} -------------------------------------------------------------------------------- /egs/SoundStream_24k_240d/readme.md: -------------------------------------------------------------------------------- 1 | # The training code of SoundStream 2 | 3 | 4 | ### For Training 5 | set the right path to start.sh 6 | 7 | run: `bash start.sh` 8 | 9 | ### For Inference 10 | 11 | 12 | 模型不开源,这个目录暂未整理 13 | -------------------------------------------------------------------------------- /egs/SoundStream_24k_240d/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | 4 | python3 main3_ddp.py \ 5 | --BATCH_SIZE 16 \ 6 | --N_EPOCHS 300 \ 7 | --save_dir path_to_save_log \ 8 | --PATH path_to_save_model \ 9 | --train_data_path path_to_training_data \ 10 | --valid_data_path path_to_val_data \ 11 | --sr 24000 \ 12 | --ratios 6 5 4 2 \ 13 | --target_bandwidths 1 2 4 8 12 14 | -------------------------------------------------------------------------------- /egs/SoundStream_24k_240d/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source path.sh 3 | python3 ${BIN_DIR}/test.py \ 4 | --input=./test_wav \ 5 | --output=./output \ 6 | --resume_path=checkpoint/soundstream.pth \ 7 | --sr=24000 \ 8 | --ratios 6 5 4 2 \ 9 | --target_bandwidths 1 2 4 8 12 \ 10 | --target_bw=12 11 | -------------------------------------------------------------------------------- /egs/util/wavlstgen.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # 2022-2023 by zhaomingwork@qq.com 3 | # can be used for generating train.lst or valid.lst only given a root dir 4 | # example: 5 | # python wavlstgen.py --wavdir /data/asr_data/aishell/ --outfile train.lst 6 | import os 7 | import time 8 | 9 | import argparse 10 | import json 11 | import traceback 12 | 13 | 14 | import logging 15 | 16 | logging.basicConfig(level=logging.ERROR) 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--wavdir", 20 | type=str, 21 | default="./", 22 | required=True, 23 | help="root dir of wav") 24 | 25 | 26 | parser.add_argument("--outfile", 27 | type=str, 28 | default="./wav.lst", 29 | required=False, 30 | help="output list file name") 31 | 32 | args = parser.parse_args() 33 | 34 | print(args) 35 | 36 | def genwavlist(rootdir): 37 | outlist = open(args.outfile, 'w+') 38 | 39 | for dirpath, dirnames, filenames in os.walk(rootdir): 40 | for filename in filenames: 41 | #print(os.path.join(dirpath, filename)) 42 | if filename.endswith(".wav"): 43 | outlist.write(os.path.join(dirpath, filename)+"\n") 44 | outlist.close() 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | genwavlist(args.wavdir) 50 | -------------------------------------------------------------------------------- /evaluation_metric/calculate_voc_obj_metrics/compute_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # pip install pesq 4 | # pip install pystoi 5 | # pip install pyworld 6 | # pip install pysptk 7 | # pip install -U numpy 8 | stage=1 9 | stop_stage=2 10 | 11 | #ref_dir=$1 12 | #gen_dir=$2 13 | 14 | ref_dir='your test folder' 15 | gen_dir='the genereated samples' 16 | echo ${ref_dir} 17 | echo ${gen_dir} 18 | 19 | 20 | 21 | if [ $stage -le 1 ] && [ "${stop_stage}" -ge 2 ];then 22 | echo "Compute PESQ" 23 | python metrics/compute_pesq.py \ 24 | -r ${ref_dir} \ 25 | -d ${gen_dir} 26 | fi 27 | 28 | if [ $stage -le 2 ] && [ "${stop_stage}" -ge 3 ];then 29 | echo "Compute STOI" 30 | python metrics/compute_stoi.py \ 31 | -r ${ref_dir} \ 32 | -d ${gen_dir} 33 | fi 34 | -------------------------------------------------------------------------------- /evaluation_metric/calculate_voc_obj_metrics/metrics/compute_pesq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import scipy.signal as signal 6 | from pesq import pesq 7 | from scipy.io import wavfile 8 | from tqdm import tqdm 9 | 10 | 11 | def cal_pesq(ref_dir, deg_dir): 12 | input_files = glob.glob(f"{deg_dir}/*.wav") 13 | 14 | nb_pesq_scores = 0.0 15 | wb_pesq_scores = 0.0 16 | for deg_wav in tqdm(input_files): 17 | ref_wav = os.path.join(ref_dir, os.path.basename(deg_wav)) 18 | ref_rate, ref = wavfile.read(ref_wav) 19 | deg_rate, deg = wavfile.read(deg_wav) 20 | if ref_rate != 16000: 21 | ref = signal.resample(ref, 16000) 22 | if deg_rate != 16000: 23 | deg = signal.resample(deg, 16000) 24 | 25 | min_len = min(len(ref), len(deg)) 26 | ref = ref[:min_len] 27 | deg = deg[:min_len] 28 | 29 | nb_pesq_scores += pesq(16000, ref, deg, 'nb') 30 | wb_pesq_scores += pesq(16000, ref, deg, 'wb') 31 | 32 | return nb_pesq_scores / len(input_files), wb_pesq_scores / len(input_files) 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | parser = argparse.ArgumentParser(description="Compute PESQ measure.") 38 | 39 | parser.add_argument( 40 | '-r', '--ref_dir', required=True, help="Reference wave folder.") 41 | parser.add_argument( 42 | '-d', '--deg_dir', required=True, help="Degraded wave folder.") 43 | 44 | args = parser.parse_args() 45 | 46 | nb_score, wb_score = cal_pesq(args.ref_dir, args.deg_dir) 47 | print(f"NB PESQ: {nb_score}") 48 | print(f"WB PESQ: {wb_score}") 49 | -------------------------------------------------------------------------------- /evaluation_metric/calculate_voc_obj_metrics/metrics/compute_stoi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import numpy as np 6 | from pystoi import stoi 7 | from scipy.io import wavfile 8 | from tqdm import tqdm 9 | 10 | 11 | def calculate_stoi(ref_dir, deg_dir): 12 | input_files = glob.glob(f"{deg_dir}/*.wav") 13 | if len(input_files) < 1: 14 | raise RuntimeError(f"Found no wavs in {ref_dir}") 15 | 16 | stoi_scores = [] 17 | for deg_wav in tqdm(input_files): 18 | ref_wav = os.path.join(ref_dir, os.path.basename(deg_wav)) 19 | rate, ref = wavfile.read(ref_wav) 20 | rate, deg = wavfile.read(deg_wav) 21 | min_len = min(len(ref), len(deg)) 22 | ref = ref[:min_len] 23 | deg = deg[:min_len] 24 | cur_stoi = stoi(ref, deg, rate, extended=False) 25 | stoi_scores.append(cur_stoi) 26 | 27 | return np.mean(stoi_scores) 28 | 29 | 30 | if __name__ == '__main__': 31 | 32 | parser = argparse.ArgumentParser(description="Compute STOI measure") 33 | 34 | parser.add_argument( 35 | '-r', '--ref_dir', required=True, help="Reference wave folder.") 36 | parser.add_argument( 37 | '-d', '--deg_dir', required=True, help="Degraded wave folder.") 38 | 39 | args = parser.parse_args() 40 | 41 | stoi_score = calculate_stoi(args.ref_dir, args.deg_dir) 42 | print(f"STOI: {stoi_score}") 43 | -------------------------------------------------------------------------------- /evaluation_metric/calculate_voc_obj_metrics/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from os.path import join as opj 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def count_parameters(model): 11 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | 13 | 14 | def get_params(args): 15 | 16 | params = {} 17 | args_ref = vars(args) 18 | args_keys = vars(args).keys() 19 | 20 | for key in args_keys: 21 | if '__' in key: 22 | continue 23 | else: 24 | temp_params = args_ref[key] 25 | if type(temp_params) == dict: 26 | params.update(temp_params) 27 | else: 28 | params[key] = temp_params 29 | 30 | return params 31 | 32 | 33 | def rescale_module(module, reference): 34 | for sub in module.modules(): 35 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): 36 | rescale_conv(sub, reference) 37 | 38 | 39 | def rescale_conv(conv, reference): 40 | std = conv.weight.std().detach() 41 | scale = (std / reference)**0.5 42 | conv.weight.data /= scale 43 | if conv.bias is not None: 44 | conv.bias.data /= scale 45 | 46 | 47 | def write_result(estimate, noise, file, args): 48 | if not os.path.exists(args.enhanced_path): 49 | os.makedirs(args.enhanced_path) 50 | file_name = opj(args.enhanced_path, 51 | file[0].rsplit('.', 1)[0].replace('\\', '/').split('/')[-1]) 52 | noise_path = file_name + '_noise.wav' 53 | enhanced_path = file_name + '_enhanced.wav' 54 | 55 | torchaudio.save(noise_path, noise.squeeze(1), args.sample_rate) 56 | torchaudio.save(enhanced_path, estimate.squeeze(1), args.sample_rate) 57 | 58 | 59 | def seed_init(seed=100): 60 | 61 | random.seed(seed) 62 | np.random.seed(seed) 63 | torch.manual_seed(seed) 64 | torch.cuda.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | torch.backends.cudnn.deterministic = True 67 | torch.backends.cudnn.benchmark = False 68 | os.environ['PYTHONHASHSEED'] = str(seed) 69 | 70 | 71 | def args_dict(args): 72 | """ 73 | Get your arguments and make dictionary. 74 | If you add some arguments in the model, you should edit here also. 75 | """ 76 | args.dataset = { 77 | 'train': args.train, 78 | 'val': args.val, 79 | 'test': args.test, 80 | 'matching': args.matching 81 | } 82 | args.setting = { 83 | 'sample_rate': args.sample_rate, 84 | 'segment': args.segment, 85 | 'pad': args.pad, 86 | 'stride': args.set_stride 87 | } 88 | args.manner = { 89 | 'in_channels': args.in_channels, 90 | 'out_channels': args.out_channels, 91 | 'hidden': args.hidden, 92 | 'depth': args.depth, 93 | 'kernel_size': args.kernel_size, 94 | 'stride': args.stride, 95 | 'growth': args.growth, 96 | 'head': args.head, 97 | 'segment_len': args.segment_len 98 | } 99 | 100 | args.ex_name = os.getcwd().replace('\\', '/').split('/')[-1] 101 | 102 | return args 103 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AcademiCodec: An Open Source Audio Codec Model for Academic Research 2 | 3 | This repo is organized as follows: 4 | 5 | ```text 6 | AcademiCodec 7 | ├── academicodec 8 | │   ├── utils.py # common parts of various models 9 | │   ├── modules # common parts of various models 10 | │   ├── ... 11 | │   ├── quantization # common parts of various models 12 | │   └── models # parts that are not shared by various models 13 | │     ├── hificodec 14 | │     ├── encodec 15 | │     ├── soundstream 16 | │     └── ... 17 | ├── evaluation_metric 18 | ├── egs 19 | │ ├── SoundStream* 20 | │ ├── EnCodec* 21 | │ └── HiFi-Codec* 22 | │      ├── start.sh 23 | │      ├── ... 24 | │     └── test.sh 25 | └── README.md 26 | ``` 27 | 28 | ### On going 29 | This project is on going. You can find the paper on https://arxiv.org/pdf/2305.02765.pdf
30 | Furthermore, this project is lanched from University, we expect more researchers to be the contributor.
31 | 32 | #### Abstract 33 | Audio codec models are widely used in audio communication as a crucial technique for compressing audio into discrete representations. Nowadays, audio codec models are increasingly utilized in generation fields as intermediate representations. For instance, AudioLM is ann audio generation model that uses the discrete representation of SoundStream as a training target, while VALL-E employs the Encodec model as an intermediate feature to aid TTS tasks. Despite their usefulness, two challenges persist: (1) training these audio codec models can be difficult due to the lack of publicly available training processes and the need for large-scale data and GPUs; (2) achieving good reconstruction performance requires many codebooks, which increases the burden on generation models. In this study, we propose a group-residual vector quantization (GRVQ) technique and use it to develop a novel \textbf{Hi}gh \textbf{Fi}delity Audio Codec model, HiFi-Codec, which only requires 4 codebooks. We train all the models using publicly available TTS data such as LibriTTS, VCTK, AISHELL, and more, with a total duration of over 1000 hours, using 8 GPUs. Our experimental results show that HiFi-Codec outperforms Encodec in terms of reconstruction performance despite requiring only 4 codebooks. To facilitate research in audio codec and generation, we introduce AcademiCodec, the first open-source audio codec toolkit that offers training codes and pre-trained models for Encodec, SoundStream, and HiFi-Codec. 34 | 35 | ## 🔥 News 36 | #### AcademiCodec 37 | - 2023.4.16: We first release the training code for Encodec and SoundStream and our pre-trained models, includes 24khz and 16khz. 38 | - 2023.5.5: We release the code of HiFi-Codec. 39 | - 2023.6.2: Add `HiFi-Codec-24k-320d/infer.ipynb`, which can be used to infer acoustic tokens to use for later training of VALL-E, SoundStorm and etc. 40 | - 2023.06.13: Refactor the code structure. 41 | ### Dependencies 42 | * [PyTorch](http://pytorch.org/) version >= 1.13.0 43 | * Python version >= 3.8 44 | 45 | # Train your own model 46 | please refer to the specific version. 47 | 48 | ## Data preparation 49 | Just prepare your audio data in one folder. Make sure the sample rate is right. 50 | 51 | ## Training or Inferce 52 | Refer to the specical folders, e.g. Encodec_24k_240d represent, the Encodec model, sample rate is 24khz, downsample rate is 240. If you want to use our pre-trained models, please refer to https://huggingface.co/Dongchao/AcademiCodec/tree/main 53 | 54 | ## Version Description 55 | * Encodec_16k_320, we train it using 16khz audio, and we set the downsample as 320, which can be used to train SpearTTS 56 | * Encodec_24k_240d, we train it using 24khz audio, and we set the downsample as 240, which can be used to InstructTTS 57 | * Encodec_24k_32d, we train it using 24khz audio, we only set the downsample as 32, which can only use one codebook, such as AudioGen. 58 | * SoundStream_24k_240d, the same configuration as Encodec_24k_240d. 59 | ## What the difference between SoundStream, Encodec and HiFi-Codec? 60 | In our view, the mian difference between SoundStream and Encodec is the different Discriminator choice. For Encodec, it only uses a STFT-dicriminator, which forces the STFT-spectrogram be more real. SoundStream use two types of Discriminator, one forces the waveform-level to be more real, one forces the specrogram-level to be more real. In our code, we adopt the waveform-level discriminator from HIFI-GAN. The spectrogram level discrimimator from Encodec. In thoery, we think SoundStream enjoin better performance. Actually, Google's offical SoundStream proves this, Google can only use 3 codebooks to reconstruct a audio with high-quality. Although our implements can also use 3 codebooks to realize good performance, we admit our version cannot be compared with Google now.
61 | For the HiFi-Codec, which is our proposed novel methods, which aims to help to some generation tasks. Such as VALL-E, AudioLM, MusicLM, SpearTTS, IntructTTS and so on. HiFi-Codec codebook only needs 4 codebooks, which significantly reduce the token numbers. Some researchers use our HiFi-Codec to implement VALL-E, which proves that can get better audio quality. 62 | 63 | ## Acknowledgements 64 | This implementation uses parts of the code from the following Github repos: 65 | https://github.com/facebookresearch/encodec
66 | https://github.com/yangdongchao/Text-to-sound-Synthesis
67 | https://github.com/b04901014/MQTTS 68 | ## Citations ## 69 | If you find this code useful in your research, please cite our work: 70 | ```bib 71 | @article{yang2023instructtts, 72 | title={InstructTTS: Modelling Expressive TTS in Discrete Latent Space with Natural Language Style Prompt}, 73 | author={Yang, Dongchao and Liu, Songxiang and Huang, Rongjie and Lei, Guangzhi and Weng, Chao and Meng, Helen and Yu, Dong}, 74 | journal={arXiv preprint arXiv:2301.13662}, 75 | year={2023} 76 | } 77 | ``` 78 | ```bibtex 79 | @article{yang2023hifi, 80 | title={HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec}, 81 | author={Yang, Dongchao and Liu, Songxiang and Huang, Rongjie and Tian, Jinchuan and Weng, Chao and Zou, Yuexian}, 82 | journal={arXiv preprint arXiv:2305.02765}, 83 | year={2023} 84 | } 85 | ``` 86 | 87 | ## Disclaimer ## 88 | MIT license 89 | 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchaudio 2 | tensorboard 3 | einops 4 | matplotlib 5 | pyyaml 6 | tqdm --------------------------------------------------------------------------------