├── .gitignore
├── academicodec
├── __init__.py
├── binary.py
├── models
│ ├── encodec
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── distributed
│ │ │ ├── distributed.py
│ │ │ └── launch.py
│ │ ├── loss.py
│ │ ├── main_launch.py
│ │ ├── msstftd.py
│ │ ├── net3.py
│ │ └── test.py
│ ├── hificodec
│ │ ├── __init__.py
│ │ ├── env.py
│ │ ├── meldataset.py
│ │ ├── models.py
│ │ ├── train.py
│ │ ├── vqvae.py
│ │ ├── vqvae_copy_syn.py
│ │ └── vqvae_tester.py
│ └── soundstream
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── loss.py
│ │ └── models.py
├── modules
│ ├── __init__.py
│ ├── conv.py
│ ├── lstm.py
│ ├── norm.py
│ ├── seanet.py
│ └── transformer.py
├── quantization
│ ├── __init__.py
│ ├── ac.py
│ ├── core_vq.py
│ ├── distrib.py
│ └── vq.py
└── utils.py
├── egs
├── Encodec_16k_320d
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── Encodec_24k_240d
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── Encodec_24k_32d
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── HiFi-Codec-16k-320d
│ ├── config_16k_320d.json
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── HiFi-Codec-24k-240d
│ ├── config_24k_240d.json
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── HiFi-Codec-24k-320d
│ ├── config_24k_320d.json
│ ├── infer.ipynb
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
├── SoundStream_24k_240d
│ ├── main3_ddp.py
│ ├── path.sh
│ ├── readme.md
│ ├── start.sh
│ └── test.sh
└── util
│ └── wavlstgen.py
├── evaluation_metric
└── calculate_voc_obj_metrics
│ ├── compute_metrics.sh
│ └── metrics
│ ├── compute_pesq.py
│ ├── compute_stoi.py
│ └── utils.py
├── readme.md
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpt
2 | outputdir
3 | __pycache__
4 |
--------------------------------------------------------------------------------
/academicodec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/__init__.py
--------------------------------------------------------------------------------
/academicodec/binary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
7 | import io
8 | import json
9 | import struct
10 | import typing as tp
11 |
12 | # format is `ECDC` magic code, followed by the header size as uint32.
13 | # Then an uint8 indicates the protocol version (0.)
14 | # The header is then provided as json and should contain all required
15 | # informations for decoding. A raw stream of bytes is then provided
16 | # and should be interpretable using the json header.
17 | _encodec_header_struct = struct.Struct('!4sBI')
18 | _ENCODEC_MAGIC = b'ECDC'
19 |
20 |
21 | def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
22 | meta_dumped = json.dumps(metadata).encode('utf-8')
23 | version = 0
24 | header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version,
25 | len(meta_dumped))
26 | fo.write(header)
27 | fo.write(meta_dumped)
28 | fo.flush()
29 |
30 |
31 | def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
32 | buf = b""
33 | while len(buf) < size:
34 | new_buf = fo.read(size)
35 | if not new_buf:
36 | raise EOFError("Impossible to read enough data from the stream, "
37 | f"{size} bytes remaining.")
38 | buf += new_buf
39 | size -= len(new_buf)
40 | return buf
41 |
42 |
43 | def read_ecdc_header(fo: tp.IO[bytes]):
44 | header_bytes = _read_exactly(fo, _encodec_header_struct.size)
45 | magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
46 | if magic != _ENCODEC_MAGIC:
47 | raise ValueError("File is not in ECDC format.")
48 | if version != 0:
49 | raise ValueError("Version not supported.")
50 | meta_bytes = _read_exactly(fo, meta_size)
51 | return json.loads(meta_bytes.decode('utf-8'))
52 |
53 |
54 | class BitPacker:
55 | """Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
56 | Note that for some bandwidth (1.5, 3), the codebook representation
57 | will not cover an integer number of bytes.
58 |
59 | Args:
60 | bits (int): number of bits per value that will be pushed.
61 | fo (IO[bytes]): file-object to push the bytes to.
62 | """
63 |
64 | def __init__(self, bits: int, fo: tp.IO[bytes]):
65 | self._current_value = 0
66 | self._current_bits = 0
67 | self.bits = bits
68 | self.fo = fo
69 |
70 | def push(self, value: int):
71 | """Push a new value to the stream. This will immediately
72 | write as many uint8 as possible to the underlying file-object."""
73 | self._current_value += (value << self._current_bits)
74 | self._current_bits += self.bits
75 | while self._current_bits >= 8:
76 | lower_8bits = self._current_value & 0xff
77 | self._current_bits -= 8
78 | self._current_value >>= 8
79 | self.fo.write(bytes([lower_8bits]))
80 |
81 | def flush(self):
82 | """Flushes the remaining partial uint8, call this at the end
83 | of the stream to encode."""
84 | if self._current_bits:
85 | self.fo.write(bytes([self._current_value]))
86 | self._current_value = 0
87 | self._current_bits = 0
88 | self.fo.flush()
89 |
90 |
91 | class BitUnpacker:
92 | """BitUnpacker does the opposite of `BitPacker`.
93 |
94 | Args:
95 | bits (int): number of bits of the values to decode.
96 | fo (IO[bytes]): file-object to push the bytes to.
97 | """
98 |
99 | def __init__(self, bits: int, fo: tp.IO[bytes]):
100 | self.bits = bits
101 | self.fo = fo
102 | self._mask = (1 << bits) - 1
103 | self._current_value = 0
104 | self._current_bits = 0
105 |
106 | def pull(self) -> tp.Optional[int]:
107 | """
108 | Pull a single value from the stream, potentially reading some
109 | extra bytes from the underlying file-object.
110 | Returns `None` when reaching the end of the stream.
111 | """
112 | while self._current_bits < self.bits:
113 | buf = self.fo.read(1)
114 | if not buf:
115 | return None
116 | character = buf[0]
117 | self._current_value += character << self._current_bits
118 | self._current_bits += 8
119 |
120 | out = self._current_value & self._mask
121 | self._current_value >>= self.bits
122 | self._current_bits -= self.bits
123 | return out
124 |
125 |
126 | def test():
127 | import torch
128 | torch.manual_seed(1234)
129 | for rep in range(4):
130 | length: int = torch.randint(10, 2_000, (1, )).item()
131 | bits: int = torch.randint(1, 16, (1, )).item()
132 | tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist()
133 | rebuilt: tp.List[int] = []
134 | buf = io.BytesIO()
135 | packer = BitPacker(bits, buf)
136 | for token in tokens:
137 | packer.push(token)
138 | packer.flush()
139 | buf.seek(0)
140 | unpacker = BitUnpacker(bits, buf)
141 | while True:
142 | value = unpacker.pull()
143 | if value is None:
144 | break
145 | rebuilt.append(value)
146 | assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
147 | # The flushing mechanism might lead to "ghost" values at the end of the stream.
148 | assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt),
149 | len(tokens), bits)
150 | for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
151 | assert a == b, (idx, a, b)
152 |
153 |
154 | if __name__ == '__main__':
155 | test()
156 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/encodec/__init__.py
--------------------------------------------------------------------------------
/academicodec/models/encodec/dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 |
4 | import torch
5 | import torchaudio
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class NSynthDataset(Dataset):
10 | """Dataset to load NSynth data."""
11 |
12 | def __init__(self, audio_dir):
13 | super().__init__()
14 | self.filenames = []
15 | self.filenames.extend(glob.glob(audio_dir + "/*.wav"))
16 | print(len(self.filenames))
17 | _, self.sr = torchaudio.load(self.filenames[0])
18 | self.max_len = 24000 # 24000
19 |
20 | def __len__(self):
21 | return len(self.filenames)
22 |
23 | def __getitem__(self, index):
24 | ans = torch.zeros(1, self.max_len)
25 | audio = torchaudio.load(self.filenames[index])[0]
26 | if audio.shape[1] > self.max_len:
27 | st = random.randint(0, audio.shape[1] - self.max_len - 1)
28 | ed = st + self.max_len
29 | return audio[:, st:ed]
30 | else:
31 | ans[:, :audio.shape[1]] = audio
32 | return ans
33 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/distributed/distributed.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------
2 | # Diffsound
3 | # code based https://github.com/cientgu/VQ-Diffusion
4 | # ------------------------------------------
5 | import pickle
6 |
7 | import torch
8 | from torch import distributed as dist
9 | from torch.utils import data
10 |
11 | LOCAL_PROCESS_GROUP = None
12 |
13 |
14 | def is_primary():
15 | return get_rank() == 0
16 |
17 |
18 | def get_rank():
19 | if not dist.is_available():
20 | return 0
21 |
22 | if not dist.is_initialized():
23 | return 0
24 |
25 | return dist.get_rank()
26 |
27 |
28 | def get_local_rank():
29 | if not dist.is_available():
30 | return 0
31 |
32 | if not dist.is_initialized():
33 | return 0
34 |
35 | if LOCAL_PROCESS_GROUP is None:
36 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")
37 |
38 | return dist.get_rank(group=LOCAL_PROCESS_GROUP)
39 |
40 |
41 | def synchronize():
42 | if not dist.is_available():
43 | return
44 |
45 | if not dist.is_initialized():
46 | return
47 |
48 | world_size = dist.get_world_size()
49 |
50 | if world_size == 1:
51 | return
52 |
53 | dist.barrier()
54 |
55 |
56 | def get_world_size():
57 | if not dist.is_available():
58 | return 1
59 |
60 | if not dist.is_initialized():
61 | return 1
62 |
63 | return dist.get_world_size()
64 |
65 |
66 | def is_distributed():
67 | raise RuntimeError('Please debug this function!')
68 | return get_world_size() > 1
69 |
70 |
71 | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
72 | world_size = get_world_size()
73 |
74 | if world_size == 1:
75 | return tensor
76 | dist.all_reduce(tensor, op=op, async_op=async_op)
77 |
78 | return tensor
79 |
80 |
81 | def all_gather(data):
82 | world_size = get_world_size()
83 |
84 | if world_size == 1:
85 | return [data]
86 |
87 | buffer = pickle.dumps(data)
88 | storage = torch.ByteStorage.from_buffer(buffer)
89 | tensor = torch.ByteTensor(storage).to("cuda")
90 |
91 | local_size = torch.IntTensor([tensor.numel()]).to("cuda")
92 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
93 | dist.all_gather(size_list, local_size)
94 | size_list = [int(size.item()) for size in size_list]
95 | max_size = max(size_list)
96 |
97 | tensor_list = []
98 | for _ in size_list:
99 | tensor_list.append(torch.ByteTensor(size=(max_size, )).to("cuda"))
100 |
101 | if local_size != max_size:
102 | padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda")
103 | tensor = torch.cat((tensor, padding), 0)
104 |
105 | dist.all_gather(tensor_list, tensor)
106 |
107 | data_list = []
108 |
109 | for size, tensor in zip(size_list, tensor_list):
110 | buffer = tensor.cpu().numpy().tobytes()[:size]
111 | data_list.append(pickle.loads(buffer))
112 |
113 | return data_list
114 |
115 |
116 | def reduce_dict(input_dict, average=True):
117 | world_size = get_world_size()
118 |
119 | if world_size < 2:
120 | return input_dict
121 |
122 | with torch.no_grad():
123 | keys = []
124 | values = []
125 |
126 | for k in sorted(input_dict.keys()):
127 | keys.append(k)
128 | values.append(input_dict[k])
129 |
130 | values = torch.stack(values, 0)
131 | dist.reduce(values, dst=0)
132 |
133 | if dist.get_rank() == 0 and average:
134 | values /= world_size
135 |
136 | reduced_dict = {k: v for k, v in zip(keys, values)}
137 |
138 | return reduced_dict
139 |
140 |
141 | def data_sampler(dataset, shuffle, distributed):
142 | if distributed:
143 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
144 |
145 | if shuffle:
146 | return data.RandomSampler(dataset)
147 |
148 | else:
149 | return data.SequentialSampler(dataset)
150 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/distributed/launch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------
2 | # Diffsound
3 | # code based https://github.com/cientgu/VQ-Diffusion
4 | # ------------------------------------------
5 | import distributed.distributed as dist_fn
6 | import torch
7 | from torch import distributed as dist
8 | from torch import multiprocessing as mp
9 |
10 | # import distributed as dist_fn
11 |
12 |
13 | def find_free_port():
14 | import socket
15 |
16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
17 |
18 | sock.bind(("", 0))
19 | port = sock.getsockname()[1]
20 | sock.close()
21 |
22 | return port
23 |
24 |
25 | def launch(fn,
26 | n_gpu_per_machine,
27 | n_machine=1,
28 | machine_rank=0,
29 | dist_url=None,
30 | args=()):
31 | world_size = n_machine * n_gpu_per_machine
32 |
33 | if world_size > 1:
34 | # if "OMP_NUM_THREADS" not in os.environ:
35 | # os.environ["OMP_NUM_THREADS"] = "1"
36 | if dist_url == "auto":
37 | if n_machine != 1:
38 | raise ValueError(
39 | 'dist_url="auto" not supported in multi-machine jobs')
40 | port = find_free_port()
41 | dist_url = f"tcp://127.0.0.1:{port}"
42 | print('dist_url ', dist_url)
43 | print('n_machine ', n_machine)
44 | print('args ', args)
45 | print('world_size ', world_size)
46 | print('machine_rank ', machine_rank)
47 | if n_machine > 1 and dist_url.startswith("file://"):
48 | raise ValueError(
49 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
50 | )
51 |
52 | mp.spawn(
53 | distributed_worker,
54 | nprocs=n_gpu_per_machine,
55 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url,
56 | args),
57 | daemon=False, )
58 | # n_machine ? world_size
59 | else:
60 | local_rank = 0
61 | fn(local_rank, *args)
62 |
63 |
64 | def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine,
65 | machine_rank, dist_url, args):
66 | if not torch.cuda.is_available():
67 | raise OSError("CUDA is not available. Please check your environments")
68 |
69 | global_rank = machine_rank * n_gpu_per_machine + local_rank
70 | print('local_rank ', local_rank)
71 | print('global_rank ', global_rank)
72 | try:
73 | dist.init_process_group(
74 | backend="NCCL",
75 | init_method=dist_url,
76 | world_size=world_size,
77 | rank=global_rank, )
78 |
79 | except Exception:
80 | raise OSError("failed to initialize NCCL groups")
81 |
82 | # changed
83 | dist_fn.synchronize()
84 |
85 | if n_gpu_per_machine > torch.cuda.device_count():
86 | raise ValueError(
87 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
88 | )
89 |
90 | torch.cuda.set_device(local_rank)
91 |
92 | if dist_fn.LOCAL_PROCESS_GROUP is not None:
93 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")
94 |
95 | # change paert
96 |
97 | n_machine = world_size // n_gpu_per_machine
98 | for i in range(n_machine):
99 | ranks_on_i = list(
100 | range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
101 | pg = dist.new_group(ranks_on_i)
102 |
103 | if i == machine_rank:
104 | dist_fn.LOCAL_PROCESS_GROUP = pg
105 |
106 | fn(local_rank, *args)
107 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchaudio.transforms import MelSpectrogram
4 |
5 |
6 | def adversarial_g_loss(y_disc_gen):
7 | """Hinge loss"""
8 | loss = 0.0
9 | for i in range(len(y_disc_gen)):
10 | stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
11 | loss += stft_loss
12 | return loss / len(y_disc_gen)
13 |
14 |
15 | def feature_loss(fmap_r, fmap_gen):
16 | loss = 0.0
17 | for i in range(len(fmap_r)):
18 | for j in range(len(fmap_r[i])):
19 | stft_loss = ((fmap_r[i][j] - fmap_gen[i][j]).abs() /
20 | (fmap_r[i][j].abs().mean())).mean()
21 | loss += stft_loss
22 | return loss / (len(fmap_r) * len(fmap_r[0]))
23 |
24 |
25 | def sim_loss(y_disc_r, y_disc_gen):
26 | loss = 0.0
27 | for i in range(len(y_disc_r)):
28 | loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
29 | return loss / len(y_disc_r)
30 |
31 | # def sisnr_loss(x, s, eps=1e-8):
32 | # """
33 | # calculate training loss
34 | # input:
35 | # x: separated signal, N x S tensor, estimate value
36 | # s: reference signal, N x S tensor, True value
37 | # Return:
38 | # sisnr: N tensor
39 | # """
40 | # if x.shape != s.shape:
41 | # if x.shape[-1] > s.shape[-1]:
42 | # x = x[:, :s.shape[-1]]
43 | # else:
44 | # s = s[:, :x.shape[-1]]
45 | # def l2norm(mat, keepdim=False):
46 | # return torch.norm(mat, dim=-1, keepdim=keepdim)
47 | # if x.shape != s.shape:
48 | # raise RuntimeError(
49 | # "Dimention mismatch when calculate si-snr, {} vs {}".format(
50 | # x.shape, s.shape))
51 | # x_zm = x - torch.mean(x, dim=-1, keepdim=True)
52 | # s_zm = s - torch.mean(s, dim=-1, keepdim=True)
53 | # t = torch.sum(
54 | # x_zm * s_zm, dim=-1,
55 | # keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
56 | # loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
57 | # return torch.sum(loss) / x.shape[0]
58 |
59 |
60 | def reconstruction_loss(x, G_x, args, eps=1e-7):
61 | # NOTE (lsx): hard-coded now
62 | L = args.LAMBDA_WAV * F.mse_loss(x, G_x) # wav L1 loss
63 | # loss_sisnr = sisnr_loss(G_x, x) #
64 | # L += 0.01*loss_sisnr
65 | # 2^6=64 -> 2^10=1024
66 | # NOTE (lsx): add 2^11
67 | for i in range(6, 12):
68 | # for i in range(5, 12): # Encodec setting
69 | s = 2**i
70 | melspec = MelSpectrogram(
71 | sample_rate=args.sr,
72 | n_fft=max(s, 512),
73 | win_length=s,
74 | hop_length=s // 4,
75 | n_mels=64,
76 | wkwargs={"device": args.device}).to(args.device)
77 | S_x = melspec(x)
78 | S_G_x = melspec(G_x)
79 | l1_loss = (S_x - S_G_x).abs().mean()
80 | l2_loss = (((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2).mean(dim=-2)**0.5).mean()
81 |
82 | alpha = (s / 2) ** 0.5
83 | L += (l1_loss + alpha * l2_loss)
84 | return L
85 |
86 |
87 | def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r,
88 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g,
89 | fmap_s_r, fmap_s_g):
90 | """Hinge Loss"""
91 | loss = 0.0
92 | loss1 = 0.0
93 | loss2 = 0.0
94 | loss3 = 0.0
95 | for i in range(len(y_disc_r)):
96 | loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[
97 | i]).mean()
98 | for i in range(len(y_df_hat_r)):
99 | loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[
100 | i]).mean()
101 | for i in range(len(y_ds_hat_r)):
102 | loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[
103 | i]).mean()
104 |
105 | loss = (loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 /
106 | len(y_ds_hat_r)) / 3.0
107 |
108 | return loss
109 |
110 |
111 | def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen,
112 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r,
113 | y_ds_hat_g, fmap_s_r, fmap_s_g, args):
114 | adv_g_loss = adversarial_g_loss(y_disc_gen)
115 | feat_loss = (feature_loss(fmap_r, fmap_gen) + sim_loss(
116 | y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss(
117 | y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) +
118 | sim_loss(y_ds_hat_r, y_ds_hat_g)) / 3.0
119 | rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
120 | total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV * adv_g_loss + args.LAMBDA_FEAT * feat_loss + args.LAMBDA_REC * rec_loss
121 | return total_loss, adv_g_loss, feat_loss, rec_loss
122 |
123 |
124 | def adopt_weight(weight, global_step, threshold=0, value=0.):
125 | if global_step < threshold:
126 | weight = value
127 | return weight
128 |
129 |
130 | def adopt_dis_weight(weight, global_step, threshold=0, value=0.):
131 | # 0,3,6,9,13....这些时间步,不更新dis
132 | if global_step % 3 == 0:
133 | weight = value
134 | return weight
135 |
136 |
137 | def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
138 | if last_layer is not None:
139 | nll_grads = torch.autograd.grad(
140 | nll_loss, last_layer, retain_graph=True)[0]
141 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
142 | else:
143 | print('last_layer cannot be none')
144 | assert 1 == 2
145 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
146 | d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
147 | d_weight = d_weight * args.LAMBDA_ADV
148 | return d_weight
149 |
150 |
151 | def loss_g(codebook_loss,
152 | inputs,
153 | reconstructions,
154 | fmap_r,
155 | fmap_gen,
156 | y_disc_r,
157 | y_disc_gen,
158 | global_step,
159 | y_df_hat_r,
160 | y_df_hat_g,
161 | y_ds_hat_r,
162 | y_ds_hat_g,
163 | fmap_f_r,
164 | fmap_f_g,
165 | fmap_s_r,
166 | fmap_s_g,
167 | last_layer=None,
168 | is_training=True,
169 | args=None):
170 | """
171 | args:
172 | codebook_loss: commit loss.
173 | inputs: ground-truth wav.
174 | reconstructions: reconstructed wav.
175 | fmap_r: real stft-D feature map.
176 | fmap_gen: fake stft-D feature map.
177 | y_disc_r: real stft-D logits.
178 | y_disc_gen: fake stft-D logits.
179 | global_step: global training step.
180 | y_df_hat_r: real MPD logits.
181 | y_df_hat_g: fake MPD logits.
182 | y_ds_hat_r: real MSD logits.
183 | y_ds_hat_g: fake MSD logits.
184 | fmap_f_r: real MPD feature map.
185 | fmap_f_g: fake MPD feature map.
186 | fmap_s_r: real MSD feature map.
187 | fmap_s_g: fake MSD feature map.
188 | """
189 | rec_loss = reconstruction_loss(inputs.contiguous(),
190 | reconstructions.contiguous(), args)
191 | adv_g_loss = adversarial_g_loss(y_disc_gen)
192 | adv_mpd_loss = adversarial_g_loss(y_df_hat_g)
193 | adv_msd_loss = adversarial_g_loss(y_ds_hat_g)
194 | adv_loss = (adv_g_loss + adv_mpd_loss + adv_msd_loss
195 | ) / 3.0 # NOTE(lsx): need to divide by 3?
196 | feat_loss = feature_loss(
197 | fmap_r,
198 | fmap_gen) #+ sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
199 | feat_loss_mpd = feature_loss(fmap_f_r,
200 | fmap_f_g) #+ sim_loss(y_df_hat_r, y_df_hat_g)
201 | feat_loss_msd = feature_loss(fmap_s_r,
202 | fmap_s_g) #+ sim_loss(y_ds_hat_r, y_ds_hat_g)
203 | feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
204 | d_weight = torch.tensor(1.0)
205 | # try:
206 | # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # 动态调整重构损失和对抗损失
207 | # except RuntimeError:
208 | # assert not is_training
209 | # d_weight = torch.tensor(0.0)
210 | disc_factor = adopt_weight(
211 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start)
212 | if disc_factor == 0.:
213 | fm_loss_wt = 0
214 | else:
215 | fm_loss_wt = args.LAMBDA_FEAT
216 | #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start)
217 | loss = rec_loss + d_weight * disc_factor * adv_loss + \
218 | fm_loss_wt * feat_loss_tot + args.LAMBDA_COM * codebook_loss
219 | return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
220 |
221 |
222 | def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r,
223 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r,
224 | fmap_s_g, global_step, args):
225 | disc_factor = adopt_weight(
226 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start)
227 | d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det,
228 | fmap_gen_det, y_df_hat_r, y_df_hat_g,
229 | fmap_f_r, fmap_f_g, y_ds_hat_r,
230 | y_ds_hat_g, fmap_s_r, fmap_s_g)
231 | return d_loss
232 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/msstftd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """MS-STFT discriminator, provided here for reference."""
7 | import typing as tp
8 |
9 | import torch
10 | import torchaudio
11 | from einops import rearrange
12 | from torch import nn
13 |
14 | from academicodec.modules import NormConv2d
15 |
16 | FeatureMapType = tp.List[torch.Tensor]
17 | LogitsType = torch.Tensor
18 | DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
19 |
20 |
21 | def get_2d_padding(kernel_size: tp.Tuple[int, int],
22 | dilation: tp.Tuple[int, int]=(1, 1)):
23 | return (((kernel_size[0] - 1) * dilation[0]) // 2, (
24 | (kernel_size[1] - 1) * dilation[1]) // 2)
25 |
26 |
27 | class DiscriminatorSTFT(nn.Module):
28 | """STFT sub-discriminator.
29 | Args:
30 | filters (int): Number of filters in convolutions
31 | in_channels (int): Number of input channels. Default: 1
32 | out_channels (int): Number of output channels. Default: 1
33 | n_fft (int): Size of FFT for each scale. Default: 1024
34 | hop_length (int): Length of hop between STFT windows for each scale. Default: 256
35 | kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
36 | stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
37 | dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
38 | win_length (int): Window size for each scale. Default: 1024
39 | normalized (bool): Whether to normalize by magnitude after stft. Default: True
40 | norm (str): Normalization method. Default: `'weight_norm'`
41 | activation (str): Activation function. Default: `'LeakyReLU'`
42 | activation_params (dict): Parameters to provide to the activation function.
43 | growth (int): Growth factor for the filters. Default: 1
44 | """
45 |
46 | def __init__(self,
47 | filters: int,
48 | in_channels: int=1,
49 | out_channels: int=1,
50 | n_fft: int=1024,
51 | hop_length: int=256,
52 | win_length: int=1024,
53 | max_filters: int=1024,
54 | filters_scale: int=1,
55 | kernel_size: tp.Tuple[int, int]=(3, 9),
56 | dilations: tp.List=[1, 2, 4],
57 | stride: tp.Tuple[int, int]=(1, 2),
58 | normalized: bool=True,
59 | norm: str='weight_norm',
60 | activation: str='LeakyReLU',
61 | activation_params: dict={'negative_slope': 0.2}):
62 | super().__init__()
63 | assert len(kernel_size) == 2
64 | assert len(stride) == 2
65 | self.filters = filters
66 | self.in_channels = in_channels
67 | self.out_channels = out_channels
68 | self.n_fft = n_fft
69 | self.hop_length = hop_length
70 | self.win_length = win_length
71 | self.normalized = normalized
72 | self.activation = getattr(torch.nn, activation)(**activation_params)
73 | self.spec_transform = torchaudio.transforms.Spectrogram(
74 | n_fft=self.n_fft,
75 | hop_length=self.hop_length,
76 | win_length=self.win_length,
77 | window_fn=torch.hann_window,
78 | normalized=self.normalized,
79 | center=False,
80 | pad_mode=None,
81 | power=None)
82 | spec_channels = 2 * self.in_channels
83 | self.convs = nn.ModuleList()
84 | self.convs.append(
85 | NormConv2d(
86 | spec_channels,
87 | self.filters,
88 | kernel_size=kernel_size,
89 | padding=get_2d_padding(kernel_size)))
90 | in_chs = min(filters_scale * self.filters, max_filters)
91 | for i, dilation in enumerate(dilations):
92 | out_chs = min((filters_scale**(i + 1)) * self.filters, max_filters)
93 | self.convs.append(
94 | NormConv2d(
95 | in_chs,
96 | out_chs,
97 | kernel_size=kernel_size,
98 | stride=stride,
99 | dilation=(dilation, 1),
100 | padding=get_2d_padding(kernel_size, (dilation, 1)),
101 | norm=norm))
102 | in_chs = out_chs
103 | out_chs = min((filters_scale**(len(dilations) + 1)) * self.filters,
104 | max_filters)
105 | self.convs.append(
106 | NormConv2d(
107 | in_chs,
108 | out_chs,
109 | kernel_size=(kernel_size[0], kernel_size[0]),
110 | padding=get_2d_padding((kernel_size[0], kernel_size[0])),
111 | norm=norm))
112 | self.conv_post = NormConv2d(
113 | out_chs,
114 | self.out_channels,
115 | kernel_size=(kernel_size[0], kernel_size[0]),
116 | padding=get_2d_padding((kernel_size[0], kernel_size[0])),
117 | norm=norm)
118 |
119 | def forward(self, x: torch.Tensor):
120 | fmap = []
121 | # print('x ', x.shape)
122 | z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
123 | # print('z ', z.shape)
124 | z = torch.cat([z.real, z.imag], dim=1)
125 | # print('cat_z ', z.shape)
126 | z = rearrange(z, 'b c w t -> b c t w')
127 | for i, layer in enumerate(self.convs):
128 | z = layer(z)
129 | z = self.activation(z)
130 | # print('z i', i, z.shape)
131 | fmap.append(z)
132 | z = self.conv_post(z)
133 | # print('logit ', z.shape)
134 | return z, fmap
135 |
136 |
137 | class MultiScaleSTFTDiscriminator(nn.Module):
138 | """Multi-Scale STFT (MS-STFT) discriminator.
139 | Args:
140 | filters (int): Number of filters in convolutions
141 | in_channels (int): Number of input channels. Default: 1
142 | out_channels (int): Number of output channels. Default: 1
143 | n_ffts (Sequence[int]): Size of FFT for each scale
144 | hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
145 | win_lengths (Sequence[int]): Window size for each scale
146 | **kwargs: additional args for STFTDiscriminator
147 | """
148 |
149 | def __init__(self,
150 | filters: int,
151 | in_channels: int=1,
152 | out_channels: int=1,
153 | n_ffts: tp.List[int]=[1024, 2048, 512, 256, 128],
154 | hop_lengths: tp.List[int]=[256, 512, 128, 64, 32],
155 | win_lengths: tp.List[int]=[1024, 2048, 512, 256, 128],
156 | **kwargs):
157 | super().__init__()
158 | assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
159 | self.discriminators = nn.ModuleList([
160 | DiscriminatorSTFT(
161 | filters,
162 | in_channels=in_channels,
163 | out_channels=out_channels,
164 | n_fft=n_ffts[i],
165 | win_length=win_lengths[i],
166 | hop_length=hop_lengths[i],
167 | **kwargs) for i in range(len(n_ffts))
168 | ])
169 | self.num_discriminators = len(self.discriminators)
170 |
171 | def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
172 | logits = []
173 | fmaps = []
174 | for disc in self.discriminators:
175 | logit, fmap = disc(x)
176 | logits.append(logit)
177 | fmaps.append(fmap)
178 | return logits, fmaps
179 |
180 |
181 | def test():
182 | disc = MultiScaleSTFTDiscriminator(filters=32)
183 | y = torch.randn(1, 1, 24000)
184 | y_hat = torch.randn(1, 1, 24000)
185 |
186 | y_disc_r, fmap_r = disc(y)
187 | y_disc_gen, fmap_gen = disc(y_hat)
188 | assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(
189 | fmap_gen) == disc.num_discriminators
190 |
191 | assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
192 | assert all(
193 | [list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
194 | assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
195 |
196 |
197 | if __name__ == '__main__':
198 | test()
199 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/net3.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | import torch.nn as nn
6 | from academicodec.modules.seanet import SEANetDecoder
7 | from academicodec.modules.seanet import SEANetEncoder
8 | from academicodec.quantization import ResidualVectorQuantizer
9 |
10 |
11 | # Generator
12 | class SoundStream(nn.Module):
13 | def __init__(self,
14 | n_filters,
15 | D,
16 | target_bandwidths=[7.5, 15],
17 | ratios=[8, 5, 4, 2],
18 | sample_rate=24000,
19 | bins=1024,
20 | normalize=False):
21 | super().__init__()
22 | self.hop_length = np.prod(ratios) # 计算乘积
23 | self.encoder = SEANetEncoder(
24 | n_filters=n_filters, dimension=D, ratios=ratios)
25 | n_q = int(1000 * target_bandwidths[-1] //
26 | (math.ceil(sample_rate / self.hop_length) * 10))
27 | self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 75
28 | self.bits_per_codebook = int(math.log2(bins))
29 | self.target_bandwidths = target_bandwidths
30 | self.quantizer = ResidualVectorQuantizer(
31 | dimension=D, n_q=n_q, bins=bins)
32 | self.decoder = SEANetDecoder(
33 | n_filters=n_filters, dimension=D, ratios=ratios)
34 |
35 | def get_last_layer(self):
36 | return self.decoder.layers[-1].weight
37 |
38 | def forward(self, x):
39 | e = self.encoder(x)
40 | max_idx = len(self.target_bandwidths) - 1
41 | bw = self.target_bandwidths[random.randint(0, max_idx)]
42 | quantized, codes, bandwidth, commit_loss = self.quantizer(
43 | e, self.frame_rate, bw)
44 | o = self.decoder(quantized)
45 | return o, commit_loss, None
46 |
47 | def encode(self, x, target_bw=None, st=None):
48 | e = self.encoder(x)
49 | if target_bw is None:
50 | bw = self.target_bandwidths[-1]
51 | else:
52 | bw = target_bw
53 | if st is None:
54 | st = 0
55 | codes = self.quantizer.encode(e, self.frame_rate, bw, st)
56 | return codes
57 |
58 | def decode(self, codes):
59 | quantized = self.quantizer.decode(codes)
60 | o = self.decoder(quantized)
61 | return o
62 |
--------------------------------------------------------------------------------
/academicodec/models/encodec/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Command-line for audio compression."""
7 | import argparse
8 | import os
9 | import sys
10 | import typing as tp
11 | from collections import OrderedDict
12 | from pathlib import Path
13 |
14 | import librosa
15 | import soundfile as sf
16 | import torch
17 | from academicodec.models.encodec.net3 import SoundStream
18 |
19 |
20 | def save_audio(wav: torch.Tensor,
21 | path: tp.Union[Path, str],
22 | sample_rate: int,
23 | rescale: bool=False):
24 | limit = 0.99
25 | mx = wav.abs().max()
26 | if rescale:
27 | wav = wav * min(limit / mx, 1)
28 | else:
29 | wav = wav.clamp(-limit, limit)
30 | wav = wav.squeeze().cpu().numpy()
31 | sf.write(path, wav, sample_rate)
32 |
33 |
34 | def get_parser():
35 | parser = argparse.ArgumentParser(
36 | 'encodec',
37 | description='High fidelity neural audio codec. '
38 | 'If input is a .ecdc, decompresses it. '
39 | 'If input is .wav, compresses it. If output is also wav, '
40 | 'do a compression/decompression cycle.')
41 | parser.add_argument(
42 | '--input',
43 | type=Path,
44 | help='Input file, whatever is supported by torchaudio on your system.')
45 | parser.add_argument(
46 | '--output',
47 | type=Path,
48 | nargs='?',
49 | help='Output file, otherwise inferred from input file.')
50 | parser.add_argument(
51 | '--resume_path', type=str, default='resume_path', help='resume_path')
52 | parser.add_argument(
53 | '--sr', type=int, default=16000, help='sample rate of model')
54 | parser.add_argument(
55 | '-r',
56 | '--rescale',
57 | action='store_true',
58 | help='Automatically rescale the output to avoid clipping.')
59 | parser.add_argument(
60 | '--ratios',
61 | type=int,
62 | nargs='+',
63 | # probs(ratios) = hop_size
64 | default=[8, 5, 4, 2],
65 | help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)'
66 | )
67 | parser.add_argument(
68 | '--target_bandwidths',
69 | type=float,
70 | nargs='+',
71 | # default for 16k_320d
72 | default=[1, 1.5, 2, 4, 6, 12],
73 | help='target_bandwidths of net3.py')
74 | parser.add_argument(
75 | '--target_bw',
76 | type=float,
77 | # default for 16k_320d
78 | default=12,
79 | help='target_bw of net3.py')
80 |
81 | return parser
82 |
83 |
84 | def fatal(*args):
85 | print(*args, file=sys.stderr)
86 | sys.exit(1)
87 |
88 |
89 | # 这只是打印了但是没有真的 clip
90 | def check_clipping(wav, rescale):
91 | if rescale:
92 | return
93 | mx = wav.abs().max()
94 | limit = 0.99
95 | if mx > limit:
96 | print(
97 | f"Clipping!! max scale {mx}, limit is {limit}. "
98 | "To avoid clipping, use the `-r` option to rescale the output.",
99 | file=sys.stderr)
100 |
101 |
102 | def test_one(args, wav_root, store_root, rescale, soundstream):
103 | # torchaudio.load 的采样率为原始音频的采样率,不会自动下采样
104 | # wav, sr = torchaudio.load(wav_root)
105 | # # 取单声道, output shape [1, T]
106 | # wav = wav[0].unsqueeze(0)
107 | # # 重采样为模型的采样率
108 | # wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav)
109 |
110 | # load wav with librosa
111 | wav, sr = librosa.load(wav_root, sr=args.sr)
112 | wav = torch.tensor(wav).unsqueeze(0)
113 |
114 | # add batch axis
115 | wav = wav.unsqueeze(1).cuda()
116 |
117 | # compressing
118 | compressed = soundstream.encode(wav, target_bw=args.target_bw)
119 | print('finish compressing')
120 | out = soundstream.decode(compressed)
121 | out = out.detach().cpu().squeeze(0)
122 | check_clipping(out, rescale)
123 | save_audio(wav=out, path=store_root, sample_rate=args.sr, rescale=rescale)
124 | print('finish decompressing')
125 |
126 |
127 | def remove_encodec_weight_norm(model):
128 | from academicodec.modules import SConv1d
129 | from academicodec.modules.seanet import SConvTranspose1d
130 | from academicodec.modules.seanet import SEANetResnetBlock
131 | from torch.nn.utils import remove_weight_norm
132 |
133 | encoder = model.encoder.model
134 | for key in encoder._modules:
135 | if isinstance(encoder._modules[key], SEANetResnetBlock):
136 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
137 | block_modules = encoder._modules[key].block._modules
138 | for skey in block_modules:
139 | if isinstance(block_modules[skey], SConv1d):
140 | remove_weight_norm(block_modules[skey].conv.conv)
141 | elif isinstance(encoder._modules[key], SConv1d):
142 | remove_weight_norm(encoder._modules[key].conv.conv)
143 |
144 | decoder = model.decoder.model
145 | for key in decoder._modules:
146 | if isinstance(decoder._modules[key], SEANetResnetBlock):
147 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
148 | block_modules = decoder._modules[key].block._modules
149 | for skey in block_modules:
150 | if isinstance(block_modules[skey], SConv1d):
151 | remove_weight_norm(block_modules[skey].conv.conv)
152 | elif isinstance(decoder._modules[key], SConvTranspose1d):
153 | remove_weight_norm(decoder._modules[key].convtr.convtr)
154 | elif isinstance(decoder._modules[key], SConv1d):
155 | remove_weight_norm(decoder._modules[key].conv.conv)
156 |
157 |
158 | def test_batch():
159 | args = get_parser().parse_args()
160 | print("args.target_bandwidths:", args.target_bandwidths)
161 | if not args.input.exists():
162 | fatal(f"Input file {args.input} does not exist.")
163 | input_lists = os.listdir(args.input)
164 | input_lists.sort()
165 | soundstream = SoundStream(
166 | n_filters=32,
167 | D=512,
168 | ratios=args.ratios,
169 | sample_rate=args.sr,
170 | target_bandwidths=args.target_bandwidths)
171 | parameter_dict = torch.load(args.resume_path)
172 | new_state_dict = OrderedDict()
173 | # k 为 module.xxx.weight, v 为权重
174 | for k, v in parameter_dict.items():
175 | # 截取`module.`后面的xxx.weight
176 | name = k[7:]
177 | new_state_dict[name] = v
178 | soundstream.load_state_dict(new_state_dict) # load model
179 | remove_encodec_weight_norm(soundstream)
180 | soundstream.cuda()
181 | soundstream.eval()
182 | os.makedirs(args.output, exist_ok=True)
183 | for audio in input_lists:
184 | test_one(
185 | args=args,
186 | wav_root=os.path.join(args.input, audio),
187 | store_root=os.path.join(args.output, audio),
188 | rescale=args.rescale,
189 | soundstream=soundstream)
190 |
191 |
192 | if __name__ == '__main__':
193 | test_batch()
194 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/hificodec/__init__.py
--------------------------------------------------------------------------------
/academicodec/models/hificodec/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/meldataset.py:
--------------------------------------------------------------------------------
1 | # code based on https://github.com/b04901014/MQTTS
2 | import math
3 | import os
4 | import random
5 |
6 | import librosa
7 | import numpy as np
8 | import torch.utils.data
9 | from librosa.filters import mel as librosa_mel_fn
10 |
11 |
12 | def load_wav(full_path, sr):
13 | wav, sr = librosa.load(full_path, sr=sr)
14 | return wav, sr
15 |
16 |
17 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
18 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
19 |
20 |
21 | def dynamic_range_decompression(x, C=1):
22 | return np.exp(x) / C
23 |
24 |
25 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
26 | return torch.log(torch.clamp(x, min=clip_val) * C)
27 |
28 |
29 | def dynamic_range_decompression_torch(x, C=1):
30 | return torch.exp(x) / C
31 |
32 |
33 | def spectral_normalize_torch(magnitudes):
34 | output = dynamic_range_compression_torch(magnitudes)
35 | return output
36 |
37 |
38 | def spectral_de_normalize_torch(magnitudes):
39 | output = dynamic_range_decompression_torch(magnitudes)
40 | return output
41 |
42 |
43 | mel_basis = {}
44 | hann_window = {}
45 |
46 |
47 | def mel_spectrogram(y,
48 | n_fft,
49 | num_mels,
50 | sampling_rate,
51 | hop_size,
52 | win_size,
53 | fmin,
54 | fmax,
55 | center=False):
56 | if torch.min(y) < -1.:
57 | print('min value is ', torch.min(y))
58 | if torch.max(y) > 1.:
59 | print('max value is ', torch.max(y))
60 |
61 | global mel_basis, hann_window
62 | if fmax not in mel_basis:
63 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
64 | mel_basis[str(fmax) + '_' +
65 | str(y.device)] = torch.from_numpy(mel).float().to(y.device)
66 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
67 |
68 | y = torch.nn.functional.pad(
69 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int(
70 | (n_fft - hop_size) / 2)),
71 | mode='reflect')
72 | y = y.squeeze(1)
73 |
74 | spec = torch.stft(
75 | y,
76 | n_fft,
77 | hop_length=hop_size,
78 | win_length=win_size,
79 | window=hann_window[str(y.device)],
80 | center=center,
81 | pad_mode='reflect',
82 | normalized=False,
83 | onesided=True)
84 |
85 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
86 |
87 | spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
88 | spec = spectral_normalize_torch(spec)
89 |
90 | return spec
91 |
92 |
93 | def get_dataset_filelist(a):
94 | with open(a.input_training_file, 'r') as f:
95 | training_files = [l.strip() for l in f]
96 | with open(a.input_validation_file, 'r') as f:
97 | validation_files = [l.strip() for l in f]
98 | return training_files, validation_files
99 |
100 |
101 | class MelDataset(torch.utils.data.Dataset):
102 | def __init__(self,
103 | training_files,
104 | segment_size,
105 | n_fft,
106 | num_mels,
107 | hop_size,
108 | win_size,
109 | sampling_rate,
110 | fmin,
111 | fmax,
112 | split=True,
113 | shuffle=True,
114 | n_cache_reuse=1,
115 | device=None,
116 | fmax_loss=None,
117 | fine_tuning=False,
118 | base_mels_path=None):
119 | self.audio_files = training_files
120 | random.seed(1234)
121 | if shuffle:
122 | random.shuffle(self.audio_files)
123 | self.segment_size = segment_size
124 | self.sampling_rate = sampling_rate
125 | self.split = split
126 | self.n_fft = n_fft
127 | self.num_mels = num_mels
128 | self.hop_size = hop_size
129 | self.win_size = win_size
130 | self.fmin = fmin
131 | self.fmax = fmax
132 | self.fmax_loss = fmax_loss
133 | self.cached_wav = None
134 | self.n_cache_reuse = n_cache_reuse
135 | self._cache_ref_count = 0
136 | self.device = device
137 | self.fine_tuning = fine_tuning
138 | self.base_mels_path = base_mels_path
139 |
140 | def __getitem__(self, index):
141 | filename = self.audio_files[index]
142 | if self._cache_ref_count == 0:
143 | try:
144 | # Note by yuantian: load with the sample_rate of config
145 | audio, sampling_rate = load_wav(filename, sr=self.sampling_rate)
146 | except Exception as e:
147 | print(f"Error on audio: {filename}")
148 | audio = np.random.normal(size=(160000, )) * 0.05
149 | sampling_rate = self.sampling_rate
150 | self.cached_wav = audio
151 | if sampling_rate != self.sampling_rate:
152 | raise ValueError("{} SR doesn't match target {} SR".format(
153 | sampling_rate, self.sampling_rate))
154 | self._cache_ref_count = self.n_cache_reuse
155 | else:
156 | audio = self.cached_wav
157 | self._cache_ref_count -= 1
158 |
159 | audio = torch.FloatTensor(audio)
160 | audio = audio.unsqueeze(0)
161 |
162 | if not self.fine_tuning:
163 | if self.split:
164 | if audio.size(1) >= self.segment_size:
165 | max_audio_start = audio.size(1) - self.segment_size
166 | audio_start = random.randint(0, max_audio_start)
167 | audio = audio[:, audio_start:audio_start +
168 | self.segment_size]
169 | else:
170 | audio = torch.nn.functional.pad(audio, (
171 | 0, self.segment_size - audio.size(1)), 'constant')
172 |
173 | mel = mel_spectrogram(
174 | audio,
175 | self.n_fft,
176 | self.num_mels,
177 | self.sampling_rate,
178 | self.hop_size,
179 | self.win_size,
180 | self.fmin,
181 | self.fmax,
182 | center=False)
183 | else:
184 | mel = np.load(
185 | os.path.join(self.base_mels_path,
186 | os.path.splitext(os.path.split(filename)[-1])[0] +
187 | '.npy'))
188 | mel = torch.from_numpy(mel)
189 |
190 | if len(mel.shape) < 3:
191 | mel = mel.unsqueeze(0)
192 |
193 | if self.split:
194 | frames_per_seg = math.ceil(self.segment_size / self.hop_size)
195 |
196 | if audio.size(1) >= self.segment_size:
197 | mel_start = random.randint(0,
198 | mel.size(2) - frames_per_seg - 1)
199 | mel = mel[:, :, mel_start:mel_start + frames_per_seg]
200 | audio = audio[:, mel_start * self.hop_size:(
201 | mel_start + frames_per_seg) * self.hop_size]
202 | else:
203 | mel = torch.nn.functional.pad(mel, (
204 | 0, frames_per_seg - mel.size(2)), 'constant')
205 | audio = torch.nn.functional.pad(audio, (
206 | 0, self.segment_size - audio.size(1)), 'constant')
207 |
208 | mel_loss = mel_spectrogram(
209 | audio,
210 | self.n_fft,
211 | self.num_mels,
212 | self.sampling_rate,
213 | self.hop_size,
214 | self.win_size,
215 | self.fmin,
216 | self.fmax_loss,
217 | center=False)
218 |
219 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
220 |
221 | def __len__(self):
222 | return len(self.audio_files)
223 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.simplefilter(action='ignore', category=FutureWarning)
3 | import itertools
4 | import os
5 | import time
6 | import argparse
7 | import json
8 | import torch
9 | import torch.nn.functional as F
10 | from torchaudio.transforms import MelSpectrogram
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DistributedSampler, DataLoader
13 | import torch.multiprocessing as mp
14 | from torch.distributed import init_process_group
15 | from torch.nn.parallel import DistributedDataParallel
16 |
17 | from academicodec.models.hificodec.env import AttrDict, build_env
18 | from academicodec.models.hificodec.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
19 | from academicodec.models.encodec.msstftd import MultiScaleSTFTDiscriminator
20 | from academicodec.models.hificodec.models import Generator
21 | from academicodec.models.hificodec.models import MultiPeriodDiscriminator
22 | from academicodec.models.hificodec.models import MultiScaleDiscriminator
23 | from academicodec.models.hificodec.models import feature_loss
24 | from academicodec.models.hificodec.models import generator_loss
25 | from academicodec.models.hificodec.models import discriminator_loss
26 | from academicodec.models.hificodec.models import Encoder
27 | from academicodec.models.hificodec.models import Quantizer
28 | from academicodec.utils import plot_spectrogram
29 | from academicodec.utils import scan_checkpoint
30 | from academicodec.utils import load_checkpoint
31 | from academicodec.utils import save_checkpoint
32 |
33 | torch.backends.cudnn.benchmark = True
34 |
35 |
36 | def reconstruction_loss(x, G_x, device, eps=1e-7):
37 | L = 100 * F.mse_loss(x, G_x) # wav L1 loss
38 | for i in range(6, 11):
39 | s = 2**i
40 | melspec = MelSpectrogram(
41 | sample_rate=24000,
42 | n_fft=s,
43 | hop_length=s // 4,
44 | n_mels=64,
45 | wkwargs={"device": device}).to(device)
46 | # 64, 16, 64
47 | # 128, 32, 128
48 | # 256, 64, 256
49 | # 512, 128, 512
50 | # 1024, 256, 1024
51 | S_x = melspec(x)
52 | S_G_x = melspec(G_x)
53 | loss = ((S_x - S_G_x).abs().mean() + (
54 | ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2
55 | ).mean(dim=-2)**0.5).mean()) / (i)
56 | L += loss
57 | #print('i ,loss ', i, loss)
58 | #assert 1==2
59 | return L
60 |
61 |
62 | def train(rank, a, h):
63 | torch.cuda.set_device(rank)
64 | if h.num_gpus > 1:
65 | init_process_group(
66 | backend=h.dist_config['dist_backend'],
67 | init_method=h.dist_config['dist_url'],
68 | world_size=h.dist_config['world_size'] * h.num_gpus,
69 | rank=rank)
70 |
71 | torch.cuda.manual_seed(h.seed)
72 | device = torch.device('cuda:{:d}'.format(rank))
73 |
74 | encoder = Encoder(h).to(device)
75 | generator = Generator(h).to(device)
76 | quantizer = Quantizer(h).to(device)
77 | mpd = MultiPeriodDiscriminator().to(device)
78 | msd = MultiScaleDiscriminator().to(device)
79 | mstftd = MultiScaleSTFTDiscriminator(32).to(device)
80 | if rank == 0:
81 | print(encoder)
82 | print(quantizer)
83 | print(generator)
84 | os.makedirs(a.checkpoint_path, exist_ok=True)
85 | print("checkpoints directory : ", a.checkpoint_path)
86 |
87 | if os.path.isdir(a.checkpoint_path):
88 | cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
89 | cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
90 |
91 | steps = 0
92 | if cp_g is None or cp_do is None:
93 | state_dict_do = None
94 | last_epoch = -1
95 | else:
96 | state_dict_g = load_checkpoint(cp_g, device)
97 | state_dict_do = load_checkpoint(cp_do, device)
98 | generator.load_state_dict(state_dict_g['generator'])
99 | encoder.load_state_dict(state_dict_g['encoder'])
100 | quantizer.load_state_dict(state_dict_g['quantizer'])
101 | mpd.load_state_dict(state_dict_do['mpd'])
102 | msd.load_state_dict(state_dict_do['msd'])
103 | mstftd.load_state_dict(state_dict_do['mstftd'])
104 | steps = state_dict_do['steps'] + 1
105 | last_epoch = state_dict_do['epoch']
106 |
107 | if h.num_gpus > 1:
108 | generator = DistributedDataParallel(
109 | generator, device_ids=[rank]).to(device)
110 | encoder = DistributedDataParallel(encoder, device_ids=[rank]).to(device)
111 | quantizer = DistributedDataParallel(
112 | quantizer, device_ids=[rank]).to(device)
113 | mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
114 | msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
115 | mstftd = DistributedDataParallel(mstftd, device_ids=[rank]).to(device)
116 |
117 | optim_g = torch.optim.Adam(
118 | itertools.chain(generator.parameters(),
119 | encoder.parameters(), quantizer.parameters()),
120 | h.learning_rate,
121 | betas=[h.adam_b1, h.adam_b2])
122 | optim_d = torch.optim.Adam(
123 | itertools.chain(msd.parameters(), mpd.parameters(),
124 | mstftd.parameters()),
125 | h.learning_rate,
126 | betas=[h.adam_b1, h.adam_b2])
127 | if state_dict_do is not None:
128 | optim_g.load_state_dict(state_dict_do['optim_g'])
129 | optim_d.load_state_dict(state_dict_do['optim_d'])
130 |
131 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
132 | optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
133 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
134 | optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
135 |
136 | training_filelist, validation_filelist = get_dataset_filelist(a)
137 |
138 | trainset = MelDataset(
139 | training_filelist,
140 | h.segment_size,
141 | h.n_fft,
142 | h.num_mels,
143 | h.hop_size,
144 | h.win_size,
145 | h.sampling_rate,
146 | h.fmin,
147 | h.fmax,
148 | n_cache_reuse=0,
149 | shuffle=False if h.num_gpus > 1 else True,
150 | fmax_loss=h.fmax_for_loss,
151 | device=device,
152 | fine_tuning=a.fine_tuning,
153 | base_mels_path=a.input_mels_dir)
154 |
155 | train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
156 |
157 | train_loader = DataLoader(
158 | trainset,
159 | num_workers=h.num_workers,
160 | shuffle=False,
161 | sampler=train_sampler,
162 | batch_size=h.batch_size,
163 | pin_memory=True,
164 | drop_last=True)
165 |
166 | if rank == 0:
167 | validset = MelDataset(
168 | validation_filelist,
169 | h.segment_size,
170 | h.n_fft,
171 | h.num_mels,
172 | h.hop_size,
173 | h.win_size,
174 | h.sampling_rate,
175 | h.fmin,
176 | h.fmax,
177 | False,
178 | False,
179 | n_cache_reuse=0,
180 | fmax_loss=h.fmax_for_loss,
181 | device=device,
182 | fine_tuning=a.fine_tuning,
183 | base_mels_path=a.input_mels_dir)
184 | validation_loader = DataLoader(
185 | validset,
186 | num_workers=1,
187 | shuffle=False,
188 | sampler=None,
189 | batch_size=1,
190 | pin_memory=True,
191 | drop_last=True)
192 | sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
193 | plot_gt_once = False
194 | generator.train()
195 | encoder.train()
196 | quantizer.train()
197 | mpd.train()
198 | msd.train()
199 | for epoch in range(max(0, last_epoch), a.training_epochs):
200 | if rank == 0:
201 | start = time.time()
202 | print("Epoch: {}".format(epoch + 1))
203 | if h.num_gpus > 1:
204 | train_sampler.set_epoch(epoch)
205 | for i, batch in enumerate(train_loader):
206 | if rank == 0:
207 | start_b = time.time()
208 | x, y, _, y_mel = batch
209 | x = torch.autograd.Variable(x.to(device, non_blocking=True))
210 | y = torch.autograd.Variable(y.to(device, non_blocking=True))
211 | y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
212 | y = y.unsqueeze(1)
213 |
214 | c = encoder(y)
215 | # print("c.shape: ", c.shape)
216 | q, loss_q, c = quantizer(c)
217 | # print("q.shape: ", q.shape)
218 | y_g_hat = generator(q)
219 | y_g_hat_mel = mel_spectrogram(
220 | y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
221 | h.hop_size, h.win_size, h.fmin,
222 | h.fmax_for_loss) # 1024, 80, 24000, 240,1024
223 | y_r_mel_1 = mel_spectrogram(
224 | y.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512,
225 | h.fmin, h.fmax_for_loss)
226 | y_g_mel_1 = mel_spectrogram(
227 | y_g_hat.squeeze(1), 512, h.num_mels, h.sampling_rate, 120, 512,
228 | h.fmin, h.fmax_for_loss)
229 | y_r_mel_2 = mel_spectrogram(
230 | y.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256, h.fmin,
231 | h.fmax_for_loss)
232 | y_g_mel_2 = mel_spectrogram(
233 | y_g_hat.squeeze(1), 256, h.num_mels, h.sampling_rate, 60, 256,
234 | h.fmin, h.fmax_for_loss)
235 | y_r_mel_3 = mel_spectrogram(
236 | y.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128, h.fmin,
237 | h.fmax_for_loss)
238 | y_g_mel_3 = mel_spectrogram(
239 | y_g_hat.squeeze(1), 128, h.num_mels, h.sampling_rate, 30, 128,
240 | h.fmin, h.fmax_for_loss)
241 | # print("x.shape: ", x.shape)
242 | # print("y.shape: ", y.shape)
243 | # print("y_g_hat.shape: ", y_g_hat.shape)
244 | optim_d.zero_grad()
245 |
246 | # MPD
247 | y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
248 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
249 | y_df_hat_r, y_df_hat_g)
250 |
251 | # MSD
252 | y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
253 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
254 | y_ds_hat_r, y_ds_hat_g)
255 |
256 | y_disc_r, fmap_r = mstftd(y)
257 | y_disc_gen, fmap_gen = mstftd(y_g_hat.detach())
258 | loss_disc_stft, losses_disc_stft_r, losses_disc_stft_g = discriminator_loss(
259 | y_disc_r, y_disc_gen)
260 | loss_disc_all = loss_disc_s + loss_disc_f + loss_disc_stft
261 |
262 | loss_disc_all.backward()
263 | optim_d.step()
264 |
265 | # Generator
266 | optim_g.zero_grad()
267 |
268 | # L1 Mel-Spectrogram Loss
269 | loss_mel1 = F.l1_loss(y_r_mel_1, y_g_mel_1)
270 | loss_mel2 = F.l1_loss(y_r_mel_2, y_g_mel_2)
271 | loss_mel3 = F.l1_loss(y_r_mel_3, y_g_mel_3)
272 | #print('loss_mel1, loss_mel2 ', loss_mel1, loss_mel2)
273 | loss_mel = F.l1_loss(y_mel,
274 | y_g_hat_mel) * 45 + loss_mel1 + loss_mel2
275 | # print('loss_mel ', loss_mel)
276 | # assert 1==2
277 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
278 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
279 | y_stftd_hat_r, fmap_stftd_r = mstftd(y)
280 | y_stftd_hat_g, fmap_stftd_g = mstftd(y_g_hat)
281 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
282 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
283 | loss_fm_stft = feature_loss(fmap_stftd_r, fmap_stftd_g)
284 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
285 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
286 | loss_gen_stft, losses_gen_stft = generator_loss(y_stftd_hat_g)
287 | loss_gen_all = loss_gen_s + loss_gen_f + loss_gen_stft + loss_fm_s + loss_fm_f + loss_fm_stft + loss_mel + loss_q * 10
288 | loss_gen_all.backward()
289 | optim_g.step()
290 | if rank == 0:
291 | # STDOUT logging
292 | if steps % a.stdout_interval == 0:
293 | with torch.no_grad():
294 | mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
295 | print(
296 | 'Steps : {:d}, Gen Loss Total : {:4.3f}, Loss Q : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
297 | format(steps, loss_gen_all, loss_q, mel_error,
298 | time.time() - start_b))
299 | # checkpointing
300 | if steps % a.checkpoint_interval == 0 and steps != 0:
301 | checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path,
302 | steps)
303 | save_checkpoint(
304 | checkpoint_path, {
305 | 'generator': (generator.module if h.num_gpus > 1
306 | else generator).state_dict(),
307 | 'encoder': (encoder.module if h.num_gpus > 1 else
308 | encoder).state_dict(),
309 | 'quantizer': (quantizer.module if h.num_gpus > 1
310 | else quantizer).state_dict()
311 | },
312 | num_ckpt_keep=a.num_ckpt_keep)
313 | checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path,
314 | steps)
315 | save_checkpoint(
316 | checkpoint_path, {
317 | 'mpd': (mpd.module
318 | if h.num_gpus > 1 else mpd).state_dict(),
319 | 'msd': (msd.module
320 | if h.num_gpus > 1 else msd).state_dict(),
321 | 'mstftd': (mstftd.module
322 | if h.num_gpus > 1 else msd).state_dict(),
323 | 'optim_g':
324 | optim_g.state_dict(),
325 | 'optim_d':
326 | optim_d.state_dict(),
327 | 'steps':
328 | steps,
329 | 'epoch':
330 | epoch
331 | },
332 | num_ckpt_keep=a.num_ckpt_keep)
333 | # Tensorboard summary logging
334 | if steps % a.summary_interval == 0:
335 | sw.add_scalar("training/gen_loss_total", loss_gen_all,
336 | steps)
337 | sw.add_scalar("training/mel_spec_error", mel_error, steps)
338 |
339 | # Validation
340 | if steps % a.validation_interval == 0 and steps != 0:
341 | generator.eval()
342 | encoder.eval()
343 | quantizer.eval()
344 | torch.cuda.empty_cache()
345 | val_err_tot = 0
346 | with torch.no_grad():
347 | for j, batch in enumerate(validation_loader):
348 | x, y, _, y_mel = batch
349 | c = encoder(y.to(device).unsqueeze(1))
350 | q, loss_q, c = quantizer(c)
351 | y_g_hat = generator(q)
352 | y_mel = torch.autograd.Variable(y_mel.to(device))
353 | y_g_hat_mel = mel_spectrogram(
354 | y_g_hat.squeeze(1), h.n_fft, h.num_mels,
355 | h.sampling_rate, h.hop_size, h.win_size, h.fmin,
356 | h.fmax_for_loss)
357 | i_size = min(y_mel.size(2), y_g_hat_mel.size(2))
358 | val_err_tot += F.l1_loss(
359 | y_mel[:, :, :i_size],
360 | y_g_hat_mel[:, :, :i_size]).item()
361 |
362 | if j <= 8:
363 | # if steps == 0:
364 | if not plot_gt_once:
365 | sw.add_audio('gt/y_{}'.format(j), y[0],
366 | steps, h.sampling_rate)
367 | sw.add_figure('gt/y_spec_{}'.format(j),
368 | plot_spectrogram(x[0]), steps)
369 |
370 | sw.add_audio('generated/y_hat_{}'.format(j),
371 | y_g_hat[0], steps, h.sampling_rate)
372 | y_hat_spec = mel_spectrogram(
373 | y_g_hat.squeeze(1), h.n_fft, h.num_mels,
374 | h.sampling_rate, h.hop_size, h.win_size,
375 | h.fmin, h.fmax)
376 | sw.add_figure(
377 | 'generated/y_hat_spec_{}'.format(j),
378 | plot_spectrogram(
379 | y_hat_spec.squeeze(0).cpu().numpy()),
380 | steps)
381 |
382 | val_err = val_err_tot / (j + 1)
383 | sw.add_scalar("validation/mel_spec_error", val_err,
384 | steps)
385 | if not plot_gt_once:
386 | plot_gt_once = True
387 |
388 | generator.train()
389 |
390 | steps += 1
391 |
392 | scheduler_g.step()
393 | scheduler_d.step()
394 |
395 | if rank == 0:
396 | print('Time taken for epoch {} is {} sec\n'.format(
397 | epoch + 1, int(time.time() - start)))
398 |
399 |
400 | def main():
401 | print('Initializing Training Process..')
402 |
403 | parser = argparse.ArgumentParser()
404 |
405 | # parser.add_argument('--group_name', default=None)
406 | # parser.add_argument('--input_wavs_dir', default='../datasets/audios')
407 | parser.add_argument('--input_mels_dir', default=None)
408 | parser.add_argument('--input_training_file', required=True)
409 | parser.add_argument('--input_validation_file', required=True)
410 | parser.add_argument('--checkpoint_path', default='checkpoints')
411 | parser.add_argument('--config', default='')
412 | parser.add_argument('--training_epochs', default=2000, type=int)
413 | parser.add_argument('--stdout_interval', default=5, type=int)
414 | parser.add_argument('--checkpoint_interval', default=5000, type=int)
415 | parser.add_argument('--summary_interval', default=100, type=int)
416 | parser.add_argument('--validation_interval', default=5000, type=int)
417 | parser.add_argument('--num_ckpt_keep', default=5, type=int)
418 | parser.add_argument('--fine_tuning', default=False, type=bool)
419 |
420 | a = parser.parse_args()
421 |
422 | with open(a.config) as f:
423 | data = f.read()
424 |
425 | json_config = json.loads(data)
426 | h = AttrDict(json_config)
427 | build_env(a.config, 'config.json', a.checkpoint_path)
428 |
429 | torch.manual_seed(h.seed)
430 | if torch.cuda.is_available():
431 | torch.cuda.manual_seed(h.seed)
432 | h.num_gpus = torch.cuda.device_count()
433 | h.batch_size = int(h.batch_size / h.num_gpus)
434 | print('Batch size per GPU :', h.batch_size)
435 | else:
436 | pass
437 |
438 | if h.num_gpus > 1:
439 | mp.spawn(train, nprocs=h.num_gpus, args=(a, h, ))
440 | else:
441 | train(0, a, h)
442 |
443 |
444 | if __name__ == '__main__':
445 | main()
446 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/vqvae.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from academicodec.models.hificodec.env import AttrDict
7 | from academicodec.models.hificodec.models import Encoder
8 | from academicodec.models.hificodec.models import Generator
9 | from academicodec.models.hificodec.models import Quantizer
10 |
11 |
12 | class VQVAE(nn.Module):
13 | def __init__(self,
14 | config_path,
15 | ckpt_path,
16 | with_encoder=False):
17 | super(VQVAE, self).__init__()
18 | ckpt = torch.load(ckpt_path)
19 | with open(config_path) as f:
20 | data = f.read()
21 | json_config = json.loads(data)
22 | self.h = AttrDict(json_config)
23 | self.quantizer = Quantizer(self.h)
24 | self.generator = Generator(self.h)
25 | self.generator.load_state_dict(ckpt['generator'])
26 | self.quantizer.load_state_dict(ckpt['quantizer'])
27 | if with_encoder:
28 | self.encoder = Encoder(self.h)
29 | self.encoder.load_state_dict(ckpt['encoder'])
30 |
31 | def forward(self, x):
32 | # x is the codebook
33 | # x.shape (B, T, Nq)
34 | quant_emb = self.quantizer.embed(x)
35 | return self.generator(quant_emb)
36 |
37 | def encode(self, x):
38 | batch_size = x.size(0)
39 | if len(x.shape) == 3 and x.shape[-1] == 1:
40 | x = x.squeeze(-1)
41 | c = self.encoder(x.unsqueeze(1))
42 | q, loss_q, c = self.quantizer(c)
43 | c = [code.reshape(batch_size, -1) for code in c]
44 | # shape: [N, T, 4]
45 | return torch.stack(c, -1)
46 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/vqvae_copy_syn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 | from pathlib import Path
6 |
7 | import soundfile as sf
8 | from tqdm import tqdm
9 |
10 | from academicodec.models.hificodec.vqvae_tester import VqvaeTester
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | #Path
15 | parser.add_argument('--outputdir', type=str, required=True)
16 | parser.add_argument('--model_path', type=str, required=True)
17 | parser.add_argument('--input_wavdir', type=str, required=True)
18 | parser.add_argument('--config_path', type=str, required=True)
19 | parser.add_argument('--num_gens', type=int, default=1024)
20 |
21 | #Data
22 | parser.add_argument('--sample_rate', type=int, default=24000)
23 |
24 | args = parser.parse_args()
25 |
26 | with open(args.config_path, 'r') as f:
27 | argdict = json.load(f)
28 | assert argdict['sampling_rate'] == args.sample_rate, \
29 | f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}"
30 | argdict.update(args.__dict__)
31 | args.__dict__ = argdict
32 |
33 | if __name__ == '__main__':
34 | Path(args.outputdir).mkdir(parents=True, exist_ok=True)
35 | print("Init model and load weights")
36 | model = VqvaeTester(config_path=args.config_path, model_path=args.model_path,sample_rate=args.sample_rate)
37 | model.cuda()
38 | model.vqvae.generator.remove_weight_norm()
39 | model.vqvae.encoder.remove_weight_norm()
40 | model.eval()
41 | print("Model ready")
42 |
43 | wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens]
44 | print(f"Globbed {len(wav_paths)} wav files.")
45 |
46 | for wav_path in wav_paths:
47 | fid, wav = model(wav_path)
48 | wav = wav.squeeze().cpu().numpy()
49 | sf.write(
50 | os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate)
51 |
--------------------------------------------------------------------------------
/academicodec/models/hificodec/vqvae_tester.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import torch
5 | import torch.nn as nn
6 |
7 | from academicodec.models.hificodec.vqvae import VQVAE
8 |
9 |
10 | class VqvaeTester(nn.Module):
11 | def __init__(self, config_path, model_path, sample_rate=24000):
12 | super().__init__()
13 | self.vqvae = VQVAE(config_path, model_path, with_encoder=True)
14 | self.sample_rate = sample_rate
15 |
16 | @torch.no_grad()
17 | def forward(self, wav_path):
18 | # 单声道
19 | # wav.shape (T, ), 按照模型的 sr 读取
20 | wav, sr = librosa.load(wav_path, sr=self.sample_rate)
21 | fid = os.path.basename(wav_path)[:-4]
22 | wav = torch.tensor(wav).unsqueeze(0)
23 | wav = wav.cuda()
24 | # vq_codes is acoustic token
25 | vq_codes = self.vqvae.encode(wav)
26 | syn = self.vqvae(vq_codes)
27 | return fid, syn
28 |
29 | @torch.no_grad()
30 | def vq(self, wav_path):
31 | wav, sr = librosa.load(wav_path, sr=self.sample_rate)
32 | fid = os.path.basename(wav_path)[:-4]
33 | wav = torch.tensor(wav).unsqueeze(0)
34 | wav = wav.cuda()
35 | # vq_codes is acoustic token
36 | vq_codes = self.vqvae.encode(wav)
37 | return fid, vq_codes
38 |
--------------------------------------------------------------------------------
/academicodec/models/soundstream/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangdongchao/AcademiCodec/b6ac134735f6079543db959a60eb77a7bab4277b/academicodec/models/soundstream/__init__.py
--------------------------------------------------------------------------------
/academicodec/models/soundstream/dataset.py:
--------------------------------------------------------------------------------
1 | # 和 Encodec* 的 dataset.py 有点类似但是不完全一样
2 | # 主要是 prob > 0.7 的时候多了 ans2
3 | import glob
4 | import random
5 |
6 | import torch
7 | import torchaudio
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class NSynthDataset(Dataset):
12 | """Dataset to load NSynth data."""
13 |
14 | def __init__(self, audio_dir):
15 | super().__init__()
16 | self.filenames = []
17 | self.filenames.extend(glob.glob(audio_dir + "/*.wav"))
18 | print(len(self.filenames))
19 | _, self.sr = torchaudio.load(self.filenames[0])
20 | self.max_len = 24000 # 24000
21 |
22 | def __len__(self):
23 | return len(self.filenames)
24 |
25 | def __getitem__(self, index):
26 | #print(self.filenames[index])
27 | prob = random.random() # (0,1)
28 | if prob > 0.7:
29 | # data augmentation
30 | ans1 = torch.zeros(1, self.max_len)
31 | ans2 = torch.zeros(1, self.max_len)
32 | audio1 = torchaudio.load(self.filenames[index])[0]
33 | index2 = random.randint(0, len(self.filenames) - 1)
34 | audio2 = torchaudio.load(self.filenames[index2])[0]
35 | if audio1.shape[1] > self.max_len:
36 | st = random.randint(0, audio1.shape[1] - self.max_len - 1)
37 | ed = st + self.max_len
38 | ans1 = audio1[:, st:ed]
39 | else:
40 | ans1[:, :audio1.shape[1]] = audio1
41 | if audio2.shape[1] > self.max_len:
42 | st = random.randint(0, audio2.shape[1] - self.max_len - 1)
43 | ed = st + self.max_len
44 | ans2 = audio2[:, st:ed]
45 | else:
46 | ans2[:, :audio2.shape[1]] = audio2
47 | ans = ans1 + ans2
48 | return ans
49 | else:
50 | ans = torch.zeros(1, self.max_len)
51 | audio = torchaudio.load(self.filenames[index])[0]
52 | if audio.shape[1] > self.max_len:
53 | st = random.randint(0, audio.shape[1] - self.max_len - 1)
54 | ed = st + self.max_len
55 | return audio[:, st:ed]
56 | else:
57 | ans[:, :audio.shape[1]] = audio
58 | return ans
59 |
--------------------------------------------------------------------------------
/academicodec/models/soundstream/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchaudio.transforms import MelSpectrogram
4 |
5 |
6 | def adversarial_g_loss(y_disc_gen):
7 | loss = 0.0
8 | for i in range(len(y_disc_gen)):
9 | #print(y_disc_gen[i].shape)
10 | # assert 1==2
11 | stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
12 | loss += stft_loss
13 | return loss / len(y_disc_gen)
14 |
15 |
16 | def feature_loss(fmap_r, fmap_gen):
17 | loss = 0.0
18 | for i in range(len(fmap_r)):
19 | for j in range(len(fmap_r[i])):
20 | stft_loss = ((fmap_r[i][j] - fmap_gen[i][j]).abs() /
21 | (fmap_r[i][j].abs().mean())).mean()
22 | loss += stft_loss
23 | return loss / (len(fmap_r) * len(fmap_r[0]))
24 |
25 |
26 | def sim_loss(y_disc_r, y_disc_gen):
27 | loss = 0.0
28 | for i in range(len(y_disc_r)):
29 | loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
30 | return loss / len(y_disc_r)
31 |
32 |
33 | def sisnr_loss(x, s, eps=1e-8):
34 | """
35 | calculate training loss
36 | input:
37 | x: separated signal, N x S tensor, estimate value
38 | s: reference signal, N x S tensor, True value
39 | Return:
40 | sisnr: N tensor
41 | """
42 | if x.shape != s.shape:
43 | if x.shape[-1] > s.shape[-1]:
44 | x = x[:, :s.shape[-1]]
45 | else:
46 | s = s[:, :x.shape[-1]]
47 |
48 | def l2norm(mat, keepdim=False):
49 | return torch.norm(mat, dim=-1, keepdim=keepdim)
50 |
51 | if x.shape != s.shape:
52 | raise RuntimeError("Dimention mismatch when calculate si-snr, {} vs {}".
53 | format(x.shape, s.shape))
54 | x_zm = x - torch.mean(x, dim=-1, keepdim=True)
55 | s_zm = s - torch.mean(s, dim=-1, keepdim=True)
56 | t = torch.sum(
57 | x_zm * s_zm, dim=-1,
58 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
59 | loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
60 | return torch.sum(loss) / x.shape[0]
61 |
62 |
63 | def reconstruction_loss(x, G_x, args, eps=1e-7):
64 | L = 100 * F.mse_loss(x, G_x) # wav L1 loss
65 | #loss_sisnr = sisnr_loss(G_x, x) #
66 | #L += 0.01*loss_sisnr
67 | # print('L0 ', L)
68 | # print('loss_sisnr ', 0.01*loss_sisnr)
69 | # print('L0 ', L)
70 | for i in range(6, 11):
71 | s = 2**i
72 | melspec = MelSpectrogram(
73 | sample_rate=args.sr,
74 | n_fft=max(s, 512),
75 | win_length=s,
76 | hop_length=s // 4,
77 | n_mels=64,
78 | wkwargs={"device": args.device}).to(args.device)
79 | S_x = melspec(x)
80 | S_G_x = melspec(G_x)
81 | l1_loss = (S_x - S_G_x).abs().mean()
82 | l2_loss = (((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps))**2).mean(dim=-2)**0.5).mean()
83 |
84 | alpha = (s / 2) ** 0.5
85 | L += (l1_loss + alpha * l2_loss)
86 | #print('i ,loss ', i, loss)
87 | #assert 1==2
88 | return L
89 |
90 |
91 | def criterion_d(y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r,
92 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g,
93 | fmap_s_r, fmap_s_g):
94 | loss = 0.0
95 | loss1 = 0.0
96 | loss2 = 0.0
97 | loss3 = 0.0
98 | loss_f = feature_loss(fmap_r_det, fmap_gen_det) + feature_loss(
99 | fmap_f_r, fmap_f_g) + feature_loss(fmap_s_r, fmap_s_g)
100 | for i in range(len(y_disc_r)):
101 | loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[
102 | i]).mean()
103 | for i in range(len(y_df_hat_r)):
104 | loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[
105 | i]).mean()
106 | for i in range(len(y_ds_hat_r)):
107 | loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[
108 | i]).mean()
109 | loss = (loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 /
110 | len(y_ds_hat_r)) / 3.0
111 | return loss + 0.0 * loss_f
112 |
113 |
114 | def criterion_g(commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen,
115 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r,
116 | y_ds_hat_g, fmap_s_r, fmap_s_g, args):
117 | adv_g_loss = adversarial_g_loss(y_disc_gen)
118 | feat_loss = (feature_loss(fmap_r, fmap_gen) + sim_loss(
119 | y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss(
120 | y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) +
121 | sim_loss(y_ds_hat_r, y_ds_hat_g)) / 3.0
122 | rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
123 | total_loss = args.LAMBDA_COM * commit_loss + args.LAMBDA_ADV * adv_g_loss + \
124 | args.LAMBDA_FEAT * feat_loss + args.LAMBDA_REC * rec_loss
125 | return total_loss, adv_g_loss, feat_loss, rec_loss
126 |
127 |
128 | def adopt_weight(weight, global_step, threshold=0, value=0.):
129 | if global_step < threshold:
130 | weight = value
131 | return weight
132 |
133 |
134 | def adopt_dis_weight(weight, global_step, threshold=0, value=0.):
135 | if global_step % 3 == 0: # 0,3,6,9,13....这些时间步,不更新dis
136 | weight = value
137 | return weight
138 |
139 |
140 | def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
141 | if last_layer is not None:
142 | nll_grads = torch.autograd.grad(
143 | nll_loss, last_layer, retain_graph=True)[0]
144 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
145 | else:
146 | print('last_layer cannot be none')
147 | assert 1 == 2
148 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
149 | d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
150 | d_weight = d_weight * args.LAMBDA_ADV
151 | return d_weight
152 |
153 |
154 | def loss_g(codebook_loss,
155 | inputs,
156 | reconstructions,
157 | fmap_r,
158 | fmap_gen,
159 | y_disc_r,
160 | y_disc_gen,
161 | global_step,
162 | y_df_hat_r,
163 | y_df_hat_g,
164 | y_ds_hat_r,
165 | y_ds_hat_g,
166 | fmap_f_r,
167 | fmap_f_g,
168 | fmap_s_r,
169 | fmap_s_g,
170 | last_layer=None,
171 | is_training=True,
172 | args=None):
173 | rec_loss = reconstruction_loss(inputs.contiguous(),
174 | reconstructions.contiguous(), args)
175 | adv_g_loss = adversarial_g_loss(y_disc_gen)
176 | adv_mpd_loss = adversarial_g_loss(y_df_hat_g)
177 | adv_msd_loss = adversarial_g_loss(y_ds_hat_g)
178 | adv_loss = (adv_g_loss + adv_mpd_loss + adv_msd_loss) / 3.0
179 | feat_loss = feature_loss(fmap_r, fmap_gen) + sim_loss(y_disc_r,
180 | y_disc_gen) #
181 | feat_loss_mpd = feature_loss(fmap_f_r, fmap_f_g) + sim_loss(y_df_hat_r,
182 | y_df_hat_g)
183 | feat_loss_msd = feature_loss(fmap_s_r, fmap_s_g) + sim_loss(y_ds_hat_r,
184 | y_ds_hat_g)
185 | feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
186 | d_weight = torch.tensor(1.0)
187 | # try:
188 | # d_weight = calculate_adaptive_weight(rec_loss, adv_g_loss, last_layer, args) # 动态调整重构损失和对抗损失
189 | # except RuntimeError:
190 | # assert not is_training
191 | # d_weight = torch.tensor(0.0)
192 | disc_factor = adopt_weight(
193 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start)
194 | #feat_factor = adopt_weight(args.LAMBDA_FEAT, global_step, threshold=args.discriminator_iter_start)
195 | loss = rec_loss + d_weight * disc_factor * adv_loss + \
196 | args.LAMBDA_FEAT * feat_loss_tot + args.LAMBDA_COM * codebook_loss
197 | return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
198 |
199 |
200 | def loss_dis(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r,
201 | y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r,
202 | fmap_s_g, global_step, args):
203 | disc_factor = adopt_weight(
204 | args.LAMBDA_ADV, global_step, threshold=args.discriminator_iter_start)
205 | d_loss = disc_factor * criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det,
206 | fmap_gen_det, y_df_hat_r, y_df_hat_g,
207 | fmap_f_r, fmap_f_g, y_ds_hat_r,
208 | y_ds_hat_g, fmap_s_r, fmap_s_g)
209 | return d_loss
210 |
--------------------------------------------------------------------------------
/academicodec/models/soundstream/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from academicodec.modules import NormConv1d
5 | from academicodec.modules import NormConv2d
6 | from academicodec.utils import get_padding
7 | from torch.nn import AvgPool1d
8 | from torch.nn.utils import spectral_norm
9 | from torch.nn.utils import weight_norm
10 |
11 | LRELU_SLOPE = 0.1
12 |
13 |
14 | class DiscriminatorP(torch.nn.Module):
15 | def __init__(self,
16 | period,
17 | kernel_size=5,
18 | stride=3,
19 | use_spectral_norm=False,
20 | activation: str='LeakyReLU',
21 | activation_params: dict={'negative_slope': 0.2}):
22 | super(DiscriminatorP, self).__init__()
23 | self.period = period
24 | norm_f = weight_norm if use_spectral_norm is False else spectral_norm
25 | self.activation = getattr(torch.nn, activation)(**activation_params)
26 | self.convs = nn.ModuleList([
27 | NormConv2d(
28 | 1,
29 | 32, (kernel_size, 1), (stride, 1),
30 | padding=(get_padding(5, 1), 0)),
31 | NormConv2d(
32 | 32,
33 | 32, (kernel_size, 1), (stride, 1),
34 | padding=(get_padding(5, 1), 0)),
35 | NormConv2d(
36 | 32,
37 | 32, (kernel_size, 1), (stride, 1),
38 | padding=(get_padding(5, 1), 0)),
39 | NormConv2d(
40 | 32,
41 | 32, (kernel_size, 1), (stride, 1),
42 | padding=(get_padding(5, 1), 0)),
43 | NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)),
44 | ])
45 | self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0))
46 |
47 | def forward(self, x):
48 | fmap = []
49 | # 1d to 2d
50 | b, c, t = x.shape
51 | if t % self.period != 0: # pad first
52 | n_pad = self.period - (t % self.period)
53 | x = F.pad(x, (0, n_pad), "reflect")
54 | t = t + n_pad
55 | x = x.view(b, c, t // self.period, self.period)
56 |
57 | for l in self.convs:
58 | x = l(x)
59 | x = self.activation(x)
60 | fmap.append(x)
61 | x = self.conv_post(x)
62 | fmap.append(x)
63 | x = torch.flatten(x, 1, -1)
64 |
65 | return x, fmap
66 |
67 |
68 | class MultiPeriodDiscriminator(torch.nn.Module):
69 | def __init__(self):
70 | super(MultiPeriodDiscriminator, self).__init__()
71 | self.discriminators = nn.ModuleList([
72 | DiscriminatorP(2),
73 | DiscriminatorP(3),
74 | DiscriminatorP(5),
75 | DiscriminatorP(7),
76 | DiscriminatorP(11),
77 | ])
78 |
79 | def forward(self, y, y_hat):
80 | y_d_rs = []
81 | y_d_gs = []
82 | fmap_rs = []
83 | fmap_gs = []
84 | for i, d in enumerate(self.discriminators):
85 | y_d_r, fmap_r = d(y)
86 | y_d_g, fmap_g = d(y_hat)
87 | y_d_rs.append(y_d_r)
88 | fmap_rs.append(fmap_r)
89 | y_d_gs.append(y_d_g)
90 | fmap_gs.append(fmap_g)
91 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
92 |
93 |
94 | class DiscriminatorS(torch.nn.Module):
95 | def __init__(self,
96 | use_spectral_norm=False,
97 | activation: str='LeakyReLU',
98 | activation_params: dict={'negative_slope': 0.2}):
99 | super(DiscriminatorS, self).__init__()
100 | self.activation = getattr(torch.nn, activation)(**activation_params)
101 | self.convs = nn.ModuleList([
102 | NormConv1d(1, 32, 15, 1, padding=7),
103 | NormConv1d(32, 32, 41, 2, groups=4, padding=20),
104 | NormConv1d(32, 32, 41, 2, groups=16, padding=20),
105 | NormConv1d(32, 32, 41, 4, groups=16, padding=20),
106 | NormConv1d(32, 32, 41, 4, groups=16, padding=20),
107 | NormConv1d(32, 32, 41, 1, groups=16, padding=20),
108 | NormConv1d(32, 32, 5, 1, padding=2),
109 | ])
110 | self.conv_post = NormConv1d(32, 1, 3, 1, padding=1)
111 |
112 | def forward(self, x):
113 | fmap = []
114 | for l in self.convs:
115 | x = l(x)
116 | x = self.activation(x)
117 | fmap.append(x)
118 | x = self.conv_post(x)
119 | fmap.append(x)
120 | x = torch.flatten(x, 1, -1)
121 | return x, fmap
122 |
123 |
124 | class MultiScaleDiscriminator(torch.nn.Module):
125 | def __init__(self):
126 | super(MultiScaleDiscriminator, self).__init__()
127 | self.discriminators = nn.ModuleList([
128 | DiscriminatorS(),
129 | DiscriminatorS(),
130 | DiscriminatorS(),
131 | ])
132 | self.meanpools = nn.ModuleList(
133 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
134 |
135 | def forward(self, y, y_hat):
136 | y_d_rs = []
137 | y_d_gs = []
138 | fmap_rs = []
139 | fmap_gs = []
140 | for i, d in enumerate(self.discriminators):
141 | if i != 0:
142 | y = self.meanpools[i - 1](y)
143 | y_hat = self.meanpools[i - 1](y_hat)
144 | y_d_r, fmap_r = d(y)
145 | y_d_g, fmap_g = d(y_hat)
146 | y_d_rs.append(y_d_r)
147 | fmap_rs.append(fmap_r)
148 | y_d_gs.append(y_d_g)
149 | fmap_gs.append(fmap_g)
150 |
151 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
152 |
--------------------------------------------------------------------------------
/academicodec/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Torch modules."""
7 | # flake8: noqa
8 | from .conv import NormConv1d
9 | from .conv import NormConv2d
10 | from .conv import NormConvTranspose1d
11 | from .conv import NormConvTranspose2d
12 | from .conv import pad1d
13 | from .conv import SConv1d
14 | from .conv import SConvTranspose1d
15 | from .conv import unpad1d
16 | from .lstm import SLSTM
17 | from .seanet import SEANetDecoder
18 | from .seanet import SEANetEncoder
19 | from .transformer import StreamingTransformerEncoder
20 |
--------------------------------------------------------------------------------
/academicodec/modules/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Convolutional layers wrappers and utilities."""
7 | import math
8 | import typing as tp
9 | import warnings
10 |
11 | import torch
12 | from torch import nn
13 | from torch.nn import functional as F
14 | from torch.nn.utils import spectral_norm
15 | from torch.nn.utils import weight_norm
16 |
17 | from academicodec.modules.norm import ConvLayerNorm
18 |
19 | CONV_NORMALIZATIONS = frozenset([
20 | 'none', 'weight_norm', 'spectral_norm', 'time_layer_norm', 'layer_norm',
21 | 'time_group_norm'
22 | ])
23 |
24 |
25 | def apply_parametrization_norm(module: nn.Module,
26 | norm: str='none') -> nn.Module:
27 | assert norm in CONV_NORMALIZATIONS
28 | if norm == 'weight_norm':
29 | return weight_norm(module)
30 | elif norm == 'spectral_norm':
31 | return spectral_norm(module)
32 | else:
33 | # We already check was in CONV_NORMALIZATION, so any other choice
34 | # doesn't need reparametrization.
35 | return module
36 |
37 |
38 | def get_norm_module(module: nn.Module,
39 | causal: bool=False,
40 | norm: str='none',
41 | **norm_kwargs) -> nn.Module:
42 | """Return the proper normalization module. If causal is True, this will ensure the returned
43 | module is causal, or return an error if the normalization doesn't support causal evaluation.
44 | """
45 | assert norm in CONV_NORMALIZATIONS
46 | if norm == 'layer_norm':
47 | assert isinstance(module, nn.modules.conv._ConvNd)
48 | return ConvLayerNorm(module.out_channels, **norm_kwargs)
49 | elif norm == 'time_group_norm':
50 | if causal:
51 | raise ValueError("GroupNorm doesn't support causal evaluation.")
52 | assert isinstance(module, nn.modules.conv._ConvNd)
53 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
54 | else:
55 | return nn.Identity()
56 |
57 |
58 | def get_extra_padding_for_conv1d(x: torch.Tensor,
59 | kernel_size: int,
60 | stride: int,
61 | padding_total: int=0) -> int:
62 | """See `pad_for_conv1d`.
63 | """
64 | length = x.shape[-1]
65 | n_frames = (length - kernel_size + padding_total) / stride + 1
66 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size -
67 | padding_total)
68 | return ideal_length - length
69 |
70 |
71 | def pad_for_conv1d(x: torch.Tensor,
72 | kernel_size: int,
73 | stride: int,
74 | padding_total: int=0):
75 | """Pad for a convolution to make sure that the last window is full.
76 | Extra padding is added at the end. This is required to ensure that we can rebuild
77 | an output of the same length, as otherwise, even with padding, some time steps
78 | might get removed.
79 | For instance, with total padding = 4, kernel size = 4, stride = 2:
80 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
81 | 1 2 3 # (output frames of a convolution, last 0 is never used)
82 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
83 | 1 2 3 4 # once you removed padding, we are missing one time step !
84 | """
85 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride,
86 | padding_total)
87 | return F.pad(x, (0, extra_padding))
88 |
89 |
90 | def pad1d(x: torch.Tensor,
91 | paddings: tp.Tuple[int, int],
92 | mode: str='zero',
93 | value: float=0.):
94 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
95 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
96 | """
97 | length = x.shape[-1]
98 | padding_left, padding_right = paddings
99 | assert padding_left >= 0 and padding_right >= 0, (padding_left,
100 | padding_right)
101 | if mode == 'reflect':
102 | max_pad = max(padding_left, padding_right)
103 | extra_pad = 0
104 | if length <= max_pad:
105 | extra_pad = max_pad - length + 1
106 | x = F.pad(x, (0, extra_pad))
107 | padded = F.pad(x, paddings, mode, value)
108 | end = padded.shape[-1] - extra_pad
109 | return padded[..., :end]
110 | else:
111 | return F.pad(x, paddings, mode, value)
112 |
113 |
114 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
115 | """Remove padding from x, handling properly zero padding. Only for 1d!"""
116 | padding_left, padding_right = paddings
117 | assert padding_left >= 0 and padding_right >= 0, (padding_left,
118 | padding_right)
119 | assert (padding_left + padding_right) <= x.shape[-1]
120 | end = x.shape[-1] - padding_right
121 | return x[..., padding_left:end]
122 |
123 |
124 | class NormConv1d(nn.Module):
125 | """Wrapper around Conv1d and normalization applied to this conv
126 | to provide a uniform interface across normalization approaches.
127 | """
128 |
129 | def __init__(self,
130 | *args,
131 | causal: bool=False,
132 | norm: str='none',
133 | norm_kwargs: tp.Dict[str, tp.Any]={},
134 | **kwargs):
135 | super().__init__()
136 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
137 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
138 | self.norm_type = norm
139 |
140 | def forward(self, x):
141 | x = self.conv(x)
142 | x = self.norm(x)
143 | return x
144 |
145 |
146 | class NormConv2d(nn.Module):
147 | """Wrapper around Conv2d and normalization applied to this conv
148 | to provide a uniform interface across normalization approaches.
149 | """
150 |
151 | def __init__(self,
152 | *args,
153 | norm: str='none',
154 | norm_kwargs: tp.Dict[str, tp.Any]={},
155 | **kwargs):
156 | super().__init__()
157 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
158 | self.norm = get_norm_module(
159 | self.conv, causal=False, norm=norm, **norm_kwargs)
160 | self.norm_type = norm
161 |
162 | def forward(self, x):
163 | x = self.conv(x)
164 | x = self.norm(x)
165 | return x
166 |
167 |
168 | class NormConvTranspose1d(nn.Module):
169 | """Wrapper around ConvTranspose1d and normalization applied to this conv
170 | to provide a uniform interface across normalization approaches.
171 | """
172 |
173 | def __init__(self,
174 | *args,
175 | causal: bool=False,
176 | norm: str='none',
177 | norm_kwargs: tp.Dict[str, tp.Any]={},
178 | **kwargs):
179 | super().__init__()
180 | self.convtr = apply_parametrization_norm(
181 | nn.ConvTranspose1d(*args, **kwargs), norm)
182 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
183 | self.norm_type = norm
184 |
185 | def forward(self, x):
186 | x = self.convtr(x)
187 | x = self.norm(x)
188 | return x
189 |
190 |
191 | class NormConvTranspose2d(nn.Module):
192 | """Wrapper around ConvTranspose2d and normalization applied to this conv
193 | to provide a uniform interface across normalization approaches.
194 | """
195 |
196 | def __init__(self,
197 | *args,
198 | norm: str='none',
199 | norm_kwargs: tp.Dict[str, tp.Any]={},
200 | **kwargs):
201 | super().__init__()
202 | self.convtr = apply_parametrization_norm(
203 | nn.ConvTranspose2d(*args, **kwargs), norm)
204 | self.norm = get_norm_module(
205 | self.convtr, causal=False, norm=norm, **norm_kwargs)
206 |
207 | def forward(self, x):
208 | x = self.convtr(x)
209 | x = self.norm(x)
210 | return x
211 |
212 |
213 | class SConv1d(nn.Module):
214 | """Conv1d with some builtin handling of asymmetric or causal padding
215 | and normalization.
216 | """
217 |
218 | def __init__(self,
219 | in_channels: int,
220 | out_channels: int,
221 | kernel_size: int,
222 | stride: int=1,
223 | dilation: int=1,
224 | groups: int=1,
225 | bias: bool=True,
226 | causal: bool=False,
227 | norm: str='none',
228 | norm_kwargs: tp.Dict[str, tp.Any]={},
229 | pad_mode: str='reflect'):
230 | super().__init__()
231 | # warn user on unusual setup between dilation and stride
232 | if stride > 1 and dilation > 1:
233 | warnings.warn(
234 | 'SConv1d has been initialized with stride > 1 and dilation > 1'
235 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).'
236 | )
237 | self.conv = NormConv1d(
238 | in_channels,
239 | out_channels,
240 | kernel_size,
241 | stride,
242 | dilation=dilation,
243 | groups=groups,
244 | bias=bias,
245 | causal=causal,
246 | norm=norm,
247 | norm_kwargs=norm_kwargs)
248 | self.causal = causal
249 | self.pad_mode = pad_mode
250 |
251 | def forward(self, x):
252 | B, C, T = x.shape
253 | kernel_size = self.conv.conv.kernel_size[0]
254 | stride = self.conv.conv.stride[0]
255 | dilation = self.conv.conv.dilation[0]
256 | padding_total = (kernel_size - 1) * dilation - (stride - 1)
257 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride,
258 | padding_total)
259 | if self.causal:
260 | # Left padding for causal
261 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
262 | else:
263 | # Asymmetric padding required for odd strides
264 | padding_right = padding_total // 2
265 | padding_left = padding_total - padding_right
266 | x = pad1d(
267 | x, (padding_left, padding_right + extra_padding),
268 | mode=self.pad_mode)
269 | return self.conv(x)
270 |
271 |
272 | class SConvTranspose1d(nn.Module):
273 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
274 | and normalization.
275 | """
276 |
277 | def __init__(self,
278 | in_channels: int,
279 | out_channels: int,
280 | kernel_size: int,
281 | stride: int=1,
282 | causal: bool=False,
283 | norm: str='none',
284 | trim_right_ratio: float=1.,
285 | norm_kwargs: tp.Dict[str, tp.Any]={}):
286 | super().__init__()
287 | self.convtr = NormConvTranspose1d(
288 | in_channels,
289 | out_channels,
290 | kernel_size,
291 | stride,
292 | causal=causal,
293 | norm=norm,
294 | norm_kwargs=norm_kwargs)
295 | self.causal = causal
296 | self.trim_right_ratio = trim_right_ratio
297 | assert self.causal or self.trim_right_ratio == 1., \
298 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
299 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
300 |
301 | def forward(self, x):
302 | kernel_size = self.convtr.convtr.kernel_size[0]
303 | stride = self.convtr.convtr.stride[0]
304 | padding_total = kernel_size - stride
305 |
306 | y = self.convtr(x)
307 |
308 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
309 | # removed at the very end, when keeping only the right length for the output,
310 | # as removing it here would require also passing the length at the matching layer
311 | # in the encoder.
312 | if self.causal:
313 | # Trim the padding on the right according to the specified ratio
314 | # if trim_right_ratio = 1.0, trim everything from right
315 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
316 | padding_left = padding_total - padding_right
317 | y = unpad1d(y, (padding_left, padding_right))
318 | else:
319 | # Asymmetric padding required for odd strides
320 | padding_right = padding_total // 2
321 | padding_left = padding_total - padding_right
322 | y = unpad1d(y, (padding_left, padding_right))
323 | return y
324 |
--------------------------------------------------------------------------------
/academicodec/modules/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """LSTM layers module."""
7 | from torch import nn
8 |
9 |
10 | class SLSTM(nn.Module):
11 | """
12 | LSTM without worrying about the hidden state, nor the layout of the data.
13 | Expects input as convolutional layout.
14 | """
15 |
16 | def __init__(self, dimension: int, num_layers: int=2, skip: bool=True):
17 | super().__init__()
18 | self.skip = skip
19 | self.lstm = nn.LSTM(dimension, dimension, num_layers)
20 |
21 | def forward(self, x):
22 | x = x.permute(2, 0, 1)
23 | y, _ = self.lstm(x)
24 | if self.skip:
25 | y = y + x
26 | y = y.permute(1, 2, 0)
27 | return y
28 |
--------------------------------------------------------------------------------
/academicodec/modules/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Normalization modules."""
7 | import typing as tp
8 |
9 | import einops
10 | import torch
11 | from torch import nn
12 |
13 |
14 | class ConvLayerNorm(nn.LayerNorm):
15 | """
16 | Convolution-friendly LayerNorm that moves channels to last dimensions
17 | before running the normalization and moves them back to original position right after.
18 | """
19 |
20 | def __init__(self,
21 | normalized_shape: tp.Union[int, tp.List[int], torch.Size],
22 | **kwargs):
23 | super().__init__(normalized_shape, **kwargs)
24 |
25 | def forward(self, x):
26 | x = einops.rearrange(x, 'b ... t -> b t ...')
27 | x = super().forward(x)
28 | x = einops.rearrange(x, 'b t ... -> b ... t')
29 | return
30 |
--------------------------------------------------------------------------------
/academicodec/modules/seanet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Encodec SEANet-based encoder and decoder implementation."""
7 | import typing as tp
8 |
9 | import numpy as np
10 | import torch.nn as nn
11 |
12 | from academicodec.modules import SConv1d
13 | from academicodec.modules import SConvTranspose1d
14 | from academicodec.modules import SLSTM
15 |
16 |
17 | class SEANetResnetBlock(nn.Module):
18 | """Residual block from SEANet model.
19 | Args:
20 | dim (int): Dimension of the input/output
21 | kernel_sizes (list): List of kernel sizes for the convolutions.
22 | dilations (list): List of dilations for the convolutions.
23 | activation (str): Activation function.
24 | activation_params (dict): Parameters to provide to the activation function
25 | norm (str): Normalization method.
26 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
27 | causal (bool): Whether to use fully causal convolution.
28 | pad_mode (str): Padding mode for the convolutions.
29 | compress (int): Reduced dimensionality in residual branches (from Demucs v3)
30 | true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
31 | """
32 |
33 | def __init__(self,
34 | dim: int,
35 | kernel_sizes: tp.List[int]=[3, 1],
36 | dilations: tp.List[int]=[1, 1],
37 | activation: str='ELU',
38 | activation_params: dict={'alpha': 1.0},
39 | norm: str='weight_norm',
40 | norm_params: tp.Dict[str, tp.Any]={},
41 | causal: bool=False,
42 | pad_mode: str='reflect',
43 | compress: int=2,
44 | true_skip: bool=True):
45 | super().__init__()
46 | assert len(kernel_sizes) == len(
47 | dilations), 'Number of kernel sizes should match number of dilations'
48 | act = getattr(nn, activation)
49 | hidden = dim // compress
50 | block = []
51 | for i, (kernel_size,
52 | dilation) in enumerate(zip(kernel_sizes, dilations)):
53 | in_chs = dim if i == 0 else hidden
54 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden
55 | block += [
56 | act(**activation_params),
57 | SConv1d(
58 | in_chs,
59 | out_chs,
60 | kernel_size=kernel_size,
61 | dilation=dilation,
62 | norm=norm,
63 | norm_kwargs=norm_params,
64 | causal=causal,
65 | pad_mode=pad_mode),
66 | ]
67 | self.block = nn.Sequential(*block)
68 | self.shortcut: nn.Module
69 | if true_skip:
70 | self.shortcut = nn.Identity()
71 | else:
72 | self.shortcut = SConv1d(
73 | dim,
74 | dim,
75 | kernel_size=1,
76 | norm=norm,
77 | norm_kwargs=norm_params,
78 | causal=causal,
79 | pad_mode=pad_mode)
80 |
81 | def forward(self, x):
82 | return self.shortcut(x) + self.block(x)
83 |
84 |
85 | class SEANetEncoder(nn.Module):
86 | """SEANet encoder.
87 | Args:
88 | channels (int): Audio channels.
89 | dimension (int): Intermediate representation dimension.
90 | n_filters (int): Base width for the model.
91 | n_residual_layers (int): nb of residual layers.
92 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
93 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
94 | that must match the decoder order
95 | activation (str): Activation function.
96 | activation_params (dict): Parameters to provide to the activation function
97 | norm (str): Normalization method.
98 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
99 | kernel_size (int): Kernel size for the initial convolution.
100 | last_kernel_size (int): Kernel size for the initial convolution.
101 | residual_kernel_size (int): Kernel size for the residual layers.
102 | dilation_base (int): How much to increase the dilation with each layer.
103 | causal (bool): Whether to use fully causal convolution.
104 | pad_mode (str): Padding mode for the convolutions.
105 | true_skip (bool): Whether to use true skip connection or a simple
106 | (streamable) convolution as the skip connection in the residual network blocks.
107 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
108 | lstm (int): Number of LSTM layers at the end of the encoder.
109 | """
110 |
111 | def __init__(self,
112 | channels: int=1,
113 | dimension: int=128,
114 | n_filters: int=32,
115 | n_residual_layers: int=1,
116 | ratios: tp.List[int]=[8, 5, 4, 2],
117 | activation: str='ELU',
118 | activation_params: dict={'alpha': 1.0},
119 | norm: str='weight_norm',
120 | norm_params: tp.Dict[str, tp.Any]={},
121 | kernel_size: int=7,
122 | last_kernel_size: int=7,
123 | residual_kernel_size: int=3,
124 | dilation_base: int=2,
125 | causal: bool=False,
126 | pad_mode: str='reflect',
127 | true_skip: bool=False,
128 | compress: int=2,
129 | lstm: int=2):
130 | super().__init__()
131 | self.channels = channels
132 | self.dimension = dimension
133 | self.n_filters = n_filters
134 | self.ratios = list(reversed(ratios))
135 | del ratios
136 | self.n_residual_layers = n_residual_layers
137 | self.hop_length = np.prod(self.ratios) # 计算乘积
138 |
139 | act = getattr(nn, activation)
140 | mult = 1
141 | model: tp.List[nn.Module] = [
142 | SConv1d(
143 | channels,
144 | mult * n_filters,
145 | kernel_size,
146 | norm=norm,
147 | norm_kwargs=norm_params,
148 | causal=causal,
149 | pad_mode=pad_mode)
150 | ]
151 | # Downsample to raw audio scale
152 | for i, ratio in enumerate(self.ratios):
153 | # Add residual layers
154 | for j in range(n_residual_layers):
155 | model += [
156 | SEANetResnetBlock(
157 | mult * n_filters,
158 | kernel_sizes=[residual_kernel_size, 1],
159 | dilations=[dilation_base**j, 1],
160 | norm=norm,
161 | norm_params=norm_params,
162 | activation=activation,
163 | activation_params=activation_params,
164 | causal=causal,
165 | pad_mode=pad_mode,
166 | compress=compress,
167 | true_skip=true_skip)
168 | ]
169 |
170 | # Add downsampling layers
171 | model += [
172 | act(**activation_params),
173 | SConv1d(
174 | mult * n_filters,
175 | mult * n_filters * 2,
176 | kernel_size=ratio * 2,
177 | stride=ratio,
178 | norm=norm,
179 | norm_kwargs=norm_params,
180 | causal=causal,
181 | pad_mode=pad_mode),
182 | ]
183 | mult *= 2
184 |
185 | if lstm:
186 | model += [SLSTM(mult * n_filters, num_layers=lstm)]
187 |
188 | model += [
189 | act(**activation_params), SConv1d(
190 | mult * n_filters,
191 | dimension,
192 | last_kernel_size,
193 | norm=norm,
194 | norm_kwargs=norm_params,
195 | causal=causal,
196 | pad_mode=pad_mode)
197 | ]
198 |
199 | self.model = nn.Sequential(*model)
200 |
201 | def forward(self, x):
202 | return self.model(x)
203 |
204 |
205 | class SEANetDecoder(nn.Module):
206 | """SEANet decoder.
207 | Args:
208 | channels (int): Audio channels.
209 | dimension (int): Intermediate representation dimension.
210 | n_filters (int): Base width for the model.
211 | n_residual_layers (int): nb of residual layers.
212 | ratios (Sequence[int]): kernel size and stride ratios
213 | activation (str): Activation function.
214 | activation_params (dict): Parameters to provide to the activation function
215 | final_activation (str): Final activation function after all convolutions.
216 | final_activation_params (dict): Parameters to provide to the activation function
217 | norm (str): Normalization method.
218 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
219 | kernel_size (int): Kernel size for the initial convolution.
220 | last_kernel_size (int): Kernel size for the initial convolution.
221 | residual_kernel_size (int): Kernel size for the residual layers.
222 | dilation_base (int): How much to increase the dilation with each layer.
223 | causal (bool): Whether to use fully causal convolution.
224 | pad_mode (str): Padding mode for the convolutions.
225 | true_skip (bool): Whether to use true skip connection or a simple
226 | (streamable) convolution as the skip connection in the residual network blocks.
227 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
228 | lstm (int): Number of LSTM layers at the end of the encoder.
229 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
230 | If equal to 1.0, it means that all the trimming is done at the right.
231 | """
232 |
233 | def __init__(self,
234 | channels: int=1,
235 | dimension: int=128,
236 | n_filters: int=32,
237 | n_residual_layers: int=1,
238 | ratios: tp.List[int]=[8, 5, 4, 2],
239 | activation: str='ELU',
240 | activation_params: dict={'alpha': 1.0},
241 | final_activation: tp.Optional[str]=None,
242 | final_activation_params: tp.Optional[dict]=None,
243 | norm: str='weight_norm',
244 | norm_params: tp.Dict[str, tp.Any]={},
245 | kernel_size: int=7,
246 | last_kernel_size: int=7,
247 | residual_kernel_size: int=3,
248 | dilation_base: int=2,
249 | causal: bool=False,
250 | pad_mode: str='reflect',
251 | true_skip: bool=False,
252 | compress: int=2,
253 | lstm: int=2,
254 | trim_right_ratio: float=1.0):
255 | super().__init__()
256 | self.dimension = dimension
257 | self.channels = channels
258 | self.n_filters = n_filters
259 | self.ratios = ratios
260 | del ratios
261 | self.n_residual_layers = n_residual_layers
262 | self.hop_length = np.prod(self.ratios)
263 |
264 | act = getattr(nn, activation)
265 | mult = int(2**len(self.ratios))
266 | model: tp.List[nn.Module] = [
267 | SConv1d(
268 | dimension,
269 | mult * n_filters,
270 | kernel_size,
271 | norm=norm,
272 | norm_kwargs=norm_params,
273 | causal=causal,
274 | pad_mode=pad_mode)
275 | ]
276 |
277 | if lstm:
278 | model += [SLSTM(mult * n_filters, num_layers=lstm)]
279 |
280 | # Upsample to raw audio scale
281 | for i, ratio in enumerate(self.ratios):
282 | # Add upsampling layers
283 | model += [
284 | act(**activation_params),
285 | SConvTranspose1d(
286 | mult * n_filters,
287 | mult * n_filters // 2,
288 | kernel_size=ratio * 2,
289 | stride=ratio,
290 | norm=norm,
291 | norm_kwargs=norm_params,
292 | causal=causal,
293 | trim_right_ratio=trim_right_ratio),
294 | ]
295 | # Add residual layers
296 | for j in range(n_residual_layers):
297 | model += [
298 | SEANetResnetBlock(
299 | mult * n_filters // 2,
300 | kernel_sizes=[residual_kernel_size, 1],
301 | dilations=[dilation_base**j, 1],
302 | activation=activation,
303 | activation_params=activation_params,
304 | norm=norm,
305 | norm_params=norm_params,
306 | causal=causal,
307 | pad_mode=pad_mode,
308 | compress=compress,
309 | true_skip=true_skip)
310 | ]
311 |
312 | mult //= 2
313 |
314 | # Add final layers
315 | model += [
316 | act(**activation_params), SConv1d(
317 | n_filters,
318 | channels,
319 | last_kernel_size,
320 | norm=norm,
321 | norm_kwargs=norm_params,
322 | causal=causal,
323 | pad_mode=pad_mode)
324 | ]
325 | # Add optional final activation to decoder (eg. tanh)
326 | if final_activation is not None:
327 | final_act = getattr(nn, final_activation)
328 | final_activation_params = final_activation_params or {}
329 | model += [final_act(**final_activation_params)]
330 | self.model = nn.Sequential(*model)
331 |
332 | def forward(self, z):
333 | y = self.model(z)
334 | return y
335 |
336 |
337 | def test():
338 | import torch
339 | encoder = SEANetEncoder()
340 | decoder = SEANetDecoder()
341 | x = torch.randn(1, 1, 24000)
342 | z = encoder(x)
343 | print('z ', z.shape)
344 | assert 1 == 2
345 | assert list(z.shape) == [1, 128, 75], z.shape
346 | y = decoder(z)
347 | assert y.shape == x.shape, (x.shape, y.shape)
348 |
349 |
350 | if __name__ == '__main__':
351 | test()
352 |
--------------------------------------------------------------------------------
/academicodec/modules/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """A streamable transformer."""
7 | import typing as tp
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | def create_sin_embedding(positions: torch.Tensor,
15 | dim: int,
16 | max_period: float=10000):
17 | """Create time embedding for the given positions, target dimension `dim`.
18 | """
19 | # We aim for BTC format
20 | assert dim % 2 == 0
21 | half_dim = dim // 2
22 | adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
23 | phase = positions / (max_period**(adim / (half_dim - 1)))
24 | return torch.cat(
25 | [
26 | torch.cos(phase),
27 | torch.sin(phase),
28 | ], dim=-1)
29 |
30 |
31 | class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
32 | def forward(self, x: torch.Tensor, x_past: torch.Tensor,
33 | past_context: int): # type: ignore
34 | if self.norm_first:
35 | sa_input = self.norm1(x)
36 | x = x + self._sa_block(sa_input, x_past, past_context)
37 | x = x + self._ff_block(self.norm2(x))
38 | else:
39 | sa_input = x
40 | x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
41 | x = self.norm2(x + self._ff_block(x))
42 |
43 | return x, sa_input
44 |
45 | # self-attention block
46 | def _sa_block(self,
47 | x: torch.Tensor,
48 | x_past: torch.Tensor,
49 | past_context: int): # type: ignore
50 | _, T, _ = x.shape
51 | _, H, _ = x_past.shape
52 |
53 | queries = x
54 | keys = torch.cat([x_past, x], dim=1)
55 | values = keys
56 |
57 | queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
58 | keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
59 | delta = queries_pos - keys_pos
60 | valid_access = (delta >= 0) & (delta <= past_context)
61 | x = self.self_attn(
62 | queries, keys, values, attn_mask=~valid_access,
63 | need_weights=False)[0]
64 | return self.dropout1(x)
65 |
66 |
67 | class StreamingTransformerEncoder(nn.Module):
68 | """TransformerEncoder with streaming support.
69 |
70 | Args:
71 | dim (int): dimension of the data.
72 | hidden_scale (int): intermediate dimension of FF module is this times the dimension.
73 | num_heads (int): number of heads.
74 | num_layers (int): number of layers.
75 | max_period (float): maxium period of cosines in the positional embedding.
76 | past_context (int or None): receptive field for the causal mask, infinite if None.
77 | gelu (bool): if true uses GeLUs, otherwise use ReLUs.
78 | norm_in (bool): normalize the input.
79 | dropout (float): dropout probability.
80 | **kwargs: See `nn.TransformerEncoderLayer`.
81 | """
82 |
83 | def __init__(self,
84 | dim,
85 | hidden_scale: float=4.,
86 | num_heads: int=8,
87 | num_layers: int=5,
88 | max_period: float=10000,
89 | past_context: int=1000,
90 | gelu: bool=True,
91 | norm_in: bool=True,
92 | dropout: float=0.,
93 | **kwargs):
94 | super().__init__()
95 | assert dim % num_heads == 0
96 | hidden_dim = int(dim * hidden_scale)
97 |
98 | self.max_period = max_period
99 | self.past_context = past_context
100 | activation: tp.Any = F.gelu if gelu else F.relu
101 |
102 | self.norm_in: nn.Module
103 | if norm_in:
104 | self.norm_in = nn.LayerNorm(dim)
105 | else:
106 | self.norm_in = nn.Identity()
107 |
108 | self.layers = nn.ModuleList()
109 | for idx in range(num_layers):
110 | self.layers.append(
111 | StreamingTransformerEncoderLayer(
112 | dim,
113 | num_heads,
114 | hidden_dim,
115 | activation=activation,
116 | batch_first=True,
117 | dropout=dropout,
118 | **kwargs))
119 |
120 | def forward(self,
121 | x: torch.Tensor,
122 | states: tp.Optional[tp.List[torch.Tensor]]=None,
123 | offset: tp.Union[int, torch.Tensor]=0):
124 | B, T, C = x.shape
125 | if states is None:
126 | states = [
127 | torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))
128 | ]
129 |
130 | positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
131 | pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
132 |
133 | new_state: tp.List[torch.Tensor] = []
134 | x = self.norm_in(x)
135 | x = x + pos_emb
136 |
137 | for layer_state, layer in zip(states, self.layers):
138 | x, new_layer_state = layer(x, layer_state, self.past_context)
139 | new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
140 | new_state.append(new_layer_state[:, -self.past_context:, :])
141 | return x, new_state, offset + T
142 |
--------------------------------------------------------------------------------
/academicodec/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # flake8: noqa
7 | from .vq import QuantizedResult
8 | from .vq import ResidualVectorQuantizer
9 |
--------------------------------------------------------------------------------
/academicodec/quantization/ac.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Arithmetic coder."""
7 | import io
8 | import math
9 | import random
10 | import typing as tp
11 |
12 | import torch
13 |
14 | from academicodec.binary import BitPacker
15 | from academicodec.binary import BitUnpacker
16 |
17 |
18 | def build_stable_quantized_cdf(pdf: torch.Tensor,
19 | total_range_bits: int,
20 | roundoff: float=1e-8,
21 | min_range: int=2,
22 | check: bool=True) -> torch.Tensor:
23 | """Turn the given PDF into a quantized CDF that splits
24 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
25 | to the PDF.
26 |
27 | Args:
28 | pdf (torch.Tensor): probability distribution, shape should be `[N]`.
29 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
30 | during the coding process is `[0, 2 ** total_range_bits - 1]`.
31 | roundoff (float): will round the pdf up to that level to remove difference coming
32 | from e.g. evaluating the Language Model on different architectures.
33 | min_range (int): minimum range width. Should always be at least 2 for numerical
34 | stability. Use this to avoid pathological behavior is a value
35 | that is expected to be rare actually happens in real life.
36 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
37 | """
38 | pdf = pdf.detach()
39 | if roundoff:
40 | pdf = (pdf / roundoff).floor() * roundoff
41 | # interpolate with uniform distribution to achieve desired minimum probability.
42 | total_range = 2**total_range_bits
43 | cardinality = len(pdf)
44 | alpha = min_range * cardinality / total_range
45 | assert alpha <= 1, "you must reduce min_range"
46 | ranges = (((1 - alpha) * total_range) * pdf).floor().long()
47 | ranges += min_range
48 | quantized_cdf = torch.cumsum(ranges, dim=-1)
49 | if min_range < 2:
50 | raise ValueError("min_range must be at least 2.")
51 | if check:
52 | assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
53 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range
54 | ).any() or quantized_cdf[0] < min_range:
55 | raise ValueError("You must increase your total_range_bits.")
56 | return quantized_cdf
57 |
58 |
59 | class ArithmeticCoder:
60 | """ArithmeticCoder,
61 | Let us take a distribution `p` over `N` symbols, and assume we have a stream
62 | of random variables `s_t` sampled from `p`. Let us assume that we have a budget
63 | of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
64 | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
65 | sequence `(s_t)` by doing the following:
66 |
67 | 1) Initialize the current range to` [0 ** 2 B - 1]`.
68 | 2) For each time step t, split the current range into contiguous chunks,
69 | one for each possible outcome, with size roughly proportional to `p`.
70 | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
71 | would be `{[0, 2], [3, 3]}`.
72 | 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
73 | 4) When done encoding all the values, just select any value remaining in the range.
74 |
75 | You will notice that this procedure can fail: for instance if at any point in time
76 | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
77 | possible outcome. Intuitively, the more likely a value is, the less the range width
78 | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
79 | coding scheme, likely outcomes would take less bits, and more of them can be coded
80 | with a fixed budget.
81 |
82 | In practice, we do not know `B` ahead of time, but we have a way to inject new bits
83 | when the current range decreases below a given limit (given by `total_range_bits`), without
84 | having to redo all the computations. If we encode mostly likely values, we will seldom
85 | need to inject new bits, but a single rare value can deplete our stock of entropy!
86 |
87 | In this explanation, we assumed that the distribution `p` was constant. In fact, the present
88 | code works for any sequence `(p_t)` possibly different for each timestep.
89 | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
90 | the KL between the true distribution and `p_t`, the most efficient the coding will be.
91 |
92 | Args:
93 | fo (IO[bytes]): file-like object to which the bytes will be written to.
94 | total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
95 | Any time the current range width fall under this limit, new bits will
96 | be injected to rescale the initial range.
97 | """
98 |
99 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24):
100 | assert total_range_bits <= 30
101 | self.total_range_bits = total_range_bits
102 | self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
103 | self.low: int = 0
104 | self.high: int = 0
105 | self.max_bit: int = -1
106 | self._dbg: tp.List[tp.Any] = []
107 | self._dbg2: tp.List[tp.Any] = []
108 |
109 | @property
110 | def delta(self) -> int:
111 | """Return the current range width."""
112 | return self.high - self.low + 1
113 |
114 | def _flush_common_prefix(self):
115 | # If self.low and self.high start with the sames bits,
116 | # those won't change anymore as we always just increase the range
117 | # by powers of 2, and we can flush them out to the bit stream.
118 | assert self.high >= self.low, (self.low, self.high)
119 | assert self.high < 2**(self.max_bit + 1)
120 | while self.max_bit >= 0:
121 | b1 = self.low >> self.max_bit
122 | b2 = self.high >> self.max_bit
123 | if b1 == b2:
124 | self.low -= (b1 << self.max_bit)
125 | self.high -= (b1 << self.max_bit)
126 | assert self.high >= self.low, (self.high, self.low,
127 | self.max_bit)
128 | assert self.low >= 0
129 | self.max_bit -= 1
130 | self.packer.push(b1)
131 | else:
132 | break
133 |
134 | def push(self, symbol: int, quantized_cdf: torch.Tensor):
135 | """Push the given symbol on the stream, flushing out bits
136 | if possible.
137 |
138 | Args:
139 | symbol (int): symbol to encode with the AC.
140 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
141 | to build this from your pdf estimate.
142 | """
143 | while self.delta < 2**self.total_range_bits:
144 | self.low *= 2
145 | self.high = self.high * 2 + 1
146 | self.max_bit += 1
147 |
148 | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
149 | range_high = quantized_cdf[symbol].item() - 1
150 | effective_low = int(
151 | math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
152 | effective_high = int(
153 | math.floor(range_high * (self.delta / (2**self.total_range_bits))))
154 | assert self.low <= self.high
155 | self.high = self.low + effective_high
156 | self.low = self.low + effective_low
157 | assert self.low <= self.high, (effective_low, effective_high, range_low,
158 | range_high)
159 | self._dbg.append((self.low, self.high))
160 | self._dbg2.append((self.low, self.high))
161 | outs = self._flush_common_prefix()
162 | assert self.low <= self.high
163 | assert self.max_bit >= -1
164 | assert self.max_bit <= 61, self.max_bit
165 | return outs
166 |
167 | def flush(self):
168 | """Flush the remaining information to the stream.
169 | """
170 | while self.max_bit >= 0:
171 | b1 = (self.low >> self.max_bit) & 1
172 | self.packer.push(b1)
173 | self.max_bit -= 1
174 | self.packer.flush()
175 |
176 |
177 | class ArithmeticDecoder:
178 | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
179 |
180 | Note that this must be called with **exactly** the same parameters and sequence
181 | of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
182 |
183 | If the AC encoder current range is [L, H], with `L` and `H` having the some common
184 | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
185 | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
186 | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
187 | for a specific sequence of symbols and a binary-search allows us to decode those symbols.
188 | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
189 | and we will need to read new bits from the stream and repeat the process.
190 |
191 | """
192 |
193 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int=24):
194 | self.total_range_bits = total_range_bits
195 | self.low: int = 0
196 | self.high: int = 0
197 | self.current: int = 0
198 | self.max_bit: int = -1
199 | self.unpacker = BitUnpacker(
200 | bits=1, fo=fo) # we pull single bits at a time.
201 | # Following is for debugging
202 | self._dbg: tp.List[tp.Any] = []
203 | self._dbg2: tp.List[tp.Any] = []
204 | self._last: tp.Any = None
205 |
206 | @property
207 | def delta(self) -> int:
208 | return self.high - self.low + 1
209 |
210 | def _flush_common_prefix(self):
211 | # Given the current range [L, H], if both have a common prefix,
212 | # we know we can remove it from our representation to avoid handling large numbers.
213 | while self.max_bit >= 0:
214 | b1 = self.low >> self.max_bit
215 | b2 = self.high >> self.max_bit
216 | if b1 == b2:
217 | self.low -= (b1 << self.max_bit)
218 | self.high -= (b1 << self.max_bit)
219 | self.current -= (b1 << self.max_bit)
220 | assert self.high >= self.low
221 | assert self.low >= 0
222 | self.max_bit -= 1
223 | else:
224 | break
225 |
226 | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
227 | """Pull a symbol, reading as many bits from the stream as required.
228 | This returns `None` when the stream has been exhausted.
229 |
230 | Args:
231 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
232 | to build this from your pdf estimate. This must be **exatly**
233 | the same cdf as the one used at encoding time.
234 | """
235 | while self.delta < 2**self.total_range_bits:
236 | bit = self.unpacker.pull()
237 | if bit is None:
238 | return None
239 | self.low *= 2
240 | self.high = self.high * 2 + 1
241 | self.current = self.current * 2 + bit
242 | self.max_bit += 1
243 |
244 | def bin_search(low_idx: int, high_idx: int):
245 | # Binary search is not just for coding interviews :)
246 | if high_idx < low_idx:
247 | raise RuntimeError("Binary search failed")
248 | mid = (low_idx + high_idx) // 2
249 | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
250 | range_high = quantized_cdf[mid].item() - 1
251 | effective_low = int(
252 | math.ceil(range_low * (self.delta / (2**self.total_range_bits)
253 | )))
254 | effective_high = int(
255 | math.floor(range_high * (self.delta / (2**self.total_range_bits)
256 | )))
257 | low = effective_low + self.low
258 | high = effective_high + self.low
259 | if self.current >= low:
260 | if self.current <= high:
261 | return (mid, low, high, self.current)
262 | else:
263 | return bin_search(mid + 1, high_idx)
264 | else:
265 | return bin_search(low_idx, mid - 1)
266 |
267 | self._last = (self.low, self.high, self.current, self.max_bit)
268 | sym, self.low, self.high, self.current = bin_search(
269 | 0, len(quantized_cdf) - 1)
270 | self._dbg.append((self.low, self.high, self.current))
271 | self._flush_common_prefix()
272 | self._dbg2.append((self.low, self.high, self.current))
273 |
274 | return sym
275 |
276 |
277 | def test():
278 | torch.manual_seed(1234)
279 | random.seed(1234)
280 | for _ in range(4):
281 | pdfs = []
282 | cardinality = random.randrange(4000)
283 | steps = random.randrange(100, 500)
284 | fo = io.BytesIO()
285 | encoder = ArithmeticCoder(fo)
286 | symbols = []
287 | for step in range(steps):
288 | pdf = torch.softmax(torch.randn(cardinality), dim=0)
289 | pdfs.append(pdf)
290 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
291 | symbol = torch.multinomial(pdf, 1).item()
292 | symbols.append(symbol)
293 | encoder.push(symbol, q_cdf)
294 | encoder.flush()
295 |
296 | fo.seek(0)
297 | decoder = ArithmeticDecoder(fo)
298 | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
299 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
300 | decoded_symbol = decoder.pull(q_cdf)
301 | assert decoded_symbol == symbol, idx
302 | assert decoder.pull(torch.zeros(1)) is None
303 |
304 |
305 | if __name__ == "__main__":
306 | test()
307 |
--------------------------------------------------------------------------------
/academicodec/quantization/core_vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # This implementation is inspired from
8 | # https://github.com/lucidrains/vector-quantize-pytorch
9 | # which is released under MIT License. Hereafter, the original license:
10 | # MIT License
11 | #
12 | # Copyright (c) 2020 Phil Wang
13 | #
14 | # Permission is hereby granted, free of charge, to any person obtaining a copy
15 | # of this software and associated documentation files (the "Software"), to deal
16 | # in the Software without restriction, including without limitation the rights
17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18 | # copies of the Software, and to permit persons to whom the Software is
19 | # furnished to do so, subject to the following conditions:
20 | #
21 | # The above copyright notice and this permission notice shall be included in all
22 | # copies or substantial portions of the Software.
23 | #
24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 | # SOFTWARE.
31 | """Core vector quantization implementation."""
32 | import typing as tp
33 |
34 | import torch
35 | import torch.nn.functional as F
36 | from einops import rearrange
37 | from einops import repeat
38 | from torch import nn
39 |
40 | from academicodec.quantization.distrib import broadcast_tensors
41 |
42 |
43 | def default(val: tp.Any, d: tp.Any) -> tp.Any:
44 | return val if val is not None else d
45 |
46 |
47 | def ema_inplace(moving_avg, new, decay: float):
48 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
49 |
50 |
51 | def laplace_smoothing(x, n_categories: int, epsilon: float=1e-5):
52 | return (x + epsilon) / (x.sum() + n_categories * epsilon)
53 |
54 |
55 | def uniform_init(*shape: int):
56 | t = torch.empty(shape)
57 | nn.init.kaiming_uniform_(t)
58 | return t
59 |
60 |
61 | def sample_vectors(samples, num: int):
62 | num_samples, device = samples.shape[0], samples.device
63 |
64 | if num_samples >= num:
65 | indices = torch.randperm(num_samples, device=device)[:num]
66 | else:
67 | indices = torch.randint(0, num_samples, (num, ), device=device)
68 |
69 | return samples[indices]
70 |
71 |
72 | def kmeans(samples, num_clusters: int, num_iters: int=10):
73 | dim, dtype = samples.shape[-1], samples.dtype
74 |
75 | means = sample_vectors(samples, num_clusters)
76 |
77 | for _ in range(num_iters):
78 | diffs = rearrange(samples, "n d -> n () d") - rearrange(means,
79 | "c d -> () c d")
80 | dists = -(diffs**2).sum(dim=-1)
81 |
82 | buckets = dists.max(dim=-1).indices
83 | bins = torch.bincount(buckets, minlength=num_clusters)
84 | zero_mask = bins == 0
85 | bins_min_clamped = bins.masked_fill(zero_mask, 1)
86 |
87 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89 | new_means = new_means / bins_min_clamped[..., None]
90 |
91 | means = torch.where(zero_mask[..., None], means, new_means)
92 |
93 | return means, bins
94 |
95 |
96 | class EuclideanCodebook(nn.Module):
97 | """Codebook with Euclidean distance.
98 | Args:
99 | dim (int): Dimension.
100 | codebook_size (int): Codebook size.
101 | kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102 | If set to true, run the k-means algorithm on the first training batch and use
103 | the learned centroids as initialization.
104 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105 | decay (float): Decay for exponential moving average over the codebooks.
106 | epsilon (float): Epsilon value for numerical stability.
107 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108 | that have an exponential moving average cluster size less than the specified threshold with
109 | randomly selected vector from the current batch.
110 | """
111 |
112 | def __init__(
113 | self,
114 | dim: int,
115 | codebook_size: int,
116 | kmeans_init: int=False,
117 | kmeans_iters: int=10,
118 | decay: float=0.99,
119 | epsilon: float=1e-5,
120 | threshold_ema_dead_code: int=2, ):
121 | super().__init__()
122 | self.decay = decay
123 | init_fn: tp.Union[
124 | tp.Callable[..., torch.Tensor],
125 | tp.Any] = uniform_init if not kmeans_init else torch.zeros
126 | embed = init_fn(codebook_size, dim)
127 |
128 | self.codebook_size = codebook_size
129 |
130 | self.kmeans_iters = kmeans_iters
131 | self.epsilon = epsilon
132 | self.threshold_ema_dead_code = threshold_ema_dead_code
133 |
134 | self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135 | self.register_buffer("cluster_size", torch.zeros(codebook_size))
136 | self.register_buffer("embed", embed)
137 | self.register_buffer("embed_avg", embed.clone())
138 |
139 | @torch.jit.ignore
140 | def init_embed_(self, data):
141 | if self.inited:
142 | return
143 |
144 | embed, cluster_size = kmeans(data, self.codebook_size,
145 | self.kmeans_iters)
146 | self.embed.data.copy_(embed)
147 | self.embed_avg.data.copy_(embed.clone())
148 | self.cluster_size.data.copy_(cluster_size)
149 | self.inited.data.copy_(torch.Tensor([True]))
150 | # Make sure all buffers across workers are in sync after initialization
151 | broadcast_tensors(self.buffers())
152 |
153 | def replace_(self, samples, mask):
154 | modified_codebook = torch.where(
155 | mask[..., None],
156 | sample_vectors(samples, self.codebook_size), self.embed)
157 | self.embed.data.copy_(modified_codebook)
158 |
159 | def expire_codes_(self, batch_samples):
160 | if self.threshold_ema_dead_code == 0:
161 | return
162 |
163 | expired_codes = self.cluster_size < self.threshold_ema_dead_code
164 | if not torch.any(expired_codes):
165 | return
166 |
167 | batch_samples = rearrange(batch_samples, "... d -> (...) d")
168 | self.replace_(batch_samples, mask=expired_codes)
169 | broadcast_tensors(self.buffers())
170 |
171 | def preprocess(self, x):
172 | x = rearrange(x, "... d -> (...) d")
173 | return x
174 |
175 | def quantize(self, x):
176 | embed = self.embed.t()
177 | dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
178 | embed.pow(2).sum(0, keepdim=True))
179 | embed_ind = dist.max(dim=-1).indices
180 | return embed_ind
181 |
182 | def postprocess_emb(self, embed_ind, shape):
183 | return embed_ind.view(*shape[:-1])
184 |
185 | def dequantize(self, embed_ind):
186 | quantize = F.embedding(embed_ind, self.embed)
187 | return quantize
188 |
189 | def encode(self, x):
190 | shape = x.shape
191 | # pre-process
192 | x = self.preprocess(x)
193 | # quantize
194 | embed_ind = self.quantize(x)
195 | # post-process
196 | embed_ind = self.postprocess_emb(embed_ind, shape)
197 | return embed_ind
198 |
199 | def decode(self, embed_ind):
200 | quantize = self.dequantize(embed_ind)
201 | return quantize
202 |
203 | def forward(self, x):
204 | shape, dtype = x.shape, x.dtype
205 | x = self.preprocess(x)
206 |
207 | self.init_embed_(x)
208 |
209 | embed_ind = self.quantize(x)
210 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
211 | embed_ind = self.postprocess_emb(embed_ind, shape)
212 | quantize = self.dequantize(embed_ind)
213 |
214 | if self.training:
215 | # We do the expiry of code at that point as buffers are in sync
216 | # and all the workers will take the same decision.
217 | self.expire_codes_(x)
218 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
219 | embed_sum = x.t() @ embed_onehot
220 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
221 | cluster_size = (
222 | laplace_smoothing(self.cluster_size, self.codebook_size,
223 | self.epsilon) * self.cluster_size.sum())
224 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
225 | self.embed.data.copy_(embed_normalized)
226 |
227 | return quantize, embed_ind
228 |
229 |
230 | class VectorQuantization(nn.Module):
231 | """Vector quantization implementation.
232 | Currently supports only euclidean distance.
233 | Args:
234 | dim (int): Dimension
235 | codebook_size (int): Codebook size
236 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
237 | decay (float): Decay for exponential moving average over the codebooks.
238 | epsilon (float): Epsilon value for numerical stability.
239 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
240 | kmeans_iters (int): Number of iterations used for kmeans initialization.
241 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242 | that have an exponential moving average cluster size less than the specified threshold with
243 | randomly selected vector from the current batch.
244 | commitment_weight (float): Weight for commitment loss.
245 | """
246 |
247 | def __init__(
248 | self,
249 | dim: int,
250 | codebook_size: int,
251 | codebook_dim: tp.Optional[int]=None,
252 | decay: float=0.99,
253 | epsilon: float=1e-5,
254 | kmeans_init: bool=True,
255 | kmeans_iters: int=50,
256 | threshold_ema_dead_code: int=2,
257 | commitment_weight: float=1., ):
258 | super().__init__()
259 | _codebook_dim: int = default(codebook_dim, dim)
260 |
261 | requires_projection = _codebook_dim != dim
262 | self.project_in = (nn.Linear(dim, _codebook_dim)
263 | if requires_projection else nn.Identity())
264 | self.project_out = (nn.Linear(_codebook_dim, dim)
265 | if requires_projection else nn.Identity())
266 |
267 | self.epsilon = epsilon
268 | self.commitment_weight = commitment_weight
269 |
270 | self._codebook = EuclideanCodebook(
271 | dim=_codebook_dim,
272 | codebook_size=codebook_size,
273 | kmeans_init=kmeans_init,
274 | kmeans_iters=kmeans_iters,
275 | decay=decay,
276 | epsilon=epsilon,
277 | threshold_ema_dead_code=threshold_ema_dead_code)
278 | self.codebook_size = codebook_size
279 |
280 | @property
281 | def codebook(self):
282 | return self._codebook.embed
283 |
284 | def encode(self, x):
285 | x = rearrange(x, "b d n -> b n d")
286 | x = self.project_in(x)
287 | embed_in = self._codebook.encode(x)
288 | return embed_in
289 |
290 | def decode(self, embed_ind):
291 | quantize = self._codebook.decode(embed_ind)
292 | quantize = self.project_out(quantize)
293 | quantize = rearrange(quantize, "b n d -> b d n")
294 | return quantize
295 |
296 | def forward(self, x):
297 | device = x.device
298 | x = rearrange(x, "b d n -> b n d")
299 | x = self.project_in(x)
300 |
301 | quantize, embed_ind = self._codebook(x)
302 |
303 | if self.training:
304 | quantize = x + (quantize - x).detach()
305 |
306 | loss = torch.tensor([0.0], device=device, requires_grad=self.training)
307 |
308 | if self.training:
309 | if self.commitment_weight > 0:
310 | commit_loss = F.mse_loss(quantize.detach(), x)
311 | loss = loss + commit_loss * self.commitment_weight
312 |
313 | quantize = self.project_out(quantize)
314 | quantize = rearrange(quantize, "b n d -> b d n")
315 | return quantize, embed_ind, loss
316 |
317 |
318 | class ResidualVectorQuantization(nn.Module):
319 | """Residual vector quantization implementation.
320 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
321 | """
322 |
323 | def __init__(self, *, num_quantizers, **kwargs):
324 | super().__init__()
325 | self.layers = nn.ModuleList(
326 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)])
327 |
328 | def forward(self, x, n_q: tp.Optional[int]=None):
329 | quantized_out = 0.0
330 | residual = x
331 |
332 | all_losses = []
333 | all_indices = []
334 |
335 | n_q = n_q or len(self.layers)
336 |
337 | for layer in self.layers[:n_q]:
338 | quantized, indices, loss = layer(residual)
339 | residual = residual - quantized
340 | quantized_out = quantized_out + quantized
341 |
342 | all_indices.append(indices)
343 | all_losses.append(loss)
344 |
345 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
346 | return quantized_out, out_indices, out_losses
347 |
348 | def encode(self,
349 | x: torch.Tensor,
350 | n_q: tp.Optional[int]=None,
351 | st: tp.Optional[int]=None) -> torch.Tensor:
352 | residual = x
353 | all_indices = []
354 | n_q = n_q or len(self.layers)
355 | st = st or 0
356 | for layer in self.layers[st:n_q]: # 设置解码的起止layer
357 | indices = layer.encode(residual)
358 | quantized = layer.decode(indices)
359 | residual = residual - quantized
360 | all_indices.append(indices)
361 | out_indices = torch.stack(all_indices)
362 | return out_indices
363 |
364 | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
365 | quantized_out = torch.tensor(0.0, device=q_indices.device)
366 | for i, indices in enumerate(q_indices):
367 | layer = self.layers[i]
368 | quantized = layer.decode(indices)
369 | quantized_out = quantized_out + quantized
370 | return quantized_out
371 |
--------------------------------------------------------------------------------
/academicodec/quantization/distrib.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Torch distributed utilities."""
7 | import typing as tp
8 |
9 | import torch
10 |
11 |
12 | def rank():
13 | if torch.distributed.is_initialized():
14 | return torch.distributed.get_rank()
15 | else:
16 | return 0
17 |
18 |
19 | def world_size():
20 | if torch.distributed.is_initialized():
21 | return torch.distributed.get_world_size()
22 | else:
23 | return 1
24 |
25 |
26 | def is_distributed():
27 | return world_size() > 1
28 |
29 |
30 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
31 | if is_distributed():
32 | return torch.distributed.all_reduce(tensor, op)
33 |
34 |
35 | def _is_complex_or_float(tensor):
36 | return torch.is_floating_point(tensor) or torch.is_complex(tensor)
37 |
38 |
39 | def _check_number_of_params(params: tp.List[torch.Tensor]):
40 | # utility function to check that the number of params in all workers is the same,
41 | # and thus avoid a deadlock with distributed all reduce.
42 | if not is_distributed() or not params:
43 | return
44 | #print('params[0].device ', params[0].device)
45 | tensor = torch.tensor(
46 | [len(params)], device=params[0].device, dtype=torch.long)
47 | all_reduce(tensor)
48 | if tensor.item() != len(params) * world_size():
49 | # If not all the workers have the same number, for at least one of them,
50 | # this inequality will be verified.
51 | raise RuntimeError(
52 | f"Mismatch in number of params: ours is {len(params)}, "
53 | "at least one worker has a different one.")
54 |
55 |
56 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int=0):
57 | """Broadcast the tensors from the given parameters to all workers.
58 | This can be used to ensure that all workers have the same model to start with.
59 | """
60 | if not is_distributed():
61 | return
62 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
63 | _check_number_of_params(tensors)
64 | handles = []
65 | for tensor in tensors:
66 | # src = int(rank()) # added code
67 | handle = torch.distributed.broadcast(
68 | tensor.data, src=src, async_op=True)
69 | handles.append(handle)
70 | for handle in handles:
71 | handle.wait()
72 |
73 |
74 | def sync_buffer(buffers, average=True):
75 | """
76 | Sync grad for buffers. If average is False, broadcast instead of averaging.
77 | """
78 | if not is_distributed():
79 | return
80 | handles = []
81 | for buffer in buffers:
82 | if torch.is_floating_point(buffer.data):
83 | if average:
84 | handle = torch.distributed.all_reduce(
85 | buffer.data,
86 | op=torch.distributed.ReduceOp.SUM,
87 | async_op=True)
88 | else:
89 | handle = torch.distributed.broadcast(
90 | buffer.data, src=0, async_op=True)
91 | handles.append((buffer, handle))
92 | for buffer, handle in handles:
93 | handle.wait()
94 | if average:
95 | buffer.data /= world_size
96 |
97 |
98 | def sync_grad(params):
99 | """
100 | Simpler alternative to DistributedDataParallel, that doesn't rely
101 | on any black magic. For simple models it can also be as fast.
102 | Just call this on your model parameters after the call to backward!
103 | """
104 | if not is_distributed():
105 | return
106 | handles = []
107 | for p in params:
108 | if p.grad is not None:
109 | handle = torch.distributed.all_reduce(
110 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
111 | handles.append((p, handle))
112 | for p, handle in handles:
113 | handle.wait()
114 | p.grad.data /= world_size()
115 |
116 |
117 | def average_metrics(metrics: tp.Dict[str, float], count=1.):
118 | """Average a dictionary of metrics across all workers, using the optional
119 | `count` as unormalized weight.
120 | """
121 | if not is_distributed():
122 | return metrics
123 | keys, values = zip(*metrics.items())
124 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
125 | tensor = torch.tensor(
126 | list(values) + [1], device=device, dtype=torch.float32)
127 | tensor *= count
128 | all_reduce(tensor)
129 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
130 | return dict(zip(keys, averaged))
131 |
--------------------------------------------------------------------------------
/academicodec/quantization/vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Residual vector quantizer implementation."""
7 | import math
8 | import typing as tp
9 | from dataclasses import dataclass
10 | from dataclasses import field
11 |
12 | import torch
13 | from torch import nn
14 |
15 | from academicodec.quantization.core_vq import ResidualVectorQuantization
16 |
17 |
18 | @dataclass
19 | class QuantizedResult:
20 | quantized: torch.Tensor
21 | codes: torch.Tensor
22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23 | penalty: tp.Optional[torch.Tensor] = None
24 | metrics: dict = field(default_factory=dict)
25 |
26 |
27 | class ResidualVectorQuantizer(nn.Module):
28 | """Residual Vector Quantizer.
29 | Args:
30 | dimension (int): Dimension of the codebooks.
31 | n_q (int): Number of residual vector quantizers used.
32 | bins (int): Codebook size.
33 | decay (float): Decay for exponential moving average over the codebooks.
34 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
35 | kmeans_iters (int): Number of iterations used for kmeans initialization.
36 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
37 | that have an exponential moving average cluster size less than the specified threshold with
38 | randomly selected vector from the current batch.
39 | """
40 |
41 | def __init__(
42 | self,
43 | dimension: int=256,
44 | n_q: int=8,
45 | bins: int=1024,
46 | decay: float=0.99,
47 | kmeans_init: bool=True,
48 | kmeans_iters: int=50,
49 | threshold_ema_dead_code: int=2, ):
50 | super().__init__()
51 | self.n_q = n_q
52 | self.dimension = dimension
53 | self.bins = bins
54 | self.decay = decay
55 | self.kmeans_init = kmeans_init
56 | self.kmeans_iters = kmeans_iters
57 | self.threshold_ema_dead_code = threshold_ema_dead_code
58 | self.vq = ResidualVectorQuantization(
59 | dim=self.dimension,
60 | codebook_size=self.bins,
61 | num_quantizers=self.n_q,
62 | decay=self.decay,
63 | kmeans_init=self.kmeans_init,
64 | kmeans_iters=self.kmeans_iters,
65 | threshold_ema_dead_code=self.threshold_ema_dead_code, )
66 |
67 | def forward(self,
68 | x: torch.Tensor,
69 | sample_rate: int,
70 | bandwidth: tp.Optional[float]=None) -> QuantizedResult:
71 | """Residual vector quantization on the given input tensor.
72 | Args:
73 | x (torch.Tensor): Input tensor.
74 | sample_rate (int): Sample rate of the input tensor.
75 | bandwidth (float): Target bandwidth.
76 | Returns:
77 | QuantizedResult:
78 | The quantized (or approximately quantized) representation with
79 | the associated bandwidth and any penalty term for the loss.
80 | """
81 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
82 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
83 | quantized, codes, commit_loss = self.vq(x, n_q=n_q)
84 | bw = torch.tensor(n_q * bw_per_q).to(x)
85 | return quantized, codes, bw, torch.mean(commit_loss)
86 | #return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
87 |
88 | def get_num_quantizers_for_bandwidth(
89 | self, sample_rate: int, bandwidth: tp.Optional[float]=None) -> int:
90 | """Return n_q based on specified target bandwidth.
91 | """
92 | bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
93 | n_q = self.n_q
94 | if bandwidth and bandwidth > 0.:
95 | n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
96 | return n_q
97 |
98 | def get_bandwidth_per_quantizer(self, sample_rate: int):
99 | """Return bandwidth per quantizer for a given input sample rate.
100 | """
101 | return math.log2(self.bins) * sample_rate / 1000
102 |
103 | def encode(self,
104 | x: torch.Tensor,
105 | sample_rate: int,
106 | bandwidth: tp.Optional[float]=None,
107 | st: tp.Optional[int]=None) -> torch.Tensor:
108 | """Encode a given input tensor with the specified sample rate at the given bandwidth.
109 | The RVQ encode method sets the appropriate number of quantizer to use
110 | and returns indices for each quantizer.
111 | """
112 | n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
113 | st = st or 0
114 | codes = self.vq.encode(x, n_q=n_q, st=st)
115 | return codes
116 |
117 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
118 | """Decode the given codes to the quantized representation.
119 | """
120 | quantized = self.vq.decode(codes)
121 | return quantized
122 |
--------------------------------------------------------------------------------
/academicodec/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import random
5 | import sys
6 | import time
7 | import warnings
8 |
9 | import matplotlib
10 | import numpy as np
11 | import torch
12 | import yaml
13 | from torch import distributed as dist
14 | from torch.nn.utils import weight_norm
15 | matplotlib.use("Agg")
16 | import matplotlib.pylab as plt
17 | import re
18 | import pathlib
19 |
20 |
21 | def seed_everything(seed, cudnn_deterministic=False):
22 | """
23 | Function that sets seed for pseudo-random number generators in:
24 | pytorch, numpy, python.random
25 |
26 | Args:
27 | seed: the integer value seed for global random state
28 | """
29 | if seed is not None:
30 | # print(f"Global seed set to {seed}")
31 | random.seed(seed)
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 | torch.cuda.manual_seed_all(seed)
35 |
36 | # if cudnn_deterministic:
37 | # torch.backends.cudnn.deterministic = True
38 | # warnings.warn('You have chosen to seed training. '
39 | # 'This will turn on the CUDNN deterministic setting, '
40 | # 'which can slow down your training considerably! '
41 | # 'You may see unexpected behavior when restarting '
42 | # 'from checkpoints.')
43 |
44 |
45 | def is_primary():
46 | return get_rank() == 0
47 |
48 |
49 | def get_rank():
50 | if not dist.is_available():
51 | return 0
52 | if not dist.is_initialized():
53 | return 0
54 |
55 | return dist.get_rank()
56 |
57 |
58 | def load_yaml_config(path):
59 | with open(path) as f:
60 | config = yaml.full_load(f)
61 | return config
62 |
63 |
64 | def save_config_to_yaml(config, path):
65 | assert path.endswith('.yaml')
66 | with open(path, 'w') as f:
67 | f.write(yaml.dump(config))
68 | f.close()
69 |
70 |
71 | def save_dict_to_json(d, path, indent=None):
72 | json.dump(d, open(path, 'w'), indent=indent)
73 |
74 |
75 | def load_dict_from_json(path):
76 | return json.load(open(path, 'r'))
77 |
78 |
79 | def write_args(args, path):
80 | args_dict = dict((name, getattr(args, name)) for name in dir(args)
81 | if not name.startswith('_'))
82 | with open(path, 'a') as args_file:
83 | args_file.write('==> torch version: {}\n'.format(torch.__version__))
84 | args_file.write(
85 | '==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
86 | args_file.write('==> Cmd:\n')
87 | args_file.write(str(sys.argv))
88 | args_file.write('\n==> args:\n')
89 | for k, v in sorted(args_dict.items()):
90 | args_file.write(' %s: %s\n' % (str(k), str(v)))
91 | args_file.close()
92 |
93 |
94 | class Logger(object):
95 | def __init__(self, args):
96 | self.args = args
97 | self.save_dir = args.save_dir
98 | self.is_primary = is_primary()
99 |
100 | if self.is_primary:
101 | os.makedirs(self.save_dir, exist_ok=True)
102 |
103 | # save the args and config
104 | self.config_dir = os.path.join(self.save_dir, 'configs')
105 | os.makedirs(self.config_dir, exist_ok=True)
106 | file_name = os.path.join(self.config_dir, 'args.txt')
107 | write_args(args, file_name)
108 |
109 | log_dir = os.path.join(self.save_dir, 'logs')
110 | if not os.path.exists(log_dir):
111 | os.makedirs(log_dir, exist_ok=True)
112 | self.text_writer = open(os.path.join(log_dir, 'log.txt'),
113 | 'a') # 'w')
114 | if args.tensorboard:
115 | self.log_info('using tensorboard')
116 | self.tb_writer = torch.utils.tensorboard.SummaryWriter(
117 | log_dir=log_dir
118 | ) # tensorboard.SummaryWriter(log_dir=log_dir)
119 | else:
120 | self.tb_writer = None
121 |
122 | def save_config(self, config):
123 | if self.is_primary:
124 | save_config_to_yaml(config,
125 | os.path.join(self.config_dir, 'config.yaml'))
126 |
127 | def log_info(self, info, check_primary=True):
128 | if self.is_primary or (not check_primary):
129 | print(info)
130 | if self.is_primary:
131 | info = str(info)
132 | time_str = time.strftime('%Y-%m-%d-%H-%M')
133 | info = '{}: {}'.format(time_str, info)
134 | if not info.endswith('\n'):
135 | info += '\n'
136 | self.text_writer.write(info)
137 | self.text_writer.flush()
138 |
139 | def add_scalar(self, **kargs):
140 | """Log a scalar variable."""
141 | if self.is_primary:
142 | if self.tb_writer is not None:
143 | self.tb_writer.add_scalar(**kargs)
144 |
145 | def add_scalars(self, **kargs):
146 | """Log a scalar variable."""
147 | if self.is_primary:
148 | if self.tb_writer is not None:
149 | self.tb_writer.add_scalars(**kargs)
150 |
151 | def add_image(self, **kargs):
152 | """Log a scalar variable."""
153 | if self.is_primary:
154 | if self.tb_writer is not None:
155 | self.tb_writer.add_image(**kargs)
156 |
157 | def add_images(self, **kargs):
158 | """Log a scalar variable."""
159 | if self.is_primary:
160 | if self.tb_writer is not None:
161 | self.tb_writer.add_images(**kargs)
162 |
163 | def close(self):
164 | if self.is_primary:
165 | self.text_writer.close()
166 | self.tb_writer.close()
167 |
168 |
169 | def plot_spectrogram(spectrogram):
170 | fig, ax = plt.subplots(figsize=(10, 2))
171 | im = ax.imshow(
172 | spectrogram, aspect="auto", origin="lower", interpolation='none')
173 | plt.colorbar(im, ax=ax)
174 |
175 | fig.canvas.draw()
176 | plt.close()
177 |
178 | return fig
179 |
180 |
181 | def init_weights(m, mean=0.0, std=0.01):
182 | classname = m.__class__.__name__
183 | if classname.find("Conv") != -1:
184 | m.weight.data.normal_(mean, std)
185 |
186 |
187 | def apply_weight_norm(m):
188 | classname = m.__class__.__name__
189 | if classname.find("Conv") != -1:
190 | weight_norm(m)
191 |
192 |
193 | def get_padding(kernel_size, dilation=1):
194 | return int((kernel_size * dilation - dilation) / 2)
195 |
196 |
197 | def load_checkpoint(filepath, device):
198 | assert os.path.isfile(filepath)
199 | print("Loading '{}'".format(filepath))
200 | checkpoint_dict = torch.load(filepath, map_location=device)
201 | print("Complete.")
202 | return checkpoint_dict
203 |
204 |
205 | def save_checkpoint(filepath, obj, num_ckpt_keep=5):
206 | name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1)
207 | ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*'))
208 | if len(ckpts) > num_ckpt_keep:
209 | [os.remove(c) for c in ckpts[:-num_ckpt_keep]]
210 | print("Saving checkpoint to {}".format(filepath))
211 | torch.save(obj, filepath)
212 | print("Complete.")
213 |
214 |
215 | def scan_checkpoint(cp_dir, prefix):
216 | pattern = os.path.join(cp_dir, prefix + '????????')
217 | cp_list = glob.glob(pattern)
218 | if len(cp_list) == 0:
219 | return None
220 | return sorted(cp_list)[-1]
221 |
--------------------------------------------------------------------------------
/egs/Encodec_16k_320d/path.sh:
--------------------------------------------------------------------------------
1 | ../Encodec_24k_32d/path.sh
--------------------------------------------------------------------------------
/egs/Encodec_16k_320d/readme.md:
--------------------------------------------------------------------------------
1 | # The training code of Encodec
2 |
3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec.
4 |
5 | ### For Training
6 | set the right path to start.sh
7 | `bash start.sh`
8 |
9 | ### For Inference
10 | if you want to use our checkpoint. Run the following
11 | ```bash
12 | mkdir checkpoint
13 | cd checkpoint
14 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_16khz_320d.pth
15 | bash test.sh # set the root in test.sh, before runing it.
16 | ```
--------------------------------------------------------------------------------
/egs/Encodec_16k_320d/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 | log_root=logs
4 | # 16kHz *.wav in train_data_dir
5 | train_data_dir=dump/train
6 | valid_data_dir=dump/valid
7 |
8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \
10 | --BATCH_SIZE 16 \
11 | --N_EPOCHS 300 \
12 | --save_dir ${log_root} \
13 | --PATH ${log_root} \
14 | --train_data_path ${train_data_dir} \
15 | --valid_data_path ${valid_data_dir} \
16 | --sr 16000 \
17 | --ratios 8 5 4 2 \
18 | --target_bandwidths 1 1.5 2 4 6 12
19 |
--------------------------------------------------------------------------------
/egs/Encodec_16k_320d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | python3 ${BIN_DIR}/test.py \
5 | --input=./test_wav \
6 | --output=./output \
7 | --resume_path=checkpoint/encodec_16k_320d.pth \
8 | --sr=16000 \
9 | --ratios 8 5 4 2 \
10 | --target_bandwidths 1 1.5 2 4 6 12 \
11 | --target_bw=12 \
12 | -r
13 |
--------------------------------------------------------------------------------
/egs/Encodec_24k_240d/path.sh:
--------------------------------------------------------------------------------
1 | ../Encodec_24k_32d/path.sh
--------------------------------------------------------------------------------
/egs/Encodec_24k_240d/readme.md:
--------------------------------------------------------------------------------
1 | # The training code of Encodec
2 |
3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec.
4 |
5 | ### For Training
6 | set the right path to statr/start.sh
7 |
8 | run: `bash start.sh`
9 |
10 | ### For Finetune
11 | If you want to finetune the model, you can use following instruct:
12 | `
13 | python3 main3_ddp.py --BATCH_SIZE 16 --N_EPOCHS 300 \
14 | --save_dir path_to_save_log \
15 | --PATH path_to_save_model \
16 | --train_data_path path_to_training_data \
17 | --valid_data_path path_to_val_data \
18 | --resume --resume_path the_model_path
19 | `
20 |
21 | ### For Inference
22 | if you want to use our checkpoint. Run the following
23 | ```bash
24 | mkdir checkpoint
25 | cd checkpoint
26 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_24khz_240d.pth
27 | bash test.sh # set the root in test.sh, before runing it.
28 | ```
--------------------------------------------------------------------------------
/egs/Encodec_24k_240d/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 | log_root=logs
4 | # 24kHz *.wav in train_data_dir
5 | train_data_dir=dump/train
6 | valid_data_dir=dump/valid
7 |
8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \
10 | --BATCH_SIZE 16 \
11 | --N_EPOCHS 300 \
12 | --save_dir ${log_root} \
13 | --PATH ${log_root} \
14 | --train_data_path ${train_data_dir} \
15 | --valid_data_path ${valid_data_dir} \
16 | --sr 24000 \
17 | --ratios 6 5 4 2 \
18 | --target_bandwidths 1 2 4 8 12
--------------------------------------------------------------------------------
/egs/Encodec_24k_240d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | python3 ${BIN_DIR}/test.py \
5 | --input=./test_wav \
6 | --output=./output \
7 | --resume_path=checkpoint/encodec_24khz_240d.pth \
8 | --sr=24000 \
9 | --ratios 6 5 4 2 \
10 | --target_bandwidths 1 2 4 8 12 \
11 | --target_bw=12 \
12 | -r
13 |
--------------------------------------------------------------------------------
/egs/Encodec_24k_32d/path.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export MAIN_ROOT=`realpath ${PWD}/../../`
3 |
4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
5 | MODEL=encodec
6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL}
--------------------------------------------------------------------------------
/egs/Encodec_24k_32d/readme.md:
--------------------------------------------------------------------------------
1 | # The training code of Encodec
2 |
3 | ### Note that, this part of code is based on Facebook's Encodec. We just provide the training process. The license is the same as Encodec.
4 |
5 | ### For Training
6 | set the right path to start.sh
7 |
8 | `bash start.sh`
9 |
10 | ### For Inference
11 | if you want to use our checkpoint. Run the following
12 | ```bash
13 | mkdir checkpoint
14 | cd checkpoint`
15 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_24khz_32d.pth
16 | bash test.sh # set the root in test.sh, before runing it.
17 | ```
18 |
--------------------------------------------------------------------------------
/egs/Encodec_24k_32d/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 | log_root=logs
4 | # 24kHz *.wav in train_data_dir
5 | train_data_dir=dump/train
6 | valid_data_dir=dump/valid
7 |
8 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
9 | python3 -m torch.distributed.launch --nproc_per_node 8 ${BIN_DIR}/main_launch.py \
10 | --BATCH_SIZE 16 \
11 | --N_EPOCHS 300 \
12 | --save_dir ${log_root} \
13 | --PATH ${log_root} \
14 | --train_data_path ${train_data_dir} \
15 | --valid_data_path ${valid_data_dir} \
16 | --sr 24000 \
17 | --ratios 2 2 2 4 \
18 | --target_bandwidths 7.5 15
19 |
--------------------------------------------------------------------------------
/egs/Encodec_24k_32d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | python3 ${BIN_DIR}/test.py \
5 | --input=./test_wav \
6 | --output=./output \
7 | --resume_path=checkpoint/Encodec_24khz_32d.pth \
8 | --sr=24000 \
9 | --ratios 2 2 2 4 \
10 | --target_bandwidths 7.5 15 \
11 | --target_bw=7.5 \
12 | -r
13 |
14 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-16k-320d/config_16k_320d.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 8,
4 | "batch_size": 64,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.5,
7 | "adam_b2": 0.9,
8 | "lr_decay": 0.98,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,5,4,2],
12 | "upsample_kernel_sizes": [16,11,8,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 16000,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 200,
22 | "win_size": 800,
23 |
24 | "sampling_rate": 16000,
25 |
26 | "n_code_groups": 2,
27 | "n_codes": 1024,
28 | "codebook_loss_lambda": 1.0,
29 | "commitment_loss_lambda": 0.25,
30 |
31 | "fmin": 0,
32 | "fmax": 8000,
33 | "fmax_for_loss": null,
34 |
35 | "num_workers": 12,
36 |
37 | "dist_config": {
38 | "dist_backend": "nccl",
39 | "dist_url": "tcp://localhost:54321",
40 | "world_size": 1
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-16k-320d/path.sh:
--------------------------------------------------------------------------------
1 | ../HiFi-Codec-24k-240d/path.sh
--------------------------------------------------------------------------------
/egs/HiFi-Codec-16k-320d/readme.md:
--------------------------------------------------------------------------------
1 | ## How to train your model
2 | Firstly, set the related path in start.sh file, then
3 | ```bash
4 | bash start.sh
5 | ```
6 |
7 | ## How to Inference
8 | ```bash
9 | mkdir checkpoint
10 | cd checkpoint
11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-16k-320d
12 | bash test.sh
13 | ```
14 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-16k-320d/start.sh:
--------------------------------------------------------------------------------
1 |
2 | #!/bin/bash
3 | source path.sh
4 | set -e
5 |
6 | log_root="logs"
7 | # .lst save the wav path.
8 | input_training_file="train.lst"
9 | input_validation_file="valid.lst"
10 |
11 | #mode=debug
12 | mode=train
13 |
14 | if [ "${mode}" == "debug" ]; then
15 | ## debug
16 | echo "Debug"
17 | log_root=${log_root}_debug
18 | export CUDA_VISIBLE_DEVICES=0
19 | python ${BIN_DIR}/train.py \
20 | --config config_16k_320d.json \
21 | --checkpoint_path ${log_root} \
22 | --input_training_file ${input_training_file} \
23 | --input_validation_file ${input_validation_file} \
24 | --checkpoint_interval 100 \
25 | --summary_interval 10 \
26 | --validation_interval 100 \
27 |
28 | elif [ "$mode" == "train" ]; then
29 | ## train
30 | echo "Train model..."
31 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
32 | python ${BIN_DIR}/train.py \
33 | --config config_16k_320d.json \
34 | --checkpoint_path ${log_root} \
35 | --input_training_file ${input_training_file} \
36 | --input_validation_file ${input_validation_file} \
37 | --checkpoint_interval 5000 \
38 | --summary_interval 100 \
39 | --validation_interval 5000
40 | fi
41 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-16k-320d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | ckpt=checkpoint/HiFi-Codec-16k-320d
5 | echo checkpoint path: ${ckpt}
6 |
7 | # the path of test wave
8 | wav_dir=test_wav
9 |
10 | outputdir=output
11 | mkdir -p ${outputdir}
12 |
13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \
14 | --model_path=${ckpt} \
15 | --config_path=config_16k_320d.json \
16 | --input_wavdir=${wav_dir} \
17 | --outputdir=${outputdir} \
18 | --num_gens=10000 \
19 | --sample_rate=16000
20 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-240d/config_24k_240d.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 8,
4 | "batch_size": 32,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.5,
7 | "adam_b2": 0.9,
8 | "lr_decay": 0.98,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,5,3,2],
12 | "upsample_kernel_sizes": [16,11,7,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 12000,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 240,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 24000,
25 |
26 | "n_code_groups": 2,
27 | "n_codes": 1024,
28 | "codebook_loss_lambda": 1.0,
29 | "commitment_loss_lambda": 0.25,
30 |
31 | "fmin": 0,
32 | "fmax": 8000,
33 | "fmax_for_loss": null,
34 |
35 | "num_workers": 12,
36 |
37 | "dist_config": {
38 | "dist_backend": "nccl",
39 | "dist_url": "tcp://localhost:54321",
40 | "world_size": 1
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-240d/path.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export MAIN_ROOT=`realpath ${PWD}/../../`
3 |
4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
5 | MODEL=hificodec
6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL}
7 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-240d/readme.md:
--------------------------------------------------------------------------------
1 | ## How to train your model
2 | Firstly, set the related path in start.sh file, then
3 | ```bash
4 | bash start.sh
5 | ```
6 |
7 | ## How to Inference
8 | ```bash
9 | mkdir checkpoint
10 | cd checkpoint
11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-24k-240d
12 | bash test.sh
13 | ```
14 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-240d/start.sh:
--------------------------------------------------------------------------------
1 |
2 | #!/bin/bash
3 | source path.sh
4 | set -e
5 |
6 | log_root="logs"
7 | # .lst save the wav path.
8 | input_training_file="train.lst"
9 | input_validation_file="valid.lst"
10 |
11 | #mode=debug
12 | mode=train
13 |
14 | if [ "${mode}" == "debug" ]; then
15 | ## debug
16 | echo "Debug"
17 | log_root=${log_root}_debug
18 | export CUDA_VISIBLE_DEVICES=0
19 | python ${BIN_DIR}/train.py \
20 | --config config_24k_240d.json \
21 | --checkpoint_path ${log_root} \
22 | --input_training_file ${input_training_file} \
23 | --input_validation_file ${input_validation_file} \
24 | --checkpoint_interval 100 \
25 | --summary_interval 10 \
26 | --validation_interval 100 \
27 |
28 | elif [ "$mode" == "train" ]; then
29 | ## train
30 | echo "Train model..."
31 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
32 | python ${BIN_DIR}/train.py \
33 | --config config_24k_240d.json \
34 | --checkpoint_path ${log_root} \
35 | --input_training_file ${input_training_file} \
36 | --input_validation_file ${input_validation_file} \
37 | --checkpoint_interval 5000 \
38 | --summary_interval 100 \
39 | --validation_interval 5000
40 | fi
41 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-240d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | ckpt=checkpoint/HiFi-Codec-24k-240d
5 | echo checkpoint path: ${ckpt}
6 |
7 | # the path of test wave
8 | wav_dir=test_wav
9 |
10 | outputdir=output
11 | mkdir -p ${outputdir}
12 |
13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \
14 | --model_path=${ckpt} \
15 | --config_path=config_24k_240d.json \
16 | --input_wavdir=${wav_dir} \
17 | --outputdir=${outputdir} \
18 | --num_gens=10000
19 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/config_24k_320d.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 8,
4 | "batch_size": 80,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.5,
7 | "adam_b2": 0.9,
8 | "lr_decay": 0.98,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,5,4,2],
12 | "upsample_kernel_sizes": [16,11,8,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 16000,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 240,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 24000,
25 |
26 | "n_code_groups": 2,
27 | "n_codes": 1024,
28 | "codebook_loss_lambda": 1.0,
29 | "commitment_loss_lambda": 0.25,
30 |
31 | "fmin": 0,
32 | "fmax": 8000,
33 | "fmax_for_loss": null,
34 |
35 | "num_workers": 12,
36 |
37 | "dist_config": {
38 | "dist_backend": "nccl",
39 | "dist_url": "tcp://localhost:54321",
40 | "world_size": 1
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/infer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "\n",
11 | "sys.path.append('../../')"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 3,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "name": "stdout",
21 | "output_type": "stream",
22 | "text": [
23 | "Init model and load weights\n",
24 | "Model ready\n",
25 | "Globbed 12 wav files.\n"
26 | ]
27 | },
28 | {
29 | "name": "stderr",
30 | "output_type": "stream",
31 | "text": [
32 | "100%|███████████| 1/1 [00:00<00:00, 11.08it/s]"
33 | ]
34 | },
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "wav.shape: (97681,)\n",
40 | "acoustic_token: tensor([[[ 11, 591, 281, 629],\n",
41 | " [733, 591, 401, 139],\n",
42 | " [500, 591, 733, 600],\n",
43 | " ...,\n",
44 | " [733, 591, 451, 346],\n",
45 | " [733, 591, 401, 139],\n",
46 | " [386, 591, 281, 461]]], device='cuda:0')\n",
47 | "acoustic_token.shape: torch.Size([1, 305, 4])\n",
48 | "acoustic_token.dtype: torch.int64\n"
49 | ]
50 | },
51 | {
52 | "name": "stderr",
53 | "output_type": "stream",
54 | "text": [
55 | "\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "import glob\n",
61 | "import json\n",
62 | "import os\n",
63 | "from pathlib import Path\n",
64 | "\n",
65 | "import librosa\n",
66 | "import torch\n",
67 | "from academicodec.models.hificodec.vqvae import VQVAE\n",
68 | "from librosa.util import normalize\n",
69 | "from tqdm import tqdm\n",
70 | "\n",
71 | "ckpt_path = './checkpoint/HiFi-Codec-24k-320d'\n",
72 | "config_path = './config_24k_320d.json'\n",
73 | "with open(config_path, 'r') as f:\n",
74 | " config = json.load(f)\n",
75 | " sample_rate = config['sampling_rate']\n",
76 | "\n",
77 | "outputdir = './output'\n",
78 | "inputdir = './test_wav'\n",
79 | "num = 1024\n",
80 | "\n",
81 | "if __name__ == '__main__':\n",
82 | " Path(outputdir).mkdir(parents=True, exist_ok=True)\n",
83 | " print(\"Init model and load weights\")\n",
84 | " # make sure you downloaded the weights from https://huggingface.co/Dongchao/AcademiCodec/blob/main/HiFi-Codec-24k-320d \n",
85 | " # and put it in ./checkpoint/\n",
86 | " model = VQVAE(\n",
87 | " config_path,\n",
88 | " ckpt_path,\n",
89 | " with_encoder=True)\n",
90 | " model.cuda()\n",
91 | " model.eval()\n",
92 | " print(\"Model ready\")\n",
93 | "\n",
94 | " wav_paths = glob.glob(f\"{inputdir}/*.wav\")[:num]\n",
95 | " print(f\"Globbed {len(wav_paths)} wav files.\")\n",
96 | " fid_to_acoustic_token = {}\n",
97 | " for wav_path in tqdm(wav_paths[:1]):\n",
98 | " wav, sr = librosa.load(wav_path, sr=sample_rate)\n",
99 | " print(\"wav.shape:\",wav.shape)\n",
100 | " assert sr == sample_rate\n",
101 | " fid = os.path.basename(wav_path)[:-4]\n",
102 | " wav = normalize(wav) * 0.95\n",
103 | " wav = torch.FloatTensor(wav).unsqueeze(0)\n",
104 | " wav = wav.to(torch.device('cuda'))\n",
105 | " acoustic_token = model.encode(wav)\n",
106 | " print(\"acoustic_token:\",acoustic_token)\n",
107 | " print(\"acoustic_token.shape:\",acoustic_token.shape)\n",
108 | " print(\"acoustic_token.dtype:\",acoustic_token.dtype)\n",
109 | " fid = os.path.basename(wav_path)[:-4]\n",
110 | " fid_to_acoustic_token[fid] = acoustic_token\n",
111 | "\n",
112 | " torch.save(fid_to_acoustic_token,\n",
113 | " os.path.join(outputdir, 'fid_to_acoustic_token.pth'))\n"
114 | ]
115 | }
116 | ],
117 | "metadata": {
118 | "kernelspec": {
119 | "display_name": "Python 3 (ipykernel)",
120 | "language": "python",
121 | "name": "python3"
122 | },
123 | "language_info": {
124 | "codemirror_mode": {
125 | "name": "ipython",
126 | "version": 3
127 | },
128 | "file_extension": ".py",
129 | "mimetype": "text/x-python",
130 | "name": "python",
131 | "nbconvert_exporter": "python",
132 | "pygments_lexer": "ipython3",
133 | "version": "3.7.14"
134 | }
135 | },
136 | "nbformat": 4,
137 | "nbformat_minor": 2
138 | }
139 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/path.sh:
--------------------------------------------------------------------------------
1 | ../HiFi-Codec-24k-240d/path.sh
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/readme.md:
--------------------------------------------------------------------------------
1 | ## How to train your model
2 | Firstly, set the related path in start.sh file, then
3 | ```bash
4 | bash start.sh
5 | ```
6 |
7 | ## How to Inference
8 | ```bash
9 | mkdir checkpoint
10 | cd checkpoint
11 | wget https://huggingface.co/Dongchao/AcademiCodec/resolve/main/HiFi-Codec-24k-320d
12 | bash test.sh
13 | ```
14 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 | set -e
4 |
5 | log_root="logs"
6 | # .lst save the wav path.
7 | input_training_file="train.lst"
8 | input_validation_file="valid.lst"
9 |
10 | #mode=debug
11 | mode=train
12 |
13 | if [ "${mode}" == "debug" ]; then
14 | ## debug
15 | echo "Debug"
16 | log_root=${log_root}_debug
17 | export CUDA_VISIBLE_DEVICES=0
18 | python ${BIN_DIR}/train.py \
19 | --config config_24k_320d.json \
20 | --checkpoint_path ${log_root} \
21 | --input_training_file ${input_training_file} \
22 | --input_validation_file ${input_validation_file} \
23 | --checkpoint_interval 100 \
24 | --summary_interval 10 \
25 | --validation_interval 100 \
26 |
27 | elif [ "$mode" == "train" ]; then
28 | ## train
29 | echo "Train model..."
30 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
31 | python ${BIN_DIR}/train.py \
32 | --config config_24k_320d.json \
33 | --checkpoint_path ${log_root} \
34 | --input_training_file ${input_training_file} \
35 | --input_validation_file ${input_validation_file} \
36 | --checkpoint_interval 5000 \
37 | --summary_interval 100 \
38 | --validation_interval 5000
39 | fi
40 |
--------------------------------------------------------------------------------
/egs/HiFi-Codec-24k-320d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | ckpt=checkpoint/HiFi-Codec-24k-320d
5 | echo checkpoint path: ${ckpt}
6 |
7 | # the path of test wave
8 | wav_dir=test_wav
9 |
10 | outputdir=output
11 | mkdir -p ${outputdir}
12 |
13 | python3 ${BIN_DIR}/vqvae_copy_syn.py \
14 | --model_path=${ckpt} \
15 | --config_path=config_24k_320d.json \
16 | --input_wavdir=${wav_dir} \
17 | --outputdir=${outputdir} \
18 | --num_gens=10000
19 |
--------------------------------------------------------------------------------
/egs/SoundStream_24k_240d/path.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export MAIN_ROOT=`realpath ${PWD}/../../`
3 |
4 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
5 | MODEL=encodec
6 | export BIN_DIR=${MAIN_ROOT}/academicodec/models/${MODEL}
--------------------------------------------------------------------------------
/egs/SoundStream_24k_240d/readme.md:
--------------------------------------------------------------------------------
1 | # The training code of SoundStream
2 |
3 |
4 | ### For Training
5 | set the right path to start.sh
6 |
7 | run: `bash start.sh`
8 |
9 | ### For Inference
10 |
11 |
12 | 模型不开源,这个目录暂未整理
13 |
--------------------------------------------------------------------------------
/egs/SoundStream_24k_240d/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 |
4 | python3 main3_ddp.py \
5 | --BATCH_SIZE 16 \
6 | --N_EPOCHS 300 \
7 | --save_dir path_to_save_log \
8 | --PATH path_to_save_model \
9 | --train_data_path path_to_training_data \
10 | --valid_data_path path_to_val_data \
11 | --sr 24000 \
12 | --ratios 6 5 4 2 \
13 | --target_bandwidths 1 2 4 8 12
14 |
--------------------------------------------------------------------------------
/egs/SoundStream_24k_240d/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source path.sh
3 | python3 ${BIN_DIR}/test.py \
4 | --input=./test_wav \
5 | --output=./output \
6 | --resume_path=checkpoint/soundstream.pth \
7 | --sr=24000 \
8 | --ratios 6 5 4 2 \
9 | --target_bandwidths 1 2 4 8 12 \
10 | --target_bw=12
11 |
--------------------------------------------------------------------------------
/egs/util/wavlstgen.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # 2022-2023 by zhaomingwork@qq.com
3 | # can be used for generating train.lst or valid.lst only given a root dir
4 | # example:
5 | # python wavlstgen.py --wavdir /data/asr_data/aishell/ --outfile train.lst
6 | import os
7 | import time
8 |
9 | import argparse
10 | import json
11 | import traceback
12 |
13 |
14 | import logging
15 |
16 | logging.basicConfig(level=logging.ERROR)
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--wavdir",
20 | type=str,
21 | default="./",
22 | required=True,
23 | help="root dir of wav")
24 |
25 |
26 | parser.add_argument("--outfile",
27 | type=str,
28 | default="./wav.lst",
29 | required=False,
30 | help="output list file name")
31 |
32 | args = parser.parse_args()
33 |
34 | print(args)
35 |
36 | def genwavlist(rootdir):
37 | outlist = open(args.outfile, 'w+')
38 |
39 | for dirpath, dirnames, filenames in os.walk(rootdir):
40 | for filename in filenames:
41 | #print(os.path.join(dirpath, filename))
42 | if filename.endswith(".wav"):
43 | outlist.write(os.path.join(dirpath, filename)+"\n")
44 | outlist.close()
45 |
46 |
47 | if __name__ == '__main__':
48 |
49 | genwavlist(args.wavdir)
50 |
--------------------------------------------------------------------------------
/evaluation_metric/calculate_voc_obj_metrics/compute_metrics.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # pip install pesq
4 | # pip install pystoi
5 | # pip install pyworld
6 | # pip install pysptk
7 | # pip install -U numpy
8 | stage=1
9 | stop_stage=2
10 |
11 | #ref_dir=$1
12 | #gen_dir=$2
13 |
14 | ref_dir='your test folder'
15 | gen_dir='the genereated samples'
16 | echo ${ref_dir}
17 | echo ${gen_dir}
18 |
19 |
20 |
21 | if [ $stage -le 1 ] && [ "${stop_stage}" -ge 2 ];then
22 | echo "Compute PESQ"
23 | python metrics/compute_pesq.py \
24 | -r ${ref_dir} \
25 | -d ${gen_dir}
26 | fi
27 |
28 | if [ $stage -le 2 ] && [ "${stop_stage}" -ge 3 ];then
29 | echo "Compute STOI"
30 | python metrics/compute_stoi.py \
31 | -r ${ref_dir} \
32 | -d ${gen_dir}
33 | fi
34 |
--------------------------------------------------------------------------------
/evaluation_metric/calculate_voc_obj_metrics/metrics/compute_pesq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import scipy.signal as signal
6 | from pesq import pesq
7 | from scipy.io import wavfile
8 | from tqdm import tqdm
9 |
10 |
11 | def cal_pesq(ref_dir, deg_dir):
12 | input_files = glob.glob(f"{deg_dir}/*.wav")
13 |
14 | nb_pesq_scores = 0.0
15 | wb_pesq_scores = 0.0
16 | for deg_wav in tqdm(input_files):
17 | ref_wav = os.path.join(ref_dir, os.path.basename(deg_wav))
18 | ref_rate, ref = wavfile.read(ref_wav)
19 | deg_rate, deg = wavfile.read(deg_wav)
20 | if ref_rate != 16000:
21 | ref = signal.resample(ref, 16000)
22 | if deg_rate != 16000:
23 | deg = signal.resample(deg, 16000)
24 |
25 | min_len = min(len(ref), len(deg))
26 | ref = ref[:min_len]
27 | deg = deg[:min_len]
28 |
29 | nb_pesq_scores += pesq(16000, ref, deg, 'nb')
30 | wb_pesq_scores += pesq(16000, ref, deg, 'wb')
31 |
32 | return nb_pesq_scores / len(input_files), wb_pesq_scores / len(input_files)
33 |
34 |
35 | if __name__ == '__main__':
36 |
37 | parser = argparse.ArgumentParser(description="Compute PESQ measure.")
38 |
39 | parser.add_argument(
40 | '-r', '--ref_dir', required=True, help="Reference wave folder.")
41 | parser.add_argument(
42 | '-d', '--deg_dir', required=True, help="Degraded wave folder.")
43 |
44 | args = parser.parse_args()
45 |
46 | nb_score, wb_score = cal_pesq(args.ref_dir, args.deg_dir)
47 | print(f"NB PESQ: {nb_score}")
48 | print(f"WB PESQ: {wb_score}")
49 |
--------------------------------------------------------------------------------
/evaluation_metric/calculate_voc_obj_metrics/metrics/compute_stoi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import numpy as np
6 | from pystoi import stoi
7 | from scipy.io import wavfile
8 | from tqdm import tqdm
9 |
10 |
11 | def calculate_stoi(ref_dir, deg_dir):
12 | input_files = glob.glob(f"{deg_dir}/*.wav")
13 | if len(input_files) < 1:
14 | raise RuntimeError(f"Found no wavs in {ref_dir}")
15 |
16 | stoi_scores = []
17 | for deg_wav in tqdm(input_files):
18 | ref_wav = os.path.join(ref_dir, os.path.basename(deg_wav))
19 | rate, ref = wavfile.read(ref_wav)
20 | rate, deg = wavfile.read(deg_wav)
21 | min_len = min(len(ref), len(deg))
22 | ref = ref[:min_len]
23 | deg = deg[:min_len]
24 | cur_stoi = stoi(ref, deg, rate, extended=False)
25 | stoi_scores.append(cur_stoi)
26 |
27 | return np.mean(stoi_scores)
28 |
29 |
30 | if __name__ == '__main__':
31 |
32 | parser = argparse.ArgumentParser(description="Compute STOI measure")
33 |
34 | parser.add_argument(
35 | '-r', '--ref_dir', required=True, help="Reference wave folder.")
36 | parser.add_argument(
37 | '-d', '--deg_dir', required=True, help="Degraded wave folder.")
38 |
39 | args = parser.parse_args()
40 |
41 | stoi_score = calculate_stoi(args.ref_dir, args.deg_dir)
42 | print(f"STOI: {stoi_score}")
43 |
--------------------------------------------------------------------------------
/evaluation_metric/calculate_voc_obj_metrics/metrics/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from os.path import join as opj
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def count_parameters(model):
11 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
12 |
13 |
14 | def get_params(args):
15 |
16 | params = {}
17 | args_ref = vars(args)
18 | args_keys = vars(args).keys()
19 |
20 | for key in args_keys:
21 | if '__' in key:
22 | continue
23 | else:
24 | temp_params = args_ref[key]
25 | if type(temp_params) == dict:
26 | params.update(temp_params)
27 | else:
28 | params[key] = temp_params
29 |
30 | return params
31 |
32 |
33 | def rescale_module(module, reference):
34 | for sub in module.modules():
35 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
36 | rescale_conv(sub, reference)
37 |
38 |
39 | def rescale_conv(conv, reference):
40 | std = conv.weight.std().detach()
41 | scale = (std / reference)**0.5
42 | conv.weight.data /= scale
43 | if conv.bias is not None:
44 | conv.bias.data /= scale
45 |
46 |
47 | def write_result(estimate, noise, file, args):
48 | if not os.path.exists(args.enhanced_path):
49 | os.makedirs(args.enhanced_path)
50 | file_name = opj(args.enhanced_path,
51 | file[0].rsplit('.', 1)[0].replace('\\', '/').split('/')[-1])
52 | noise_path = file_name + '_noise.wav'
53 | enhanced_path = file_name + '_enhanced.wav'
54 |
55 | torchaudio.save(noise_path, noise.squeeze(1), args.sample_rate)
56 | torchaudio.save(enhanced_path, estimate.squeeze(1), args.sample_rate)
57 |
58 |
59 | def seed_init(seed=100):
60 |
61 | random.seed(seed)
62 | np.random.seed(seed)
63 | torch.manual_seed(seed)
64 | torch.cuda.manual_seed(seed)
65 | torch.cuda.manual_seed_all(seed)
66 | torch.backends.cudnn.deterministic = True
67 | torch.backends.cudnn.benchmark = False
68 | os.environ['PYTHONHASHSEED'] = str(seed)
69 |
70 |
71 | def args_dict(args):
72 | """
73 | Get your arguments and make dictionary.
74 | If you add some arguments in the model, you should edit here also.
75 | """
76 | args.dataset = {
77 | 'train': args.train,
78 | 'val': args.val,
79 | 'test': args.test,
80 | 'matching': args.matching
81 | }
82 | args.setting = {
83 | 'sample_rate': args.sample_rate,
84 | 'segment': args.segment,
85 | 'pad': args.pad,
86 | 'stride': args.set_stride
87 | }
88 | args.manner = {
89 | 'in_channels': args.in_channels,
90 | 'out_channels': args.out_channels,
91 | 'hidden': args.hidden,
92 | 'depth': args.depth,
93 | 'kernel_size': args.kernel_size,
94 | 'stride': args.stride,
95 | 'growth': args.growth,
96 | 'head': args.head,
97 | 'segment_len': args.segment_len
98 | }
99 |
100 | args.ex_name = os.getcwd().replace('\\', '/').split('/')[-1]
101 |
102 | return args
103 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # AcademiCodec: An Open Source Audio Codec Model for Academic Research
2 |
3 | This repo is organized as follows:
4 |
5 | ```text
6 | AcademiCodec
7 | ├── academicodec
8 | │ ├── utils.py # common parts of various models
9 | │ ├── modules # common parts of various models
10 | │ ├── ...
11 | │ ├── quantization # common parts of various models
12 | │ └── models # parts that are not shared by various models
13 | │ ├── hificodec
14 | │ ├── encodec
15 | │ ├── soundstream
16 | │ └── ...
17 | ├── evaluation_metric
18 | ├── egs
19 | │ ├── SoundStream*
20 | │ ├── EnCodec*
21 | │ └── HiFi-Codec*
22 | │ ├── start.sh
23 | │ ├── ...
24 | │ └── test.sh
25 | └── README.md
26 | ```
27 |
28 | ### On going
29 | This project is on going. You can find the paper on https://arxiv.org/pdf/2305.02765.pdf
30 | Furthermore, this project is lanched from University, we expect more researchers to be the contributor.
31 |
32 | #### Abstract
33 | Audio codec models are widely used in audio communication as a crucial technique for compressing audio into discrete representations. Nowadays, audio codec models are increasingly utilized in generation fields as intermediate representations. For instance, AudioLM is ann audio generation model that uses the discrete representation of SoundStream as a training target, while VALL-E employs the Encodec model as an intermediate feature to aid TTS tasks. Despite their usefulness, two challenges persist: (1) training these audio codec models can be difficult due to the lack of publicly available training processes and the need for large-scale data and GPUs; (2) achieving good reconstruction performance requires many codebooks, which increases the burden on generation models. In this study, we propose a group-residual vector quantization (GRVQ) technique and use it to develop a novel \textbf{Hi}gh \textbf{Fi}delity Audio Codec model, HiFi-Codec, which only requires 4 codebooks. We train all the models using publicly available TTS data such as LibriTTS, VCTK, AISHELL, and more, with a total duration of over 1000 hours, using 8 GPUs. Our experimental results show that HiFi-Codec outperforms Encodec in terms of reconstruction performance despite requiring only 4 codebooks. To facilitate research in audio codec and generation, we introduce AcademiCodec, the first open-source audio codec toolkit that offers training codes and pre-trained models for Encodec, SoundStream, and HiFi-Codec.
34 |
35 | ## 🔥 News
36 | #### AcademiCodec
37 | - 2023.4.16: We first release the training code for Encodec and SoundStream and our pre-trained models, includes 24khz and 16khz.
38 | - 2023.5.5: We release the code of HiFi-Codec.
39 | - 2023.6.2: Add `HiFi-Codec-24k-320d/infer.ipynb`, which can be used to infer acoustic tokens to use for later training of VALL-E, SoundStorm and etc.
40 | - 2023.06.13: Refactor the code structure.
41 | ### Dependencies
42 | * [PyTorch](http://pytorch.org/) version >= 1.13.0
43 | * Python version >= 3.8
44 |
45 | # Train your own model
46 | please refer to the specific version.
47 |
48 | ## Data preparation
49 | Just prepare your audio data in one folder. Make sure the sample rate is right.
50 |
51 | ## Training or Inferce
52 | Refer to the specical folders, e.g. Encodec_24k_240d represent, the Encodec model, sample rate is 24khz, downsample rate is 240. If you want to use our pre-trained models, please refer to https://huggingface.co/Dongchao/AcademiCodec/tree/main
53 |
54 | ## Version Description
55 | * Encodec_16k_320, we train it using 16khz audio, and we set the downsample as 320, which can be used to train SpearTTS
56 | * Encodec_24k_240d, we train it using 24khz audio, and we set the downsample as 240, which can be used to InstructTTS
57 | * Encodec_24k_32d, we train it using 24khz audio, we only set the downsample as 32, which can only use one codebook, such as AudioGen.
58 | * SoundStream_24k_240d, the same configuration as Encodec_24k_240d.
59 | ## What the difference between SoundStream, Encodec and HiFi-Codec?
60 | In our view, the mian difference between SoundStream and Encodec is the different Discriminator choice. For Encodec, it only uses a STFT-dicriminator, which forces the STFT-spectrogram be more real. SoundStream use two types of Discriminator, one forces the waveform-level to be more real, one forces the specrogram-level to be more real. In our code, we adopt the waveform-level discriminator from HIFI-GAN. The spectrogram level discrimimator from Encodec. In thoery, we think SoundStream enjoin better performance. Actually, Google's offical SoundStream proves this, Google can only use 3 codebooks to reconstruct a audio with high-quality. Although our implements can also use 3 codebooks to realize good performance, we admit our version cannot be compared with Google now.
61 | For the HiFi-Codec, which is our proposed novel methods, which aims to help to some generation tasks. Such as VALL-E, AudioLM, MusicLM, SpearTTS, IntructTTS and so on. HiFi-Codec codebook only needs 4 codebooks, which significantly reduce the token numbers. Some researchers use our HiFi-Codec to implement VALL-E, which proves that can get better audio quality.
62 |
63 | ## Acknowledgements
64 | This implementation uses parts of the code from the following Github repos:
65 | https://github.com/facebookresearch/encodec
66 | https://github.com/yangdongchao/Text-to-sound-Synthesis
67 | https://github.com/b04901014/MQTTS
68 | ## Citations ##
69 | If you find this code useful in your research, please cite our work:
70 | ```bib
71 | @article{yang2023instructtts,
72 | title={InstructTTS: Modelling Expressive TTS in Discrete Latent Space with Natural Language Style Prompt},
73 | author={Yang, Dongchao and Liu, Songxiang and Huang, Rongjie and Lei, Guangzhi and Weng, Chao and Meng, Helen and Yu, Dong},
74 | journal={arXiv preprint arXiv:2301.13662},
75 | year={2023}
76 | }
77 | ```
78 | ```bibtex
79 | @article{yang2023hifi,
80 | title={HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
81 | author={Yang, Dongchao and Liu, Songxiang and Huang, Rongjie and Tian, Jinchuan and Weng, Chao and Zou, Yuexian},
82 | journal={arXiv preprint arXiv:2305.02765},
83 | year={2023}
84 | }
85 | ```
86 |
87 | ## Disclaimer ##
88 | MIT license
89 |
90 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchaudio
2 | tensorboard
3 | einops
4 | matplotlib
5 | pyyaml
6 | tqdm
--------------------------------------------------------------------------------